Skip to content

[quantization] Use QuantLlamaForCausalLM#531

Merged
mhs4670go merged 1 commit intoSamsung:mainfrom
stamalakhov:use_quant_autocasual_lm_PR
Feb 27, 2026
Merged

[quantization] Use QuantLlamaForCausalLM#531
mhs4670go merged 1 commit intoSamsung:mainfrom
stamalakhov:use_quant_autocasual_lm_PR

Conversation

@stamalakhov
Copy link
Contributor

@stamalakhov stamalakhov commented Feb 27, 2026

This PR:

  1. uses QuantLlamaForCausalLM for LLama quantization
  2. removes additional transforms of attention_mask/position_ids to prevent their exploding
  3. removes unused code
  4. adjusts tests accordingly

In short this PR uses all static data (casual_mask/position_embeddings) calculated at QuantLlamaModel level and uses inside all the layers to prevent their exploding. Currently all layers uses their own static data which induces large overhead for not-small max_seq_len (~2048).

python tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py --model HuggingFaceTB/SmolLM2-135M-Instruct --gptq_mse --max_seq_len 2048 --save_circle_to_folder "." Namespace(model='HuggingFaceTB/SmolLM2-135M-Instruct', device='cuda', dtype='float32', seed=42, trust_remote_code=False, hf_token=None, no_tqdm=False, no_GPTQ=False, no_PTQ=False, save_circle_to_folder='.', cache_dir='/mnt/storage/transformers_cache', nsamples_for_qcalibration=128, linear_weight_bits=4, gptq_mse=True, max_seq_len=2048, embedding_weight_bits=8, lm_head_weight_bits=4, eval_tasks=None) === Config === Model : HuggingFaceTB/SmolLM2-135M-Instruct Device : cuda DType : float32

Loading FP model …

Calculating original perplexities …
Token indices sequence length is longer than the specified maximum sequence length for this model (304978 > 8192). Running this sequence through the model will result in indexing errors
PPL: 99%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 148/149 [00:18<00:00, 7.86it/s]

┌── Wikitext-2 test perplexity ─────────────
│ FP32 : 17.40
└───────────────────────────────────────────
Applying GPTQ …
Quantizing layers: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [01:34<00:00, 3.15s/layer]
Wrapping layers with PTQWrapper …
Calibrating PTQ obeservers…
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [01:26<00:00, 1.48it/s]

Calculating perplexities …
PPL: 99%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 148/149 [01:15<00:00, 1.95it/s]

┌── Wikitext-2 test perplexity ─────────────
│ int16 : 27.73
└───────────────────────────────────────────
saving the whole model to /mnt/storage/slow_repos/TICO/model.q.circle

it produced model.q.circle file of 111.2 MB size instead of > 300 MB of previous version.

Draft: #495
TICO-DCO-1.0-Signed-off-by: s.malakhov s.malakhov@partner.samsung.com

@stamalakhov stamalakhov self-assigned this Feb 27, 2026
@stamalakhov stamalakhov force-pushed the use_quant_autocasual_lm_PR branch from 1023a8b to 515381a Compare February 27, 2026 08:46
This PR:
1. uses `QuantLlamaForCausalLM` for LLama quantization
2. removes additional transforms of attention_mask/position_ids to prevent their exploding
3. removes unused code
4. adjusts tests accordingly

TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
@stamalakhov stamalakhov force-pushed the use_quant_autocasual_lm_PR branch from 515381a to f2964d8 Compare February 27, 2026 09:04
@stamalakhov
Copy link
Contributor Author

@mhs4670go
Should i split it?

hidden_states.device
)
attention_mask = self._fq(attention_mask, self.obs_causal_mask)
attention_mask = attention_mask.squeeze(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we generate squeezed attention_mask whose shape is (1, L, L), do you think this squeeze would become redundant? Of course constant folding can be the same thing though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. i think it will be redundant.

@mhs4670go
Copy link
Contributor

Should i split it?

No, I think I can review them here.

Copy link
Contributor

@mhs4670go mhs4670go left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mhs4670go mhs4670go merged commit 6980493 into Samsung:main Feb 27, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants