diff --git a/tico/quantization/wrapq/examples/llama/quantize_decoder_layer_prefill.py b/tico/quantization/wrapq/examples/llama/quantize_decoder_layer_prefill.py index b306b830..448d3c38 100644 --- a/tico/quantization/wrapq/examples/llama/quantize_decoder_layer_prefill.py +++ b/tico/quantization/wrapq/examples/llama/quantize_decoder_layer_prefill.py @@ -43,6 +43,7 @@ MODEL_NAME = "Maykeye/TinyLLama-v0" MAX_SEQ = 256 +static_data_inside_the_layer = True model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype="float32") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, legacy=False) @@ -139,10 +140,26 @@ def make_fixed_inputs(prompt: str): example_hidden = torch.randn(B, S, D) with SuppressWarning(UserWarning, ".*"): - cm = tico.convert( - qlayer, - (example_hidden,), - ) + if static_data_inside_the_layer is True: + cm = tico.convert( + qlayer, + (example_hidden,), + ) + else: + qattn = qlayer.wrapped + attn_mask = qattn._slice_causal(S, "cpu").squeeze(0) + dtype = example_hidden.dtype + pos_embeds = ( + qattn.rope_cos_template.cpu().to(dtype), + qattn.rope_sin_template.cpu().to(dtype), + ) + + cm = tico.convert( + qlayer, + (example_hidden,), + kwargs={"attention_mask": attn_mask, "position_embeddings": pos_embeds}, + ) + # Note that the model is not fully quantized. cm.save(save_path) diff --git a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py index c8b62eeb..5677d125 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py @@ -193,9 +193,9 @@ def forward( L = hidden_states.size(1) attention_mask = self._slice_causal(L, hidden_states.device) attention_mask = attention_mask.squeeze(0) - attention_mask = self._fq( - attention_mask, self.obs_causal_mask - ) # let it be quantized immediately + attention_mask = self._fq( + attention_mask, self.obs_causal_mask + ) # let it be quantized immediately if position_embeddings is None: position_embeddings = ( @@ -206,11 +206,11 @@ def forward( dtype=hidden_states.dtype, device=hidden_states.device ), ) - cos, sin = position_embeddings - position_embeddings = ( - self._fq(cos, self.obs_cos), - self._fq(sin, self.obs_sin), - ) + cos, sin = position_embeddings + position_embeddings = ( + self._fq(cos, self.obs_cos), + self._fq(sin, self.obs_sin), + ) attn_out = self.self_attn( hidden_states=hidden_states,