Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from tico.quantization.config.ptq import PTQConfig
from tico.quantization.wrapq.dtypes import DType
from tico.quantization.wrapq.mode import Mode
from tico.quantization.wrapq.utils.version import has_transformers_for
from tico.quantization.wrapq.wrappers.qwen_vl.quant_text_decoder_layer import (
QuantQwen3VLTextDecoderLayer,
)

skip_msg = "required transformers not installed — skipping Qwen3VLTextDecoderLayer tests"


@unittest.skipUnless(has_transformers_for("qwen3-vl"), skip_msg)
class TestQuantQwen3VLTextDecoderLayer(unittest.TestCase):
fp_layer: torch.nn.Module

@classmethod
def setUpClass(cls):
torch.manual_seed(0)

from transformers.models.qwen3_vl.configuration_qwen3_vl import (
Qwen3VLTextConfig,
)
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
Qwen3VLTextDecoderLayer,
)

cls.cfg = Qwen3VLTextConfig(
hidden_size=16,
intermediate_size=32,
num_attention_heads=2,
num_key_value_heads=1,
head_dim=8,
attention_bias=False,
attention_dropout=0.0,
max_position_embeddings=256,
)
if not hasattr(cls.cfg, "_attn_implementation"):
setattr(cls.cfg, "_attn_implementation", "eager")
else:
cls.cfg._attn_implementation = "eager"

cls.fp_layer = Qwen3VLTextDecoderLayer(cls.cfg, layer_idx=0)

def _rand_rope(self, B: int, S: int):
h = self.cfg.head_dim
emb = torch.randn(B, S, h)
return emb.cos(), emb.sin()

def test_mode_transitions(self):
qlayer = QuantQwen3VLTextDecoderLayer(self.fp_layer)
self.assertIs(qlayer._mode, Mode.NO_QUANT)

qlayer.enable_calibration()
self.assertIs(qlayer._mode, Mode.CALIB)

SEQ_LEN = 16
hidden = torch.randn(1, SEQ_LEN, self.cfg.hidden_size)
_ = qlayer(hidden)

qlayer.freeze_qparams()
self.assertIs(qlayer._mode, Mode.QUANT)

def test_forward_diff(self):
qlayer = QuantQwen3VLTextDecoderLayer(self.fp_layer)
qlayer.enable_calibration()

SEQ_LEN = 16
for _ in range(4):
hidden = torch.randn(1, SEQ_LEN, self.cfg.hidden_size)
_ = qlayer(hidden)
qlayer.freeze_qparams()

hidden = torch.randn(1, SEQ_LEN, self.cfg.hidden_size)
pos = self._rand_rope(1, SEQ_LEN)

mask = torch.full((1, 1, SEQ_LEN, SEQ_LEN), float("-120"))
mask.triu_(1)

with torch.no_grad():
q_out = qlayer(hidden)
q_out = q_out[0] if isinstance(q_out, tuple) else q_out

fp_out = self.fp_layer(
hidden, attention_mask=mask, position_embeddings=pos
)
fp_out = fp_out[0] if isinstance(fp_out, tuple) else fp_out

diff = (fp_out - q_out).abs().mean().item()
self.assertGreater(diff, 0.0)
self.assertLess(diff, 0.5)
self.assertEqual(fp_out.shape, q_out.shape)

def test_with_precomputed_embeddings(self):
"""position_embeddings injected from outside (model-level sharing pattern)."""
qlayer = QuantQwen3VLTextDecoderLayer(self.fp_layer)
qlayer.enable_calibration()

SEQ_LEN = 16
hidden = torch.randn(1, SEQ_LEN, self.cfg.hidden_size)
pos = self._rand_rope(1, SEQ_LEN)

mask = torch.full((1, 1, SEQ_LEN, SEQ_LEN), float("-120"))
mask.triu_(1)

_ = qlayer(hidden, attention_mask=mask, position_embeddings=pos)
qlayer.freeze_qparams()
self.assertIs(qlayer._mode, Mode.QUANT)

def test_dtype_override(self):
cfg = PTQConfig(
default_dtype=DType.int(16),
overrides={
"mlp_residual_out": {"dtype": DType.uint(8)},
},
)
qlayer = QuantQwen3VLTextDecoderLayer(self.fp_layer, qcfg=cfg)
self.assertEqual(qlayer.obs_mlp_residual_out.dtype, DType.uint(8))
125 changes: 125 additions & 0 deletions test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from tico.quantization.wrapq.mode import Mode
from tico.quantization.wrapq.utils.version import has_transformers_for
from tico.quantization.wrapq.wrappers.qwen_vl.quant_text_model import (
QuantQwen3VLTextModel,
)

skip_msg = "required transformers not installed — skipping Qwen3VLTextModel tests"


@unittest.skipUnless(has_transformers_for("qwen3-vl"), skip_msg)
class TestQuantQwen3VLTextModel(unittest.TestCase):
fp_model: torch.nn.Module
vocab_size: int
seq_len: int

@classmethod
def setUpClass(cls):
torch.manual_seed(0)

from transformers.models.qwen3_vl.configuration_qwen3_vl import (
Qwen3VLTextConfig,
)
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel

cls.seq_len = 16
cls.vocab_size = 512

# head_dim=8 → head_dim//2=4; mrope_section must sum to head_dim//2
cfg = Qwen3VLTextConfig(
hidden_size=16,
intermediate_size=32,
num_attention_heads=2,
num_key_value_heads=1,
head_dim=8,
num_hidden_layers=2,
attention_bias=False,
attention_dropout=0.0,
max_position_embeddings=cls.seq_len,
vocab_size=cls.vocab_size,
use_cache=False,
return_dict=False,
rope_scaling={"rope_type": "default", "mrope_section": [1, 1, 2]},
)
if not hasattr(cfg, "_attn_implementation"):
setattr(cfg, "_attn_implementation", "eager")
else:
cfg._attn_implementation = "eager"

cls.fp_model = Qwen3VLTextModel(cfg)

def test_mode_transitions(self):
qmodel = QuantQwen3VLTextModel(self.fp_model)
self.assertIs(qmodel._mode, Mode.NO_QUANT)

qmodel.enable_calibration()
self.assertIs(qmodel._mode, Mode.CALIB)

x = torch.randint(0, self.vocab_size, (1, self.seq_len))
_ = qmodel(x)

qmodel.freeze_qparams()
self.assertIs(qmodel._mode, Mode.QUANT)

def test_forward_diff(self):
qmodel = QuantQwen3VLTextModel(self.fp_model)
qmodel.enable_calibration()

calib_set = []
for _ in range(4):
inp = torch.randint(0, self.vocab_size, (1, self.seq_len))
_ = qmodel(inp)
calib_set.append(inp)
qmodel.freeze_qparams()

with torch.no_grad():
q_out = qmodel(calib_set[0])[0]
fp_out = self.fp_model(calib_set[0])[0]

diff = (fp_out - q_out).abs().mean().item()
self.assertGreater(diff, 0.0)
self.assertLess(diff, 0.4)
self.assertEqual(fp_out.shape, q_out.shape)

def test_with_padding_mask(self):
"""Verify that a 2D padding attention_mask is accepted and changes the output.

When padding tokens are masked out, the forward result differs from the
fully-attended (no mask) output because padding positions no longer
contribute to attention weights.
"""
qmodel = QuantQwen3VLTextModel(self.fp_model)
qmodel.enable_calibration()

inp = torch.randint(0, self.vocab_size, (1, self.seq_len))

# Simulate padding: last quarter of positions are padding tokens
pad_start = self.seq_len * 3 // 4
attn_mask = torch.ones(1, self.seq_len, dtype=torch.long)
attn_mask[:, pad_start:] = 0 # mark as padding

with torch.no_grad():
out_with_mask = qmodel(inp, attention_mask=attn_mask)[0]
out_no_mask = qmodel(inp)[0]

# Outputs should differ because masked attention changes activations
diff = (out_with_mask - out_no_mask).abs().max().item()
self.assertGreater(diff, 0.0)
136 changes: 136 additions & 0 deletions tico/quantization/wrapq/examples/qwen/quantize_text_decoder_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pathlib

import torch
from transformers import AutoModelForImageTextToText, AutoTokenizer

from tico.quantization import convert, prepare
from tico.quantization.config.ptq import PTQConfig
from tico.quantization.evaluation.metric import compute_peir
from tico.quantization.evaluation.utils import plot_two_outputs
from tico.quantization.wrapq.mode import Mode
from tico.quantization.wrapq.wrappers.qwen_vl.quant_text_decoder_layer import (
QuantQwen3VLTextDecoderLayer,
)
from tico.utils.utils import SuppressWarning

# -------------------------------------------------------------------------
# 0. Load a Qwen3-VL model (text tower) + tokenizer
# -------------------------------------------------------------------------
name = "Qwen/Qwen3-VL-2B-Instruct"
model = AutoModelForImageTextToText.from_pretrained(
name,
device_map="cpu",
trust_remote_code=True,
dtype=torch.float32,
)
tokenizer = AutoTokenizer.from_pretrained(name, trust_remote_code=True)

if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

MAX_SEQ = 128
text_cfg = model.config.text_config
text_cfg.max_position_embeddings = MAX_SEQ

# -------------------------------------------------------------------------
# 1. Wrap layer-0's decoder layer with QuantQwen3VLTextDecoderLayer
#
# QuantQwen3VLTextDecoderLayer pre-computes static causal mask and RoPE
# templates internally, so calibration only requires hidden_states input.
# -------------------------------------------------------------------------
orig_layer = model.model.language_model.layers[0]
model.model.language_model.layers[0] = prepare(orig_layer, PTQConfig())
model.eval()

layer_q = model.model.language_model.layers[0]
assert isinstance(layer_q.wrapped, QuantQwen3VLTextDecoderLayer)

# -------------------------------------------------------------------------
# Helpers: tokenize → embed to get hidden states for calibration
# -------------------------------------------------------------------------
def make_hidden(prompt: str) -> torch.Tensor:
batch = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=MAX_SEQ,
)
with torch.no_grad():
return model.model.language_model.embed_tokens(batch["input_ids"])


# -------------------------------------------------------------------------
# 2. Calibration
# -------------------------------------------------------------------------
PROMPTS = [
"The quick brown fox jumps over the lazy dog.",
"In 2025, AI systems accelerated hardware-software co-design at scale.",
"양자화는 왜 어려울까? 분포, 길이, 마스크가 관건이다.",
"今日はいい天気ですね。ところでRoPE角度は長さに依存します。",
"def quicksort(arr):\n if len(arr) <= 1: return arr\n ...",
"Prices rose 3.14% — see Figure 2; emails: foo@bar.com!",
]

with torch.no_grad():
for prompt in PROMPTS:
hidden = make_hidden(prompt)
# position_embeddings and attention_mask are built internally
_ = layer_q(hidden)

convert(layer_q)
assert layer_q._mode is Mode.QUANT, "Quantization mode should be active now."

# -------------------------------------------------------------------------
# 3. Quick diff check (INT-sim vs FP32)
# -------------------------------------------------------------------------
hidden = make_hidden("check")

mask = torch.full((1, 1, MAX_SEQ, MAX_SEQ), float("-120"))
mask.triu_(1)

rotary = model.model.language_model.rotary_emb
position_ids = torch.arange(MAX_SEQ).unsqueeze(0)

with torch.no_grad():
q_out = layer_q(hidden)
q_out = q_out[0] if isinstance(q_out, tuple) else q_out

pos = rotary(hidden, position_ids)
fp_out = orig_layer(hidden, attention_mask=mask, position_embeddings=pos)
fp_out = fp_out[0] if isinstance(fp_out, tuple) else fp_out

print("┌───────────── Quantization Error Summary ─────────────")
print(f"│ Mean |diff|: {(q_out - fp_out).abs().mean().item():.6f}")
print(f"│ PEIR : {compute_peir(fp_out, q_out) * 100:.6f} %")
print("└──────────────────────────────────────────────────────")
print(plot_two_outputs(fp_out, q_out))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

┌───────────── Quantization Error Summary ─────────────
│ Mean |diff|: 0.071578PEIR       : 9.253764 %
└──────────────────────────────────────────────────────
    ┌────────────────────────────────────────────┐
 5.1┤                                         •  │
 3.4┤                              • ••••    •   │
 1.7┤                        ••••••••••          │
 0.0┤                 ••••••••••                 │
-1.7┤            • ••••••                        │
-3.4┤   ••••••••                                 │
-5.1┤  •                                         │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -5.1       -2.5        0.0       2.5       5.1 


# -------------------------------------------------------------------------
# 4. Export the quantized decoder layer to Circle
# -------------------------------------------------------------------------
import tico

save_path = pathlib.Path("qwen3vl_text_decoder_layer.q.circle")
B, S, D = 1, MAX_SEQ, text_cfg.hidden_size
example_hidden = torch.randn(B, S, D)

with SuppressWarning(UserWarning, ".*"):
cm = tico.convert(layer_q, (example_hidden,))
cm.save(save_path)

print(f"Quantized Circle model saved to {save_path.resolve()}")
Loading
Loading