diff --git a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py index 34281d43..e35edd2a 100644 --- a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py +++ b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py @@ -39,11 +39,6 @@ from lm_eval.utils import make_table from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.llama.modeling_llama import KwargsForCausalLM, LlamaForCausalLM -from transformers.processing_utils import Unpack - import tico from tico.quantization import convert, prepare @@ -107,60 +102,12 @@ def inject_gptq_qparams( def save_circles_to(q_m, calib_inputs, save_circle_to_folder): q_m.eval() q_m.cpu() - save_path = pathlib.Path(save_circle_to_folder, "embedding.q.circle") - pathlib.Path() - print(f"saving input embedding to {save_path.resolve()}") - with torch.no_grad(): - with SuppressWarning(UserWarning, ".*"): - cm = tico.convert( - q_m.model.embed_tokens, - (calib_inputs[0],), - strict=False, - ) - cm.save(save_path) - - save_path = pathlib.Path(save_circle_to_folder, "lm_head.q.circle") - print(f"saving lm_head to {save_path.resolve()}") - with torch.no_grad(): - with SuppressWarning(UserWarning, ".*"): - B, S, D = 1, q_m.config.max_position_embeddings, q_m.config.hidden_size - example_hidden = torch.randn(B, S, D) - cm = tico.convert( - q_m.lm_head, - (example_hidden,), - strict=False, - ) - cm.save(save_path) - - print("saving layers") - for i in range(len(q_m.model.layers)): - save_path = pathlib.Path(save_circle_to_folder, f"decoder_layer_{i}.q.circle") - print(f"saving model layer_{i} to {save_path.resolve()}") - B, S, D = 1, q_m.config.max_position_embeddings, q_m.config.hidden_size - example_hidden = torch.randn(B, S, D) - - with torch.no_grad(): - with SuppressWarning(UserWarning, ".*"): - cm = tico.convert( - q_m.model.layers[i], - (example_hidden,), - strict=False, - ) - cm.save(save_path) - - save_path = pathlib.Path(save_circle_to_folder, "model.model.q.circle") - print(f"saving model.model to {save_path.resolve()}") - with torch.no_grad(): - with SuppressWarning(UserWarning, ".*"): - cm = tico.convert(q_m.model, (calib_inputs[0],), strict=False) - - cm.save(save_path) save_path = pathlib.Path(save_circle_to_folder, "model.q.circle") print(f"saving the whole model to {save_path.resolve()}") with torch.no_grad(): with SuppressWarning(UserWarning, ".*"): - cm = tico.convert(q_m, (calib_inputs[0],), strict=False) + cm = tico.convert(q_m.wrapped, (calib_inputs[0],), strict=False) cm.save(save_path) @@ -222,13 +169,19 @@ def quantize_using_PTQ(q_m, calib_inputs, args): default_dtype=DType.int(16), default_qscheme=QScheme.PER_TENSOR_SYMM, overrides={ - "model.embeddings": { - "weight": { - "dtype": ( - DType.uint(args.embedding_weight_bits) - if args.embedding_weight_bits < 16 - else DType.int(args.embedding_weight_bits) - ), + "model": { + "embed_tokens": { + "weight": { + "dtype": ( + DType.uint(args.embedding_weight_bits) + if args.embedding_weight_bits < 16 + else DType.int(args.embedding_weight_bits) + ), + }, + }, + "layers": {}, + "norm": { + "weight": {"dtype": DType.int(16)}, }, }, "lm_head": { @@ -240,17 +193,14 @@ def quantize_using_PTQ(q_m, calib_inputs, args): ), }, }, - "model.norm": { - "weight": {"dtype": DType.int(16)}, - }, }, ) for i in range(len(q_m.model.layers)): - child_scope = f"layer{i}" - cfg.overrides[child_scope] = w_cfg # type: ignore[index] + child_scope = f"{i}" + cfg.overrides["model"]["layers"][child_scope] = w_cfg # type: ignore[index] qcfg = cfg - prepare(q_m, qcfg) + q_m = prepare(q_m, qcfg) # ------------------------------------------------------------------------- # Single-pass activation calibration @@ -260,6 +210,12 @@ def quantize_using_PTQ(q_m, calib_inputs, args): # Overwrite weight observers with GPTQ statistics if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict): inject_gptq_qparams(q_m, q_m.quantizers) + elif ( + hasattr(q_m, "wrapped") + and hasattr(q_m.wrapped, "quantizers") + and isinstance(q_m.wrapped.quantizers, dict) + ): + inject_gptq_qparams(q_m.wrapped, q_m.wrapped.quantizers) else: print( "[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection." @@ -276,83 +232,6 @@ def quantize_using_PTQ(q_m, calib_inputs, args): return q_m -def fix_inputs(model, tokenizer, input_ids): - if tokenizer.pad_token_id is not None: - pads = torch.full( - ( - input_ids.shape[0], - model.config.max_position_embeddings - input_ids.shape[1], - ), - fill_value=tokenizer.pad_token_id, - device=input_ids.device, - ) - elif tokenizer.eos_token_id is not None: - pads = torch.full( - ( - input_ids.shape[0], - model.config.max_position_embeddings - input_ids.shape[1], - ), - fill_value=tokenizer.eos_token_id, - device=input_ids.device, - ) - else: - raise RuntimeError( - "failed to pad sequence - tokenizer doesn't have pad_token_id/eos_token_id" - ) - - return torch.cat((input_ids, pads), dim=1) - - -class LLamaWithFixedInput(LlamaForCausalLM): - def __init__(self, parent: LlamaForCausalLM, tokenizer): - assert parent.config is not None, "config is a must have" - super().__init__(parent.config) - self.__dict__.update(parent.__dict__) - - def forward( - self, - input_ids: torch.LongTensor = None, # type: ignore[assignment] - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: - # fixed input size, due to position_ids fixed - orig_len = input_ids.shape[-1] - input_ids = fix_inputs(self, self.tokenizer, input_ids) - if labels is not None: - labels = fix_inputs(self, self.tokenizer, labels) - res = super().forward( - input_ids, - attention_mask, - position_ids, - past_key_values, - inputs_embeds, - labels, - use_cache, - output_attentions, - output_hidden_states, - return_dict, - cache_position, - logits_to_keep, - **kwargs, - ) - # we need to trim to the original size - res.logits = res.logits[..., :orig_len, :] - return res - - self.forward = types.MethodType(forward, self) - self.tokenizer = tokenizer - - def evaluate(q_m, tokenizer, dataset_test, args): # ------------------------------------------------------------------------- # Evaluate perplexity on Wikitext-2 @@ -360,7 +239,7 @@ def evaluate(q_m, tokenizer, dataset_test, args): print("\nCalculating perplexities …") enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt") ppl_uint8 = perplexity( - q_m, enc, args.device, stride=q_m.config.max_position_embeddings + q_m, enc, args.device, stride=q_m.wrapped.config.max_position_embeddings ) print("\n┌── Wikitext-2 test perplexity ─────────────") @@ -576,7 +455,7 @@ def main(): q_m = quantize_using_PTQ(q_m, calib_inputs, args) # after PTQ quantizer only fixed-length input sequences are valid - evaluate(LLamaWithFixedInput(q_m, tokenizer), tokenizer, dataset_test, args) + evaluate(q_m, tokenizer, dataset_test, args) if args.save_circle_to_folder is not None: save_circles_to(q_m, calib_inputs, args.save_circle_to_folder) diff --git a/tico/quantization/wrapq/examples/quantize_with_gptq.py b/tico/quantization/wrapq/examples/quantize_with_gptq.py index 56317d83..361267f4 100644 --- a/tico/quantization/wrapq/examples/quantize_with_gptq.py +++ b/tico/quantization/wrapq/examples/quantize_with_gptq.py @@ -42,7 +42,6 @@ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase - # Token-budget presets for activation calibration TOKENS: dict[str, int] = { # Smoke test (<1 min turnaround on CPU/GPU) @@ -66,6 +65,7 @@ TRAIN_SPLIT = "train" TEST_SPLIT = "test" + # ------------------------------------------------------------------------- # 1. Helper — copy GPTQ (scale, zp) into PTQ observers # ------------------------------------------------------------------------- diff --git a/tico/quantization/wrapq/quantizer.py b/tico/quantization/wrapq/quantizer.py index 901514aa..bc0c84fc 100644 --- a/tico/quantization/wrapq/quantizer.py +++ b/tico/quantization/wrapq/quantizer.py @@ -81,12 +81,18 @@ def _wrap_supported( Recursively attempt to wrap boundaries. Strictness is applied at every boundary. """ assert not isinstance(root, QuantModuleBase), "The module is already wrapped." + try: + return PTQWrapper(root, qcfg=qcfg, fp_name="model") + except NotImplementedError as e: + print("no special wrapper for model, wrappig using general case") # Case A: HuggingFace-style transformers: model.model.layers lm = getattr(root, "model", None) embeddings = ( - getattr(lm, "embed_tokens", None) if isinstance(lm, nn.Module) else None + getattr(lm, "embed_tokens", None) + if isinstance(lm.embed_tokens, nn.Module) # type: ignore[union-attr] + else None ) if isinstance(embeddings, nn.Module): child_scope = "model.embeddings" @@ -99,7 +105,11 @@ def _wrap_supported( ) lm.embed_tokens = wrapped # type: ignore[union-attr] - model_norm = getattr(lm, "norm", None) if isinstance(lm, nn.Module) else None + model_norm = ( + getattr(lm, "norm", None) + if isinstance(lm.norm, nn.Module) # type: ignore[union-attr] + else None + ) if isinstance(model_norm, nn.Module): child_scope = "model.norm" child_cfg = qcfg.child(child_scope) diff --git a/tico/quantization/wrapq/utils/metrics.py b/tico/quantization/wrapq/utils/metrics.py index acd36ea3..caa0f47d 100644 --- a/tico/quantization/wrapq/utils/metrics.py +++ b/tico/quantization/wrapq/utils/metrics.py @@ -90,10 +90,15 @@ def perplexity( input_ids_full = input_ids_full.to(device) if max_length is None: - assert hasattr(model, "config") - model_config = model.config - if hasattr(model.config, "text_config"): - model_config = model.config.text_config + if hasattr(model, "config"): + assert hasattr(model, "config") + model_config = model.config + else: + assert hasattr(model.wrapped, "config") + model_config = model.wrapped.config + + if hasattr(model_config, "text_config"): + model_config = model_config.text_config assert hasattr(model_config, "max_position_embeddings") assert isinstance(model_config.max_position_embeddings, int) max_length = model_config.max_position_embeddings diff --git a/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py b/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py index b2f14b80..23b71c34 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py @@ -191,8 +191,6 @@ def forward( # Rope tables cos, sin = position_embeddings - cos = self._fq(cos, self.obs_cos) - sin = self._fq(sin, self.obs_sin) # --- KV for attention & present_key_value ------------- present_key_value: Tuple[torch.Tensor, torch.Tensor] @@ -205,7 +203,7 @@ def forward( attention_mask = self.causal_mask_template[..., :q_len, :k_len].to( hidden_states.device ) - attention_mask = self._fq(attention_mask, self.obs_causal_mask) + attention_mask = self._fq(attention_mask, self.obs_causal_mask) attn_weights_parts = [] attn_out_parts = [] @@ -251,8 +249,9 @@ def forward( logits_i = self._fq(q_i @ k_i.transpose(-2, -1), self.obs_logits) # mask add + assert attention_mask.shape == logits_i.shape # check for compatiblity logits_i = self._fq( - logits_i + attention_mask.view(1, q_i.size(1), k_i.size(1)), + logits_i + attention_mask, self.obs_mask_add, ) diff --git a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py index 0d1fb617..5d0efbb4 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py @@ -107,6 +107,9 @@ def __init__( qcfg=post_attention_layernorm, fp_name=f"{fp_name}.post_attention_layernorm", ) + self.obs_causal_mask = self._make_obs("causal_mask") + self.obs_cos = self._make_obs("cos") + self.obs_sin = self._make_obs("sin") # Static causal mask template --------------------------------------- assert hasattr(fp_layer.self_attn, "config") and hasattr( @@ -184,18 +187,28 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - # to prevent introduction of attention_mask as a parameter let's use preset attention_mask - L = hidden_states.size(1) - attention_mask = self._slice_causal(L, hidden_states.device) - - position_embeddings = ( - self.rope_cos_template.to( - dtype=hidden_states.dtype, device=hidden_states.device - ), - self.rope_sin_template.to( - dtype=hidden_states.dtype, device=hidden_states.device - ), - ) + if attention_mask is None or attention_mask.dtype == torch.bool: + L = hidden_states.size(1) + attention_mask = self._slice_causal(L, hidden_states.device) + attention_mask = attention_mask.squeeze(0) + attention_mask = self.fq( + attention_mask, self.obs_causal_mask + ) # let it be quantized immediately + + if position_embeddings is None: + position_embeddings = ( + self.rope_cos_template.to( + dtype=hidden_states.dtype, device=hidden_states.device + ), + self.rope_sin_template.to( + dtype=hidden_states.dtype, device=hidden_states.device + ), + ) + cos, sin = position_embeddings + position_embeddings = ( + self._fq(cos, self.obs_cos), + self._fq(sin, self.obs_sin), + ) attn_out = self.self_attn( hidden_states=hidden_states, @@ -241,6 +254,7 @@ def forward( # No local observers; just recurse into children def _all_observers(self): + yield from (self.obs_causal_mask, self.obs_cos, self.obs_sin) yield from self.self_attn._all_observers() yield from self.mlp._all_observers() yield self.obs_mlp_residual_out diff --git a/tico/quantization/wrapq/wrappers/llama/quant_model.py b/tico/quantization/wrapq/wrappers/llama/quant_model.py index 7b09a8d8..c27f37b2 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_model.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_model.py @@ -203,6 +203,7 @@ def forward( hidden_states = inputs_embeds # create position_embeddings and causal_mask to be shared across all the decoder layers causal_mask = self.get_attention_mask_for(hidden_states) + causal_mask = causal_mask.squeeze(0) causal_mask = self._fq(causal_mask, self.obs_causal_mask) position_embeddings = self.get_position_embeddings_for(hidden_states) diff --git a/tico/quantization/wrapq/wrappers/registry.py b/tico/quantization/wrapq/wrappers/registry.py index f045e5fc..08d23405 100644 --- a/tico/quantization/wrapq/wrappers/registry.py +++ b/tico/quantization/wrapq/wrappers/registry.py @@ -37,6 +37,7 @@ "tico.quantization.wrapq.wrappers.llama.quant_attn_prefill", "tico.quantization.wrapq.wrappers.llama.quant_decoder_layer_prefill", "tico.quantization.wrapq.wrappers.llama.quant_mlp", + "tico.quantization.wrapq.wrappers.llama.quant_model_for_causal_lm", "tico.quantization.wrapq.wrappers.llama.quant_model", ## fairseq ## "tico.quantization.wrapq.wrappers.fairseq.quant_decoder_layer",