Skip to content

[quantization] Pass static data as inputs#537

Merged
mhs4670go merged 1 commit intoSamsung:mainfrom
stamalakhov:three_inputs_layer_PR
Mar 10, 2026
Merged

[quantization] Pass static data as inputs#537
mhs4670go merged 1 commit intoSamsung:mainfrom
stamalakhov:three_inputs_layer_PR

Conversation

@stamalakhov
Copy link
Contributor

@stamalakhov stamalakhov commented Mar 5, 2026

This PR adds static_data_inside_the_layer option to the scipt to have attention_mask and position_embeddings as inputs.

Quantized model uses defined globally attention_mask and position_embeddings and then passes them to the layers to avoid their population. So this PR tries to keep decoder_layer in sync with the whole quantized model.

So parameters of the saved model look like this for Maykeye/TinyLLama-v0:
image

TICO-DCO-1.0-Signed-off-by: s.malakhov s.malakhov@partner.samsung.com

@stamalakhov stamalakhov self-assigned this Mar 5, 2026
Comment on lines +167 to +168
self.quantize_attention_mask = False
self.quantize_position_embeddings = False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewers. Not sure but maybe these parameters should be set using constructor or config?
Config is unusal place to set such parameters. While modifying constructor will require more changes.

@stamalakhov stamalakhov force-pushed the three_inputs_layer_PR branch from e54c541 to 71bc124 Compare March 6, 2026 13:32
@stamalakhov stamalakhov marked this pull request as ready for review March 6, 2026 13:32
@stamalakhov stamalakhov requested a review from mhs4670go March 6, 2026 13:32
attention_mask = self._fq(
attention_mask, self.obs_causal_mask
) # let it be quantized immediately
elif self.quantize_attention_mask is True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just call _fq outside the if conditions?

if attention_mask is None or attention_mask.dtype == torch.bool:
    L = hidden_states.size(1)
    attention_mask = self._slice_causal(L, hidden_states.device)
    attention_mask = attention_mask.squeeze(0)
attention_mask = self._fq(
    attention_mask, self.obs_causal_mask
)  # let it be quantized immediately

if position_embeddings is None:
    position_embeddings = (
        self.rope_cos_template.to(
            dtype=hidden_states.dtype, device=hidden_states.device
        ),
        self.rope_sin_template.to(
            dtype=hidden_states.dtype, device=hidden_states.device
        ),
    )
cos, sin = position_embeddings
position_embeddings = (
    self._fq(cos, self.obs_cos),
    self._fq(sin, self.obs_sin),
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go
It was in the previous version. It will populate all layers with their own duplicates of their own masks and the circle file will be 3x larger (for SmoLM at least). That's why it was done by an option, which is not needed while quantizing the full model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

diff --git a/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py b/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py
index cfc41c5..46de3c9 100644
--- a/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py
+++ b/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py
@@ -205,7 +205,7 @@ class QuantLlamaAttentionPrefill(QuantModuleBase):
                 hidden_states.device
             )
             attention_mask = attention_mask.squeeze(0)
-            attention_mask = self._fq(attention_mask, self.obs_causal_mask)
+        attention_mask = self._fq(attention_mask, self.obs_causal_mask)
 
         attn_weights_parts = []
         attn_out_parts = []
diff --git a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py
index c8b62ee..5677d12 100644
--- a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py
+++ b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py
@@ -193,9 +193,9 @@ class QuantLlamaDecoderLayerPrefill(QuantModuleBase):
             L = hidden_states.size(1)
             attention_mask = self._slice_causal(L, hidden_states.device)
             attention_mask = attention_mask.squeeze(0)
-            attention_mask = self._fq(
-                attention_mask, self.obs_causal_mask
-            )  # let it be quantized immediately
+        attention_mask = self._fq(
+            attention_mask, self.obs_causal_mask
+        )  # let it be quantized immediately
 
         if position_embeddings is None:
             position_embeddings = (
@@ -206,11 +206,11 @@ class QuantLlamaDecoderLayerPrefill(QuantModuleBase):
                     dtype=hidden_states.dtype, device=hidden_states.device
                 ),
             )
-            cos, sin = position_embeddings
-            position_embeddings = (
-                self._fq(cos, self.obs_cos),
-                self._fq(sin, self.obs_sin),
-            )
+        cos, sin = position_embeddings
+        position_embeddings = (
+            self._fq(cos, self.obs_cos),
+            self._fq(sin, self.obs_sin),
+        )
 
         attn_out = self.self_attn(
             hidden_states=hidden_states,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. You are right. No duplication currently. Let's use this suggestion.

@stamalakhov stamalakhov force-pushed the three_inputs_layer_PR branch from 71bc124 to c433a89 Compare March 10, 2026 05:06
@stamalakhov stamalakhov requested a review from mhs4670go March 10, 2026 05:15
This PR adds `static_data_inside_the_layer` option to the scipt
to have `attention_mask` and `position_embeddings` as inputs.

TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
@stamalakhov stamalakhov force-pushed the three_inputs_layer_PR branch from c433a89 to e9f1ac2 Compare March 10, 2026 07:56
Copy link
Contributor

@mhs4670go mhs4670go left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mhs4670go mhs4670go merged commit 94ff52a into Samsung:main Mar 10, 2026
7 checks passed
@stamalakhov stamalakhov deleted the three_inputs_layer_PR branch March 10, 2026 08:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants