[quantization] Pass static data as inputs#537
Conversation
| self.quantize_attention_mask = False | ||
| self.quantize_position_embeddings = False |
There was a problem hiding this comment.
Note for reviewers. Not sure but maybe these parameters should be set using constructor or config?
Config is unusal place to set such parameters. While modifying constructor will require more changes.
e54c541 to
71bc124
Compare
| attention_mask = self._fq( | ||
| attention_mask, self.obs_causal_mask | ||
| ) # let it be quantized immediately | ||
| elif self.quantize_attention_mask is True: |
There was a problem hiding this comment.
How about just call _fq outside the if conditions?
if attention_mask is None or attention_mask.dtype == torch.bool:
L = hidden_states.size(1)
attention_mask = self._slice_causal(L, hidden_states.device)
attention_mask = attention_mask.squeeze(0)
attention_mask = self._fq(
attention_mask, self.obs_causal_mask
) # let it be quantized immediately
if position_embeddings is None:
position_embeddings = (
self.rope_cos_template.to(
dtype=hidden_states.dtype, device=hidden_states.device
),
self.rope_sin_template.to(
dtype=hidden_states.dtype, device=hidden_states.device
),
)
cos, sin = position_embeddings
position_embeddings = (
self._fq(cos, self.obs_cos),
self._fq(sin, self.obs_sin),
)There was a problem hiding this comment.
@mhs4670go
It was in the previous version. It will populate all layers with their own duplicates of their own masks and the circle file will be 3x larger (for SmoLM at least). That's why it was done by an option, which is not needed while quantizing the full model.
There was a problem hiding this comment.
diff --git a/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py b/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py
index cfc41c5..46de3c9 100644
--- a/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py
+++ b/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py
@@ -205,7 +205,7 @@ class QuantLlamaAttentionPrefill(QuantModuleBase):
hidden_states.device
)
attention_mask = attention_mask.squeeze(0)
- attention_mask = self._fq(attention_mask, self.obs_causal_mask)
+ attention_mask = self._fq(attention_mask, self.obs_causal_mask)
attn_weights_parts = []
attn_out_parts = []
diff --git a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py
index c8b62ee..5677d12 100644
--- a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py
+++ b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py
@@ -193,9 +193,9 @@ class QuantLlamaDecoderLayerPrefill(QuantModuleBase):
L = hidden_states.size(1)
attention_mask = self._slice_causal(L, hidden_states.device)
attention_mask = attention_mask.squeeze(0)
- attention_mask = self._fq(
- attention_mask, self.obs_causal_mask
- ) # let it be quantized immediately
+ attention_mask = self._fq(
+ attention_mask, self.obs_causal_mask
+ ) # let it be quantized immediately
if position_embeddings is None:
position_embeddings = (
@@ -206,11 +206,11 @@ class QuantLlamaDecoderLayerPrefill(QuantModuleBase):
dtype=hidden_states.dtype, device=hidden_states.device
),
)
- cos, sin = position_embeddings
- position_embeddings = (
- self._fq(cos, self.obs_cos),
- self._fq(sin, self.obs_sin),
- )
+ cos, sin = position_embeddings
+ position_embeddings = (
+ self._fq(cos, self.obs_cos),
+ self._fq(sin, self.obs_sin),
+ )
attn_out = self.self_attn(
hidden_states=hidden_states,There was a problem hiding this comment.
Yep. You are right. No duplication currently. Let's use this suggestion.
71bc124 to
c433a89
Compare
This PR adds `static_data_inside_the_layer` option to the scipt to have `attention_mask` and `position_embeddings` as inputs. TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
c433a89 to
e9f1ac2
Compare
This PR adds
static_data_inside_the_layeroption to the scipt to haveattention_maskandposition_embeddingsas inputs.Quantized model uses defined globally
attention_maskandposition_embeddingsand then passes them to the layers to avoid their population. So this PR tries to keep decoder_layer in sync with the whole quantized model.So parameters of the saved model look like this for Maykeye/TinyLLama-v0:

TICO-DCO-1.0-Signed-off-by: s.malakhov s.malakhov@partner.samsung.com