From e9f1ac28b20c2d60098c68572d7204bf499ea189 Mon Sep 17 00:00:00 2001 From: "s.malakhov" Date: Tue, 10 Mar 2026 10:53:26 +0300 Subject: [PATCH] [quantization] Pass static data as inputs This PR adds `static_data_inside_the_layer` option to the scipt to have `attention_mask` and `position_embeddings` as inputs. TICO-DCO-1.0-Signed-off-by: s.malakhov --- .../llama/quantize_decoder_layer_prefill.py | 25 ++++++++++++++++--- .../llama/quant_decoder_layer_prefill.py | 16 ++++++------ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/tico/quantization/wrapq/examples/llama/quantize_decoder_layer_prefill.py b/tico/quantization/wrapq/examples/llama/quantize_decoder_layer_prefill.py index b306b830..448d3c38 100644 --- a/tico/quantization/wrapq/examples/llama/quantize_decoder_layer_prefill.py +++ b/tico/quantization/wrapq/examples/llama/quantize_decoder_layer_prefill.py @@ -43,6 +43,7 @@ MODEL_NAME = "Maykeye/TinyLLama-v0" MAX_SEQ = 256 +static_data_inside_the_layer = True model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype="float32") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, legacy=False) @@ -139,10 +140,26 @@ def make_fixed_inputs(prompt: str): example_hidden = torch.randn(B, S, D) with SuppressWarning(UserWarning, ".*"): - cm = tico.convert( - qlayer, - (example_hidden,), - ) + if static_data_inside_the_layer is True: + cm = tico.convert( + qlayer, + (example_hidden,), + ) + else: + qattn = qlayer.wrapped + attn_mask = qattn._slice_causal(S, "cpu").squeeze(0) + dtype = example_hidden.dtype + pos_embeds = ( + qattn.rope_cos_template.cpu().to(dtype), + qattn.rope_sin_template.cpu().to(dtype), + ) + + cm = tico.convert( + qlayer, + (example_hidden,), + kwargs={"attention_mask": attn_mask, "position_embeddings": pos_embeds}, + ) + # Note that the model is not fully quantized. cm.save(save_path) diff --git a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py index c8b62eeb..5677d125 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py @@ -193,9 +193,9 @@ def forward( L = hidden_states.size(1) attention_mask = self._slice_causal(L, hidden_states.device) attention_mask = attention_mask.squeeze(0) - attention_mask = self._fq( - attention_mask, self.obs_causal_mask - ) # let it be quantized immediately + attention_mask = self._fq( + attention_mask, self.obs_causal_mask + ) # let it be quantized immediately if position_embeddings is None: position_embeddings = ( @@ -206,11 +206,11 @@ def forward( dtype=hidden_states.dtype, device=hidden_states.device ), ) - cos, sin = position_embeddings - position_embeddings = ( - self._fq(cos, self.obs_cos), - self._fq(sin, self.obs_sin), - ) + cos, sin = position_embeddings + position_embeddings = ( + self._fq(cos, self.obs_cos), + self._fq(sin, self.obs_sin), + ) attn_out = self.self_attn( hidden_states=hidden_states,