From 91bd0ab93627c5d6962dec8d1ebd24bbcf24ed57 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Tue, 3 Mar 2026 22:58:04 +0900 Subject: [PATCH 1/4] [quantization] Introduce wrappers for Qwen3VLTextDecoderLayer and Qwen3VLTextModel - Add `QuantQwen3VLTextDecoderLayer`: wraps attention, MLP, and layernorm blocks; pre-builds static causal mask and RoPE templates to avoid dynamic ops in forward pass - Add `QuantQwen3VLTextModel`: pre-computes shared causal mask and RoPE once and passes them to every decoder layer, so they are quantized exactly once rather than independently in each layer - Register both wrappers in `_CORE_MODULES` Co-Authored-By: Claude Sonnet 4.6 --- .../qwen_vl/test_quant_text_decoder_layer.py | 135 ++++++++++ .../wrappers/qwen_vl/test_quant_text_model.py | 98 +++++++ .../qwen_vl/quant_text_decoder_layer.py | 205 +++++++++++++++ .../wrappers/qwen_vl/quant_text_model.py | 241 ++++++++++++++++++ tico/quantization/wrapq/wrappers/registry.py | 2 + 5 files changed, 681 insertions(+) create mode 100644 test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_decoder_layer.py create mode 100644 test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py create mode 100644 tico/quantization/wrapq/wrappers/qwen_vl/quant_text_decoder_layer.py create mode 100644 tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py diff --git a/test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_decoder_layer.py b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_decoder_layer.py new file mode 100644 index 00000000..0f42b5fd --- /dev/null +++ b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_decoder_layer.py @@ -0,0 +1,135 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.dtypes import DType +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.utils.version import has_transformers_for +from tico.quantization.wrapq.wrappers.qwen_vl.quant_text_decoder_layer import ( + QuantQwen3VLTextDecoderLayer, +) + +skip_msg = "required transformers not installed — skipping Qwen3VLTextDecoderLayer tests" + + +@unittest.skipUnless(has_transformers_for("qwen3-vl"), skip_msg) +class TestQuantQwen3VLTextDecoderLayer(unittest.TestCase): + fp_layer: torch.nn.Module + + @classmethod + def setUpClass(cls): + torch.manual_seed(0) + + from transformers.models.qwen3_vl.configuration_qwen3_vl import ( + Qwen3VLTextConfig, + ) + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLTextDecoderLayer, + ) + + cls.cfg = Qwen3VLTextConfig( + hidden_size=16, + intermediate_size=32, + num_attention_heads=2, + num_key_value_heads=1, + head_dim=8, + attention_bias=False, + attention_dropout=0.0, + max_position_embeddings=256, + ) + if not hasattr(cls.cfg, "_attn_implementation"): + setattr(cls.cfg, "_attn_implementation", "eager") + else: + cls.cfg._attn_implementation = "eager" + + cls.fp_layer = Qwen3VLTextDecoderLayer(cls.cfg, layer_idx=0) + + def _rand_rope(self, B: int, S: int): + h = self.cfg.head_dim + emb = torch.randn(B, S, h) + return emb.cos(), emb.sin() + + def test_mode_transitions(self): + qlayer = QuantQwen3VLTextDecoderLayer(self.fp_layer) + self.assertIs(qlayer._mode, Mode.NO_QUANT) + + qlayer.enable_calibration() + self.assertIs(qlayer._mode, Mode.CALIB) + + SEQ_LEN = 16 + hidden = torch.randn(1, SEQ_LEN, self.cfg.hidden_size) + _ = qlayer(hidden) + + qlayer.freeze_qparams() + self.assertIs(qlayer._mode, Mode.QUANT) + + def test_forward_diff(self): + qlayer = QuantQwen3VLTextDecoderLayer(self.fp_layer) + qlayer.enable_calibration() + + SEQ_LEN = 16 + for _ in range(4): + hidden = torch.randn(1, SEQ_LEN, self.cfg.hidden_size) + _ = qlayer(hidden) + qlayer.freeze_qparams() + + hidden = torch.randn(1, SEQ_LEN, self.cfg.hidden_size) + pos = self._rand_rope(1, SEQ_LEN) + + mask = torch.full((1, 1, SEQ_LEN, SEQ_LEN), float("-120")) + mask.triu_(1) + + with torch.no_grad(): + q_out = qlayer(hidden) + q_out = q_out[0] if isinstance(q_out, tuple) else q_out + + fp_out = self.fp_layer( + hidden, attention_mask=mask, position_embeddings=pos + ) + fp_out = fp_out[0] if isinstance(fp_out, tuple) else fp_out + + diff = (fp_out - q_out).abs().mean().item() + self.assertGreater(diff, 0.0) + self.assertLess(diff, 0.5) + self.assertEqual(fp_out.shape, q_out.shape) + + def test_with_precomputed_embeddings(self): + """position_embeddings injected from outside (model-level sharing pattern).""" + qlayer = QuantQwen3VLTextDecoderLayer(self.fp_layer) + qlayer.enable_calibration() + + SEQ_LEN = 16 + hidden = torch.randn(1, SEQ_LEN, self.cfg.hidden_size) + pos = self._rand_rope(1, SEQ_LEN) + + mask = torch.full((1, 1, SEQ_LEN, SEQ_LEN), float("-120")) + mask.triu_(1) + + _ = qlayer(hidden, attention_mask=mask, position_embeddings=pos) + qlayer.freeze_qparams() + self.assertIs(qlayer._mode, Mode.QUANT) + + def test_dtype_override(self): + cfg = PTQConfig( + default_dtype=DType.int(16), + overrides={ + "mlp_residual_out": {"dtype": DType.uint(8)}, + }, + ) + qlayer = QuantQwen3VLTextDecoderLayer(self.fp_layer, qcfg=cfg) + self.assertEqual(qlayer.obs_mlp_residual_out.dtype, DType.uint(8)) diff --git a/test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py new file mode 100644 index 00000000..4c4d4600 --- /dev/null +++ b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py @@ -0,0 +1,98 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.utils.version import has_transformers_for +from tico.quantization.wrapq.wrappers.qwen_vl.quant_text_model import ( + QuantQwen3VLTextModel, +) + +skip_msg = "required transformers not installed — skipping Qwen3VLTextModel tests" + + +@unittest.skipUnless(has_transformers_for("qwen3-vl"), skip_msg) +class TestQuantQwen3VLTextModel(unittest.TestCase): + fp_model: torch.nn.Module + vocab_size: int + seq_len: int + + @classmethod + def setUpClass(cls): + torch.manual_seed(0) + + from transformers.models.qwen3_vl.configuration_qwen3_vl import ( + Qwen3VLTextConfig, + ) + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel + + cls.seq_len = 16 + cls.vocab_size = 512 + + cfg = Qwen3VLTextConfig( + hidden_size=8, + intermediate_size=16, + num_attention_heads=2, + num_key_value_heads=1, + head_dim=4, + num_hidden_layers=2, + attention_bias=False, + attention_dropout=0.0, + max_position_embeddings=cls.seq_len, + vocab_size=cls.vocab_size, + use_cache=False, + return_dict=False, + ) + if not hasattr(cfg, "_attn_implementation"): + setattr(cfg, "_attn_implementation", "eager") + else: + cfg._attn_implementation = "eager" + + cls.fp_model = Qwen3VLTextModel(cfg) + + def test_mode_transitions(self): + qmodel = QuantQwen3VLTextModel(self.fp_model) + self.assertIs(qmodel._mode, Mode.NO_QUANT) + + qmodel.enable_calibration() + self.assertIs(qmodel._mode, Mode.CALIB) + + x = torch.randint(0, self.vocab_size, (1, self.seq_len)) + _ = qmodel(x) + + qmodel.freeze_qparams() + self.assertIs(qmodel._mode, Mode.QUANT) + + def test_forward_diff(self): + qmodel = QuantQwen3VLTextModel(self.fp_model) + qmodel.enable_calibration() + + calib_set = [] + for _ in range(4): + inp = torch.randint(0, self.vocab_size, (1, self.seq_len)) + _ = qmodel(inp) + calib_set.append(inp) + qmodel.freeze_qparams() + + with torch.no_grad(): + q_out = qmodel(calib_set[0])[0] + fp_out = self.fp_model(calib_set[0])[0] + + diff = (fp_out - q_out).abs().mean().item() + self.assertGreater(diff, 0.0) + self.assertLess(diff, 0.4) + self.assertEqual(fp_out.shape, q_out.shape) diff --git a/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_decoder_layer.py b/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_decoder_layer.py new file mode 100644 index 00000000..3c9acd49 --- /dev/null +++ b/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_decoder_layer.py @@ -0,0 +1,205 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper +from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase +from tico.quantization.wrapq.wrappers.registry import try_register + + +@try_register("transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextDecoderLayer") +class QuantQwen3VLTextDecoderLayer(QuantModuleBase): + """ + Quant-aware drop-in replacement for HF `Qwen3VLTextDecoderLayer`. + + Attention & MLP blocks are replaced by their quantized counterparts. + A "static" causal mask and RoPE templates are pre-built in `__init__` + to avoid dynamic ops inside `forward`. + + Notes + ----- + - Prefill-only: `use_cache` is not supported because + `QuantQwen3VLTextAttention` does not return KV cache. + - `position_embeddings` can be injected from the parent model-level wrapper + (shared across all layers); if omitted the layer uses its own pre-computed + templates as fallback. + """ + + def __init__( + self, + fp_layer: nn.Module, + *, + qcfg: Optional[PTQConfig] = None, + fp_name: Optional[str] = None, + return_type: Optional[str] = None, + ): + self.return_type = return_type + if self.return_type is None: + import transformers + + v = tuple(map(int, transformers.__version__.split(".")[:2])) + self.return_type = "tensor" if v >= (4, 54) else "tuple" + assert self.return_type is not None + + super().__init__(qcfg, fp_name=fp_name) + + # ----- child configs ------------------------------------------------ + attn_cfg = qcfg.child("self_attn") if qcfg else None + mlp_cfg = qcfg.child("mlp") if qcfg else None + input_ln_cfg = qcfg.child("input_layernorm") if qcfg else None + post_attn_ln_cfg = qcfg.child("post_attention_layernorm") if qcfg else None + + # ----- assertions --------------------------------------------------- + assert hasattr(fp_layer, "self_attn") and isinstance( + fp_layer.self_attn, nn.Module + ) + assert hasattr(fp_layer, "mlp") and isinstance(fp_layer.mlp, nn.Module) + assert hasattr(fp_layer, "input_layernorm") and isinstance( + fp_layer.input_layernorm, nn.Module + ) + assert hasattr(fp_layer, "post_attention_layernorm") and isinstance( + fp_layer.post_attention_layernorm, nn.Module + ) + + # ----- wrap children ------------------------------------------------ + self.self_attn = PTQWrapper( + fp_layer.self_attn, qcfg=attn_cfg, fp_name=f"{fp_name}.self_attn" + ) + self.mlp = PTQWrapper(fp_layer.mlp, qcfg=mlp_cfg, fp_name=f"{fp_name}.mlp") + self.input_layernorm = PTQWrapper( + fp_layer.input_layernorm, + qcfg=input_ln_cfg, + fp_name=f"{fp_name}.input_layernorm", + ) + self.post_attention_layernorm = PTQWrapper( + fp_layer.post_attention_layernorm, + qcfg=post_attn_ln_cfg, + fp_name=f"{fp_name}.post_attention_layernorm", + ) + + # ----- local observers ---------------------------------------------- + self.obs_mlp_residual_out = self._make_obs("mlp_residual_out") + self.obs_causal_mask = self._make_obs("causal_mask") + self.obs_cos = self._make_obs("cos") + self.obs_sin = self._make_obs("sin") + + # ----- static buffers: causal mask template ------------------------- + cfg = fp_layer.self_attn.config + assert hasattr(cfg, "max_position_embeddings") + max_seq = cfg.max_position_embeddings + mask = torch.full((1, 1, max_seq, max_seq), float("-120")) + mask.triu_(1) + self.register_buffer("causal_mask_template", mask, persistent=False) + + # ----- static buffers: RoPE templates -------------------------------- + head_dim = getattr(cfg, "head_dim", None) or ( + cfg.hidden_size // cfg.num_attention_heads + ) + + rotary = getattr(fp_layer, "rotary_emb", None) + if rotary is not None and hasattr(rotary, "inv_freq"): + inv_freq = rotary.inv_freq.detach().float() + attn_scaling = float(getattr(rotary, "attention_scaling", 1.0)) + else: + base = float(getattr(cfg, "rope_theta", 10000.0)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) + ) + attn_scaling = 1.0 + + pos = torch.arange(max_seq, dtype=torch.float32) + freqs = torch.outer(pos, inv_freq) + emb = torch.cat([freqs, freqs], dim=-1) + cos_t = emb.cos() * attn_scaling + sin_t = emb.sin() * attn_scaling + half_dim = head_dim // 2 + sin_t[..., :half_dim] = -sin_t[..., :half_dim] + cos_t = cos_t.unsqueeze(0) # [1, max_seq, head_dim] + sin_t = sin_t.unsqueeze(0) # [1, max_seq, head_dim] + + self.register_buffer("rope_cos_template", cos_t, persistent=False) + self.register_buffer("rope_sin_template", sin_t, persistent=False) + + def _slice_causal(self, seq_len: int, device: torch.device) -> torch.Tensor: + assert isinstance(self.causal_mask_template, torch.Tensor) + return self.causal_mask_template[..., :seq_len, :seq_len].to(device) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.Tensor] | torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Build causal mask if not provided (or provided as bool) + if attention_mask is None or attention_mask.dtype == torch.bool: + L = hidden_states.size(1) + attention_mask = self._slice_causal(L, hidden_states.device) + attention_mask = attention_mask.squeeze(0) + attention_mask = self._fq(attention_mask, self.obs_causal_mask) + + # Build position embeddings if not provided + if position_embeddings is None: + position_embeddings = ( + self.rope_cos_template.to( + dtype=hidden_states.dtype, device=hidden_states.device + ), + self.rope_sin_template.to( + dtype=hidden_states.dtype, device=hidden_states.device + ), + ) + cos, sin = position_embeddings + position_embeddings = ( + self._fq(cos, self.obs_cos), + self._fq(sin, self.obs_sin), + ) + + # Attention block + # QuantQwen3VLTextAttention returns (out, attn_weights) + attn_out = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + ) + hidden_states = residual + attn_out[0] + + # MLP block + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self._fq(residual + hidden_states, self.obs_mlp_residual_out) + + if self.return_type == "tuple": + return (hidden_states,) + elif self.return_type == "tensor": + return hidden_states + else: + raise RuntimeError(f"Invalid return_type: {self.return_type!r}") + + def _all_observers(self): + yield from (self.obs_causal_mask, self.obs_cos, self.obs_sin) + yield from self.self_attn._all_observers() + yield from self.mlp._all_observers() + yield from self.input_layernorm._all_observers() + yield from self.post_attention_layernorm._all_observers() + yield self.obs_mlp_residual_out diff --git a/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py b/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py new file mode 100644 index 00000000..2d9f797b --- /dev/null +++ b/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py @@ -0,0 +1,241 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from transformers.modeling_outputs import BaseModelOutputWithPast + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper +from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase +from tico.quantization.wrapq.wrappers.registry import try_register + + +@try_register("transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextModel") +class QuantQwen3VLTextModel(QuantModuleBase): + """ + Quant-aware drop-in replacement for the Qwen3-VL language model text backbone + (the `language_model` sub-module inside `Qwen3VLModel`). + + Pre-computes shared RoPE templates and a static causal mask once in `__init__`, + then passes them to every decoder layer so they are quantized exactly once + rather than independently in each layer. + """ + + def __init__( + self, + model_fp: nn.Module, + *, + qcfg: Optional[PTQConfig] = None, + fp_name: Optional[str] = None, + ): + super().__init__(qcfg, fp_name=fp_name) + + # ----- child configs ------------------------------------------------ + embed_cfg = qcfg.child("embed_tokens") if qcfg else None + norm_cfg = qcfg.child("norm") if qcfg else None + layers_cfg = qcfg.child("layers") if qcfg else None + + # ----- assertions --------------------------------------------------- + assert hasattr(model_fp, "embed_tokens") and isinstance( + model_fp.embed_tokens, nn.Module + ) + assert hasattr(model_fp, "norm") and isinstance(model_fp.norm, nn.Module) + assert hasattr(model_fp, "layers") and isinstance( + model_fp.layers, nn.ModuleList + ) + + # ----- wrap children ------------------------------------------------ + self.embed_tokens = PTQWrapper( + model_fp.embed_tokens, embed_cfg, fp_name=f"{fp_name}.embed_tokens" + ) + self.norm = PTQWrapper(model_fp.norm, norm_cfg, fp_name=f"{fp_name}.norm") + + new_list = nn.ModuleList() + for idx, layer in enumerate(model_fp.layers): + child_scope = f"{idx}" + child_cfg = ( + layers_cfg.child(child_scope) if layers_cfg is not None else None + ) + new_list.append( + PTQWrapper(layer, child_cfg, fp_name=child_scope) + ) + self.layers = new_list + + # ----- local observers ---------------------------------------------- + self.obs_causal_mask = self._make_obs("causal_mask") + self.obs_cos = self._make_obs("cos") + self.obs_sin = self._make_obs("sin") + + self.config = model_fp.config + + # ----- static buffers: causal mask template ------------------------- + assert isinstance(self.config.max_position_embeddings, int) + max_seq = self.config.max_position_embeddings + mask = torch.full((1, 1, max_seq, max_seq), float("-120")) + mask.triu_(1) + self.register_buffer("causal_mask_template", mask, persistent=False) + + # ----- static buffers: RoPE templates -------------------------------- + head_dim = getattr(self.config, "head_dim", None) or ( + self.config.hidden_size // self.config.num_attention_heads + ) + + rotary = getattr(model_fp, "rotary_emb", None) + assert rotary is not None, ( + "Qwen3VLTextModel must have a `rotary_emb` attribute for RoPE pre-computation" + ) + if hasattr(rotary, "inv_freq"): + inv_freq = rotary.inv_freq.detach().float() + attn_scaling = float(getattr(rotary, "attention_scaling", 1.0)) + else: + base = float(getattr(self.config, "rope_theta", 10000.0)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) + ) + attn_scaling = 1.0 + + pos = torch.arange(max_seq, dtype=torch.float32, device=inv_freq.device) + freqs = torch.outer(pos, inv_freq) + emb = torch.cat([freqs, freqs], dim=-1) + cos_t = emb.cos() * attn_scaling + sin_t = emb.sin() * attn_scaling + half_dim = head_dim // 2 + sin_t[..., :half_dim] = -sin_t[..., :half_dim] + cos_t = cos_t.unsqueeze(0) # [1, max_seq, head_dim] + sin_t = sin_t.unsqueeze(0) # [1, max_seq, head_dim] + + self.register_buffer("rope_cos_template", cos_t, persistent=False) + self.register_buffer("rope_sin_template", sin_t, persistent=False) + + def _slice_causal(self, seq_len: int, device: torch.device) -> torch.Tensor: + assert isinstance(self.causal_mask_template, torch.Tensor) + return self.causal_mask_template[..., :seq_len, :seq_len].to(device) + + def get_attention_mask_for(self, hidden_states: torch.Tensor) -> torch.Tensor: + L = hidden_states.size(1) + return self._slice_causal(L, hidden_states.device) + + def get_position_embeddings_for( + self, hidden_states: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + return ( + self.rope_cos_template.to( # type: ignore[union-attr] + dtype=hidden_states.dtype, device=hidden_states.device + ), + self.rope_sin_template.to( # type: ignore[union-attr] + dtype=hidden_states.dtype, device=hidden_states.device + ), + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # Pre-compute shared causal mask and RoPE (quantized once, shared across layers) + causal_mask = self.get_attention_mask_for(hidden_states) + causal_mask = causal_mask.squeeze(0) + causal_mask = self._fq(causal_mask, self.obs_causal_mask) + + position_embeddings = self.get_position_embeddings_for(hidden_states) + cos, sin = position_embeddings + position_embeddings = ( + self._fq(cos, self.obs_cos), + self._fq(sin, self.obs_sin), + ) + + # Decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) # type: ignore[operator] + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + ) + + if hasattr(decoder_layer, "wrapped") and hasattr( + decoder_layer.wrapped, "return_type" + ): + if decoder_layer.wrapped.return_type == "tuple": + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + else: + hidden_states = ( + layer_outputs[0] + if isinstance(layer_outputs, tuple) + else layer_outputs + ) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) # type: ignore[operator] + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _all_observers(self): + yield from (self.obs_causal_mask, self.obs_cos, self.obs_sin) + for m in (self.embed_tokens, self.norm): + yield from m._all_observers() + for m in self.layers: + yield from m._all_observers() diff --git a/tico/quantization/wrapq/wrappers/registry.py b/tico/quantization/wrapq/wrappers/registry.py index 0359a090..22c6b67e 100644 --- a/tico/quantization/wrapq/wrappers/registry.py +++ b/tico/quantization/wrapq/wrappers/registry.py @@ -47,6 +47,8 @@ ## qwen_vl ## "tico.quantization.wrapq.wrappers.qwen_vl.quant_text_attn", "tico.quantization.wrapq.wrappers.qwen_vl.quant_text_mlp", + "tico.quantization.wrapq.wrappers.qwen_vl.quant_text_decoder_layer", + "tico.quantization.wrapq.wrappers.qwen_vl.quant_text_model", "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_attn", "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_mlp", "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_patch_embed", From 4a2e8e45057808b9aacb69ed1d5557d7cb282d60 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Wed, 11 Mar 2026 16:13:42 +0900 Subject: [PATCH 2/4] Add examples --- .../qwen/quantize_text_decoder_layer.py | 136 +++++++++++++++++ .../examples/qwen/quantize_text_model.py | 140 ++++++++++++++++++ 2 files changed, 276 insertions(+) create mode 100644 tico/quantization/wrapq/examples/qwen/quantize_text_decoder_layer.py create mode 100644 tico/quantization/wrapq/examples/qwen/quantize_text_model.py diff --git a/tico/quantization/wrapq/examples/qwen/quantize_text_decoder_layer.py b/tico/quantization/wrapq/examples/qwen/quantize_text_decoder_layer.py new file mode 100644 index 00000000..1ea3b0b6 --- /dev/null +++ b/tico/quantization/wrapq/examples/qwen/quantize_text_decoder_layer.py @@ -0,0 +1,136 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib + +import torch +from transformers import AutoModelForImageTextToText, AutoTokenizer + +from tico.quantization import convert, prepare +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.evaluation.metric import compute_peir +from tico.quantization.evaluation.utils import plot_two_outputs +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.wrappers.qwen_vl.quant_text_decoder_layer import ( + QuantQwen3VLTextDecoderLayer, +) +from tico.utils.utils import SuppressWarning + +# ------------------------------------------------------------------------- +# 0. Load a Qwen3-VL model (text tower) + tokenizer +# ------------------------------------------------------------------------- +name = "Qwen/Qwen3-VL-2B-Instruct" +model = AutoModelForImageTextToText.from_pretrained( + name, + device_map="cpu", + trust_remote_code=True, + dtype=torch.float32, +) +tokenizer = AutoTokenizer.from_pretrained(name, trust_remote_code=True) + +if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + +MAX_SEQ = 128 +text_cfg = model.config.text_config +text_cfg.max_position_embeddings = MAX_SEQ + +# ------------------------------------------------------------------------- +# 1. Wrap layer-0's decoder layer with QuantQwen3VLTextDecoderLayer +# +# QuantQwen3VLTextDecoderLayer pre-computes static causal mask and RoPE +# templates internally, so calibration only requires hidden_states input. +# ------------------------------------------------------------------------- +orig_layer = model.model.language_model.layers[0] +model.model.language_model.layers[0] = prepare(orig_layer, PTQConfig()) +model.eval() + +layer_q = model.model.language_model.layers[0] +assert isinstance(layer_q.wrapped, QuantQwen3VLTextDecoderLayer) + +# ------------------------------------------------------------------------- +# Helpers: tokenize → embed to get hidden states for calibration +# ------------------------------------------------------------------------- +def make_hidden(prompt: str) -> torch.Tensor: + batch = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=MAX_SEQ, + ) + with torch.no_grad(): + return model.model.language_model.embed_tokens(batch["input_ids"]) + + +# ------------------------------------------------------------------------- +# 2. Calibration +# ------------------------------------------------------------------------- +PROMPTS = [ + "The quick brown fox jumps over the lazy dog.", + "In 2025, AI systems accelerated hardware-software co-design at scale.", + "양자화는 왜 어려울까? 분포, 길이, 마스크가 관건이다.", + "今日はいい天気ですね。ところでRoPE角度は長さに依存します。", + "def quicksort(arr):\n if len(arr) <= 1: return arr\n ...", + "Prices rose 3.14% — see Figure 2; emails: foo@bar.com!", +] + +with torch.no_grad(): + for prompt in PROMPTS: + hidden = make_hidden(prompt) + # position_embeddings and attention_mask are built internally + _ = layer_q(hidden) + +convert(layer_q) +assert layer_q._mode is Mode.QUANT, "Quantization mode should be active now." + +# ------------------------------------------------------------------------- +# 3. Quick diff check (INT-sim vs FP32) +# ------------------------------------------------------------------------- +hidden = make_hidden("check") + +mask = torch.full((1, 1, MAX_SEQ, MAX_SEQ), float("-120")) +mask.triu_(1) + +rotary = model.model.language_model.rotary_emb +position_ids = torch.arange(MAX_SEQ).unsqueeze(0) + +with torch.no_grad(): + q_out = layer_q(hidden) + q_out = q_out[0] if isinstance(q_out, tuple) else q_out + + pos = rotary(hidden, position_ids) + fp_out = orig_layer(hidden, attention_mask=mask, position_embeddings=pos) + fp_out = fp_out[0] if isinstance(fp_out, tuple) else fp_out + +print("┌───────────── Quantization Error Summary ─────────────") +print(f"│ Mean |diff|: {(q_out - fp_out).abs().mean().item():.6f}") +print(f"│ PEIR : {compute_peir(fp_out, q_out) * 100:.6f} %") +print("└──────────────────────────────────────────────────────") +print(plot_two_outputs(fp_out, q_out)) + +# ------------------------------------------------------------------------- +# 4. Export the quantized decoder layer to Circle +# ------------------------------------------------------------------------- +import tico + +save_path = pathlib.Path("qwen3vl_text_decoder_layer.q.circle") +B, S, D = 1, MAX_SEQ, text_cfg.hidden_size +example_hidden = torch.randn(B, S, D) + +with SuppressWarning(UserWarning, ".*"): + cm = tico.convert(layer_q, (example_hidden,)) +cm.save(save_path) + +print(f"Quantized Circle model saved to {save_path.resolve()}") diff --git a/tico/quantization/wrapq/examples/qwen/quantize_text_model.py b/tico/quantization/wrapq/examples/qwen/quantize_text_model.py new file mode 100644 index 00000000..526bf0f6 --- /dev/null +++ b/tico/quantization/wrapq/examples/qwen/quantize_text_model.py @@ -0,0 +1,140 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers import AutoModelForImageTextToText, AutoTokenizer + +from tico.quantization import convert, prepare +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.evaluation.metric import compute_peir +from tico.quantization.evaluation.utils import plot_two_outputs +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.wrappers.qwen_vl.quant_text_model import ( + QuantQwen3VLTextModel, +) + +# ------------------------------------------------------------------------- +# 0. Load a Qwen3-VL model + tokenizer +# ------------------------------------------------------------------------- +name = "Qwen/Qwen3-VL-2B-Instruct" +model = AutoModelForImageTextToText.from_pretrained( + name, + device_map="cpu", + trust_remote_code=True, + dtype=torch.float32, +) +tokenizer = AutoTokenizer.from_pretrained(name, trust_remote_code=True) + +if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + +MAX_SEQ = 128 +text_cfg = model.config.text_config +text_cfg.max_position_embeddings = MAX_SEQ + +# ------------------------------------------------------------------------- +# 1. Wrap the language model backbone with QuantQwen3VLTextModel +# +# QuantQwen3VLTextModel replaces the text backbone (language_model) and: +# - Pre-computes a shared static causal mask +# - Pre-computes shared RoPE cos/sin templates (sliced per seq_len) +# - Passes them to every decoder layer once, avoiding redundant computation +# ------------------------------------------------------------------------- +orig_lm = model.model.language_model +model.model.language_model = prepare(orig_lm, PTQConfig()) +model.eval() + +lm_q = model.model.language_model +assert isinstance(lm_q.wrapped, QuantQwen3VLTextModel) + +# ------------------------------------------------------------------------- +# Helpers: fixed-length tokenize → input_ids +# ------------------------------------------------------------------------- +def make_input_ids(prompt: str) -> torch.Tensor: + batch = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=MAX_SEQ, + ) + return batch["input_ids"] + + +# ------------------------------------------------------------------------- +# 2. Calibration +# ------------------------------------------------------------------------- +PROMPTS = [ + "The quick brown fox jumps over the lazy dog.", + "In 2025, AI systems accelerated hardware-software co-design at scale.", + "양자화는 왜 어려울까? 분포, 길이, 마스크가 관건이다.", + "今日はいい天気ですね。ところでRoPE角度は長さに依存します。", + "def quicksort(arr):\n if len(arr) <= 1: return arr\n ...", + "Prices rose 3.14% — see Figure 2; emails: foo@bar.com!", +] + +with torch.no_grad(): + for prompt in PROMPTS: + input_ids = make_input_ids(prompt) + _ = lm_q(input_ids) + +convert(lm_q) +assert lm_q._mode is Mode.QUANT, "Quantization mode should be active now." + +# ------------------------------------------------------------------------- +# 3. Quick diff check (INT-sim vs FP32) +# ------------------------------------------------------------------------- +input_ids = make_input_ids("check") + +with torch.no_grad(): + q_out = lm_q(input_ids, return_dict=False)[0] # last_hidden_state + fp_out = orig_lm(input_ids, return_dict=False)[0] + +print("┌───────────── Quantization Error Summary ─────────────") +print(f"│ Mean |diff|: {(q_out - fp_out).abs().mean().item():.6f}") +print(f"│ PEIR : {compute_peir(fp_out, q_out) * 100:.6f} %") +print("└──────────────────────────────────────────────────────") +print(plot_two_outputs(fp_out, q_out)) + +# Note on PEIR at the model level +# -------------------------------- +# PEIR here measures the L2 divergence of the full text backbone output, not a +# single layer. It is expected to be significantly higher than per-layer PEIR +# (e.g. QuantQwen3VLTextAttention) for two reasons: +# +# 1. Error accumulation: quantization errors from every decoder layer compound +# multiplicatively. Even 1-2 % per layer becomes tens of percent over 28+ +# layers. +# +# 2. Padding in calibration data: the static causal mask used by this wrapper is +# a pure upper-triangular mask and does not mask padding tokens. Calibration +# sequences padded to MAX_SEQ therefore attend to padding positions that the +# original model would have masked, inflating the calibration statistics and +# the resulting PEIR. +# +# PEIR is not a direct proxy for task accuracy. Use downstream metrics +# (e.g. perplexity, VQA score) to evaluate the quantized model's real accuracy. + +# ------------------------------------------------------------------------- +# Note on Circle export +# ------------------------------------------------------------------------- +# Exporting QuantQwen3VLTextModel directly to Circle is not shown here +# because torch.export.export() requires the model to return a plain +# Tensor or a flat tuple of Tensors. The current forward() returns +# BaseModelOutputWithPast (a named tuple), which requires a thin adapter +# wrapper before calling tico.convert(). +# +# For subgraph-level export, see the individual layer examples: +# - quantize_text_attn.py +# - quantize_text_decoder_layer.py From e71a9b1680e5b9457109de992d565b86de999314 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Wed, 11 Mar 2026 16:14:03 +0900 Subject: [PATCH 3/4] Improve PEIR by fixing rope and padding --- .../wrappers/qwen_vl/test_quant_text_model.py | 8 +- .../examples/qwen/quantize_text_model.py | 32 ++++--- .../qwen_vl/quant_text_decoder_layer.py | 45 +++++---- .../wrappers/qwen_vl/quant_text_model.py | 93 +++++++++++-------- 4 files changed, 107 insertions(+), 71 deletions(-) diff --git a/test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py index 4c4d4600..ee36baa7 100644 --- a/test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py +++ b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py @@ -43,12 +43,13 @@ def setUpClass(cls): cls.seq_len = 16 cls.vocab_size = 512 + # head_dim=8 → head_dim//2=4; mrope_section must sum to head_dim//2 cfg = Qwen3VLTextConfig( - hidden_size=8, - intermediate_size=16, + hidden_size=16, + intermediate_size=32, num_attention_heads=2, num_key_value_heads=1, - head_dim=4, + head_dim=8, num_hidden_layers=2, attention_bias=False, attention_dropout=0.0, @@ -56,6 +57,7 @@ def setUpClass(cls): vocab_size=cls.vocab_size, use_cache=False, return_dict=False, + rope_scaling={"rope_type": "default", "mrope_section": [1, 1, 2]}, ) if not hasattr(cfg, "_attn_implementation"): setattr(cfg, "_attn_implementation", "eager") diff --git a/tico/quantization/wrapq/examples/qwen/quantize_text_model.py b/tico/quantization/wrapq/examples/qwen/quantize_text_model.py index 526bf0f6..cee80a01 100644 --- a/tico/quantization/wrapq/examples/qwen/quantize_text_model.py +++ b/tico/quantization/wrapq/examples/qwen/quantize_text_model.py @@ -59,9 +59,14 @@ assert isinstance(lm_q.wrapped, QuantQwen3VLTextModel) # ------------------------------------------------------------------------- -# Helpers: fixed-length tokenize → input_ids +# Helpers: fixed-length tokenize → (input_ids, attention_mask) +# +# We return the 2D attention_mask alongside input_ids. Passing it to the +# model lets QuantQwen3VLTextModel combine the static causal mask with a +# per-sequence padding correction, so padding tokens are masked out during +# both calibration and quantized inference. # ------------------------------------------------------------------------- -def make_input_ids(prompt: str) -> torch.Tensor: +def make_inputs(prompt: str): batch = tokenizer( prompt, return_tensors="pt", @@ -69,7 +74,7 @@ def make_input_ids(prompt: str) -> torch.Tensor: truncation=True, max_length=MAX_SEQ, ) - return batch["input_ids"] + return batch["input_ids"], batch["attention_mask"] # ------------------------------------------------------------------------- @@ -86,8 +91,8 @@ def make_input_ids(prompt: str) -> torch.Tensor: with torch.no_grad(): for prompt in PROMPTS: - input_ids = make_input_ids(prompt) - _ = lm_q(input_ids) + input_ids, attention_mask = make_inputs(prompt) + _ = lm_q(input_ids, attention_mask=attention_mask) convert(lm_q) assert lm_q._mode is Mode.QUANT, "Quantization mode should be active now." @@ -95,11 +100,11 @@ def make_input_ids(prompt: str) -> torch.Tensor: # ------------------------------------------------------------------------- # 3. Quick diff check (INT-sim vs FP32) # ------------------------------------------------------------------------- -input_ids = make_input_ids("check") +input_ids, attention_mask = make_inputs("check") with torch.no_grad(): - q_out = lm_q(input_ids, return_dict=False)[0] # last_hidden_state - fp_out = orig_lm(input_ids, return_dict=False)[0] + q_out = lm_q(input_ids, attention_mask=attention_mask, return_dict=False)[0] + fp_out = orig_lm(input_ids, attention_mask=attention_mask, return_dict=False)[0] print("┌───────────── Quantization Error Summary ─────────────") print(f"│ Mean |diff|: {(q_out - fp_out).abs().mean().item():.6f}") @@ -117,11 +122,12 @@ def make_input_ids(prompt: str) -> torch.Tensor: # multiplicatively. Even 1-2 % per layer becomes tens of percent over 28+ # layers. # -# 2. Padding in calibration data: the static causal mask used by this wrapper is -# a pure upper-triangular mask and does not mask padding tokens. Calibration -# sequences padded to MAX_SEQ therefore attend to padding positions that the -# original model would have masked, inflating the calibration statistics and -# the resulting PEIR. +# 2. Padding in calibration data: even with the padding mask fix applied above, +# the static causal mask baked into the quantized model does not carry +# sequence-specific padding information. During calibration, padding is +# correctly masked out, which improves observer statistics. However, the +# PEIR comparison above also passes attention_mask to orig_lm, so both models +# see the same mask and the measurement is fair. # # PEIR is not a direct proxy for task accuracy. Use downstream metrics # (e.g. perplexity, VQA score) to evaluate the quantized model's real accuracy. diff --git a/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_decoder_layer.py b/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_decoder_layer.py index 3c9acd49..228673ff 100644 --- a/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_decoder_layer.py +++ b/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_decoder_layer.py @@ -29,16 +29,31 @@ class QuantQwen3VLTextDecoderLayer(QuantModuleBase): Quant-aware drop-in replacement for HF `Qwen3VLTextDecoderLayer`. Attention & MLP blocks are replaced by their quantized counterparts. - A "static" causal mask and RoPE templates are pre-built in `__init__` - to avoid dynamic ops inside `forward`. + A static causal mask and RoPE templates are pre-built in `__init__` to avoid + dynamic ops inside `forward`. Notes ----- - - Prefill-only: `use_cache` is not supported because - `QuantQwen3VLTextAttention` does not return KV cache. - - `position_embeddings` can be injected from the parent model-level wrapper - (shared across all layers); if omitted the layer uses its own pre-computed - templates as fallback. + Prefill-only + `use_cache` is not supported because `QuantQwen3VLTextAttention` does not + return KV cache. + + position_embeddings injection + When used inside `QuantQwen3VLTextModel`, the parent wrapper calls the + model's actual `Qwen3VLTextRotaryEmbedding` and passes the resulting + (cos, sin) into every decoder layer. This is the recommended path as it + produces exact MRoPE position encodings. + + Standalone RoPE fallback limitation + When `position_embeddings=None` (standalone usage without a parent model + wrapper), this layer falls back to a simplified 1D cos/sin table computed + from `rope_theta`. For Qwen3-VL this is an approximation: the model + actually uses MRoPE where `inv_freq` is split across three sections + (temporal, height, width) by `mrope_section`. The simplified table + assigns frequencies uniformly and therefore does not match the original + model's position encodings exactly. Calibration PEIR will be higher in + standalone mode for this reason; use the parent `QuantQwen3VLTextModel` + wrapper for accurate calibration. """ def __init__( @@ -158,17 +173,15 @@ def forward( attention_mask = attention_mask.squeeze(0) attention_mask = self._fq(attention_mask, self.obs_causal_mask) - # Build position embeddings if not provided + # Build position embeddings if not provided; slice to actual seq_len if position_embeddings is None: - position_embeddings = ( - self.rope_cos_template.to( - dtype=hidden_states.dtype, device=hidden_states.device - ), - self.rope_sin_template.to( - dtype=hidden_states.dtype, device=hidden_states.device - ), + S = hidden_states.size(1) + cos = self.rope_cos_template[:, :S, :].to( # type: ignore[index] + dtype=hidden_states.dtype, device=hidden_states.device + ) + sin = self.rope_sin_template[:, :S, :].to( # type: ignore[index] + dtype=hidden_states.dtype, device=hidden_states.device ) - cos, sin = position_embeddings position_embeddings = ( self._fq(cos, self.obs_cos), self._fq(sin, self.obs_sin), diff --git a/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py b/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py index 2d9f797b..8a6c01e9 100644 --- a/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py +++ b/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py @@ -31,9 +31,29 @@ class QuantQwen3VLTextModel(QuantModuleBase): Quant-aware drop-in replacement for the Qwen3-VL language model text backbone (the `language_model` sub-module inside `Qwen3VLModel`). - Pre-computes shared RoPE templates and a static causal mask once in `__init__`, - then passes them to every decoder layer so they are quantized exactly once - rather than independently in each layer. + Computes shared RoPE position embeddings and a static causal mask once per + forward pass, then passes them to every decoder layer so they are quantized + exactly once rather than independently in each layer. + + RoPE note + --------- + Position embeddings are computed by calling the model's own `rotary_emb` module + (Qwen3VLTextRotaryEmbedding) directly in `forward()`, rather than using a + simplified 1D pre-computed table. This is important for accuracy because + Qwen3-VL uses MRoPE: `inv_freq` is split into three sections (temporal, height, + width) via `mrope_section`, which means a naively-computed 1D cos/sin table + would differ from the model's actual position encodings and accumulate error + across all decoder layers. + + Known PEIR limitation + --------------------- + PEIR measured on this full-backbone wrapper will be higher than for single-layer + wrappers (e.g. QuantQwen3VLTextAttention) because quantization errors from all + decoder layers compound multiplicatively. This is expected behavior and does not + reflect task-level accuracy degradation, which should be measured via downstream + metrics (e.g. perplexity). Additionally, the static causal mask does not account + for padding tokens in the calibration data, which can inflate PEIR further when + padded sequences are used. """ def __init__( @@ -90,37 +110,17 @@ def __init__( mask.triu_(1) self.register_buffer("causal_mask_template", mask, persistent=False) - # ----- static buffers: RoPE templates -------------------------------- - head_dim = getattr(self.config, "head_dim", None) or ( - self.config.hidden_size // self.config.num_attention_heads - ) - + # ----- rotary embedding module --------------------------------------- + # Store the model's actual rotary_emb (Qwen3VLTextRotaryEmbedding) and + # call it dynamically in forward() to get exact MRoPE position embeddings. + # Using a simplified 1D pre-computed table would mis-assign frequencies + # across mrope_section boundaries, producing wrong cos/sin and degrading + # calibration accuracy significantly. rotary = getattr(model_fp, "rotary_emb", None) assert rotary is not None, ( - "Qwen3VLTextModel must have a `rotary_emb` attribute for RoPE pre-computation" + "Qwen3VLTextModel must have a `rotary_emb` attribute" ) - if hasattr(rotary, "inv_freq"): - inv_freq = rotary.inv_freq.detach().float() - attn_scaling = float(getattr(rotary, "attention_scaling", 1.0)) - else: - base = float(getattr(self.config, "rope_theta", 10000.0)) - inv_freq = 1.0 / ( - base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) - ) - attn_scaling = 1.0 - - pos = torch.arange(max_seq, dtype=torch.float32, device=inv_freq.device) - freqs = torch.outer(pos, inv_freq) - emb = torch.cat([freqs, freqs], dim=-1) - cos_t = emb.cos() * attn_scaling - sin_t = emb.sin() * attn_scaling - half_dim = head_dim // 2 - sin_t[..., :half_dim] = -sin_t[..., :half_dim] - cos_t = cos_t.unsqueeze(0) # [1, max_seq, head_dim] - sin_t = sin_t.unsqueeze(0) # [1, max_seq, head_dim] - - self.register_buffer("rope_cos_template", cos_t, persistent=False) - self.register_buffer("rope_sin_template", sin_t, persistent=False) + self.rotary_emb = rotary def _slice_causal(self, seq_len: int, device: torch.device) -> torch.Tensor: assert isinstance(self.causal_mask_template, torch.Tensor) @@ -133,14 +133,11 @@ def get_attention_mask_for(self, hidden_states: torch.Tensor) -> torch.Tensor: def get_position_embeddings_for( self, hidden_states: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - return ( - self.rope_cos_template.to( # type: ignore[union-attr] - dtype=hidden_states.dtype, device=hidden_states.device - ), - self.rope_sin_template.to( # type: ignore[union-attr] - dtype=hidden_states.dtype, device=hidden_states.device - ), - ) + # Delegate to the model's actual Qwen3VLTextRotaryEmbedding so that + # MRoPE frequencies are split correctly by mrope_section. + S = hidden_states.size(1) + position_ids = torch.arange(S, device=hidden_states.device).unsqueeze(0) + return self.rotary_emb(hidden_states, position_ids) def forward( self, @@ -180,10 +177,28 @@ def forward( hidden_states = inputs_embeds # Pre-compute shared causal mask and RoPE (quantized once, shared across layers) + # + # The static causal part is quantized once and shared across all decoder layers + # (suitable for on-device baking as a constant). + # + # If a 2D padding mask (attention_mask: [B, L], 1=real token, 0=padding) is + # provided, a padding correction is applied additively on top of the quantized + # causal mask. This ensures that calibration statistics are not inflated by + # activations at padding positions. The padding correction is NOT quantized + # separately — it is applied as a float addend at runtime. causal_mask = self.get_attention_mask_for(hidden_states) causal_mask = causal_mask.squeeze(0) causal_mask = self._fq(causal_mask, self.obs_causal_mask) + if attention_mask is not None and attention_mask.ndim == 2: + # attention_mask: [B, L], 1 = real token, 0 = padding + # Build additive correction: 0 for real, -120 for padding → [B, 1, L] + pad_corr = (1.0 - attention_mask.to(dtype=hidden_states.dtype, + device=hidden_states.device)) + pad_corr = pad_corr * float("-120") + pad_corr = pad_corr.unsqueeze(1) # [B, 1, L] + causal_mask = causal_mask + pad_corr # [B, L, L] (broadcast) + position_embeddings = self.get_position_embeddings_for(hidden_states) cos, sin = position_embeddings position_embeddings = ( From 2c996916a650b94a57c518a9d8a6c0e3ec3d39bc Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Wed, 11 Mar 2026 16:41:08 +0900 Subject: [PATCH 4/4] Add padding masking test --- .../wrappers/qwen_vl/test_quant_text_model.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py index ee36baa7..e24ee973 100644 --- a/test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py +++ b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py @@ -98,3 +98,28 @@ def test_forward_diff(self): self.assertGreater(diff, 0.0) self.assertLess(diff, 0.4) self.assertEqual(fp_out.shape, q_out.shape) + + def test_with_padding_mask(self): + """Verify that a 2D padding attention_mask is accepted and changes the output. + + When padding tokens are masked out, the forward result differs from the + fully-attended (no mask) output because padding positions no longer + contribute to attention weights. + """ + qmodel = QuantQwen3VLTextModel(self.fp_model) + qmodel.enable_calibration() + + inp = torch.randint(0, self.vocab_size, (1, self.seq_len)) + + # Simulate padding: last quarter of positions are padding tokens + pad_start = self.seq_len * 3 // 4 + attn_mask = torch.ones(1, self.seq_len, dtype=torch.long) + attn_mask[:, pad_start:] = 0 # mark as padding + + with torch.no_grad(): + out_with_mask = qmodel(inp, attention_mask=attn_mask)[0] + out_no_mask = qmodel(inp)[0] + + # Outputs should differ because masked attention changes activations + diff = (out_with_mask - out_no_mask).abs().max().item() + self.assertGreater(diff, 0.0)