Skip to content

[quantization] Quantization of Llama#492

Merged
mhs4670go merged 1 commit intoSamsung:mainfrom
stamalakhov:quant_full_model_PR
Feb 20, 2026
Merged

[quantization] Quantization of Llama#492
mhs4670go merged 1 commit intoSamsung:mainfrom
stamalakhov:quant_full_model_PR

Conversation

@stamalakhov
Copy link
Contributor

@stamalakhov stamalakhov commented Feb 13, 2026

This PR quantizes the full LLama model and converts it to circle format.

Log of `python tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py --model Maykeye/TinyLLama-v0 --save_circles_to_folder "." --max_seq_len 2048`

Namespace(model='Maykeye/TinyLLama-v0', device='cuda', dtype='float32', seed=42, trust_remote_code=False, hf_token=None, no_tqdm=False, no_GPTQ=False, no_PTQ=False, save_circle_to_folder='.', cache_dir='/mnt/storage/transformers_cache', nsamples_for_qcalibration=128, linear_weight_bits=4, gptq_mse=False, max_seq_len=2048, embedding_weight_bits=8, lm_head_weight_bits=4, eval_tasks=None)
=== Config ===
Model            : Maykeye/TinyLLama-v0
Device           : cuda
DType            : float32

Loading FP model …
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.

Calculating original perplexities …
Token indices sequence length is longer than the specified maximum sequence length for this model (324381 > 2048). Running this sequence through the model will result in indexing errors
PPL:  99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 158/159 [00:04<00:00, 37.60it/s]

┌── Wikitext-2 test perplexity ─────────────
│ FP32 :  7584.31
└───────────────────────────────────────────
Applying GPTQ …
Quantizing layers: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:08<00:00,  1.07s/layer]
Wrapping layers with PTQWrapper …                                                                                                                                                                                                                                                       
Calibrating PTQ obeservers…
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [00:31<00:00,  4.12it/s]

Calculating perplexities …
PPL:  99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 158/159 [00:35<00:00,  4.47it/s]

┌── Wikitext-2 test perplexity ─────────────
│ int16 :  7410.80
└───────────────────────────────────────────
saving input embedding to /mnt/storage/slow_repos/TICO/embedding.q.circle
saving lm_head to /mnt/storage/slow_repos/TICO/lm_head.q.circle
saving layers
saving model layer_0 to /mnt/storage/slow_repos/TICO/decoder_layer_0.q.circle
saving model layer_1 to /mnt/storage/slow_repos/TICO/decoder_layer_1.q.circle
saving model layer_2 to /mnt/storage/slow_repos/TICO/decoder_layer_2.q.circle
saving model layer_3 to /mnt/storage/slow_repos/TICO/decoder_layer_3.q.circle
saving model layer_4 to /mnt/storage/slow_repos/TICO/decoder_layer_4.q.circle
saving model layer_5 to /mnt/storage/slow_repos/TICO/decoder_layer_5.q.circle
saving model layer_6 to /mnt/storage/slow_repos/TICO/decoder_layer_6.q.circle
saving model layer_7 to /mnt/storage/slow_repos/TICO/decoder_layer_7.q.circle
saving model.model to /mnt/storage/slow_repos/TICO/model.model.q.circle
saving the whole model to /mnt/storage/slow_repos/TICO/model.q.circle

Draft: #436
TICO-DCO-1.0-Signed-off-by: s.malakhov s.malakhov@partner.samsung.com

@stamalakhov stamalakhov self-assigned this Feb 13, 2026
@stamalakhov stamalakhov marked this pull request as draft February 13, 2026 10:38
@stamalakhov
Copy link
Contributor Author

@mhs4670go
Should i provide tests for it and/or split in smaller PRs?

@stamalakhov stamalakhov marked this pull request as ready for review February 13, 2026 11:43
@stamalakhov stamalakhov marked this pull request as draft February 13, 2026 13:54
@stamalakhov stamalakhov force-pushed the quant_full_model_PR branch 2 times, most recently from 2dc8751 to 5f315b3 Compare February 18, 2026 11:23
@mhs4670go
Copy link
Contributor

Should i provide tests for it and/or split in smaller PRs?

No you don't have to. The exmaple scirpt can be a test by itself.

@stamalakhov stamalakhov changed the title [quantization][draft] Quantization of Llama [quantization] Quantization of Llama Feb 19, 2026
@stamalakhov stamalakhov marked this pull request as ready for review February 19, 2026 13:43
@stamalakhov
Copy link
Contributor Author

@mhs4670go
not sure about CI failure check-style.

@mhs4670go
Copy link
Contributor

mhs4670go commented Feb 20, 2026

not sure about CI failure check-style.

This is a problem of UFMT. It uses deprecated feature in Python 3.14. Let's try to use ufmt==2.9.1. I'll post a PR for this.

@mhs4670go mhs4670go mentioned this pull request Feb 20, 2026
@mhs4670go
Copy link
Contributor

@stamalakhov Could you rebase the PR?

@stamalakhov
Copy link
Contributor Author

@stamalakhov Could you rebase the PR?

@mhs4670go
Rebased, but still no luck 😢

@mhs4670go
Copy link
Contributor

mhs4670go commented Feb 20, 2026

Then, please set the python version 3.12.

check-style:
    runs-on: ubuntu-22.04
    steps:
      - uses: actions/checkout@v4
      - uses: actions/setup-python@v5
        with:
          python-version: '3.x'  # HERE

      - name: "Run configure"
        run: |
          ./ccex configure format

      - name: "Run linters"
        run: |
          ./ccex format --no-apply-patches

@stamalakhov
Copy link
Contributor Author

Then, please set the python version 3.12.

@mhs4670go
Thank you very much. Finally the problem was shown. I think i should revert set python version.

This PR quantizes the full `LLama` model and converts it to circle format.

TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
Comment on lines +188 to +190
# to prevent introduction of attention_mask as a parameter let's use preset attention_mask
L = hidden_states.size(1)
attention_mask = self._slice_causal(L, hidden_states.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate this change? Is it necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now. Without setting explicit attention mask to model, LlamaModel sets its own default casual mask and it has wrong dimensions.
I'll try to set it explicitely in conversion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go
This change is not necessary. I'll remove it. Thank you!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without setting explicit attention mask to model, LlamaModel sets its own default casual mask and it has wrong dimensions.

Ah, I understood.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go
In case attention_mask is set explicitely like this:

attention_mask = torch.ones(1, q_m.config.max_position_embeddings, dtype=torch.bool)
cm = tico.convert(q_m, (calib_inputs[0], attention_mask), strict=False)
cm.save(save_path)

_prepare_4d_causal_attention_mask_with_cache_position of transformers will turn it into floats. So we need to quantize it inside QuantLLamaDecoderLayer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go
Sorry. Right now we need this change 😢 .
Attention_mask if set outside has some non-trivial transforms inside transformers, each of them is unquantized.
Setting attention_mask to None in conversion is the same.
So resetting it explicitely in quant_decoder_layer.py breaks any dependency on unquantized attention_mask and produces correct model without any floats.
Please correct me if i'm wrong.

Copy link
Contributor

@mhs4670go mhs4670go left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mhs4670go mhs4670go merged commit 4ad84c7 into Samsung:main Feb 20, 2026
7 checks passed
@stamalakhov stamalakhov deleted the quant_full_model_PR branch February 20, 2026 08:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants