Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 25 additions & 146 deletions tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@
from lm_eval.utils import make_table
from transformers import AutoModelForCausalLM, AutoTokenizer

from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import KwargsForCausalLM, LlamaForCausalLM
from transformers.processing_utils import Unpack

import tico

from tico.quantization import convert, prepare
Expand Down Expand Up @@ -107,60 +102,12 @@ def inject_gptq_qparams(
def save_circles_to(q_m, calib_inputs, save_circle_to_folder):
q_m.eval()
q_m.cpu()
save_path = pathlib.Path(save_circle_to_folder, "embedding.q.circle")
pathlib.Path()
print(f"saving input embedding to {save_path.resolve()}")
with torch.no_grad():
with SuppressWarning(UserWarning, ".*"):
cm = tico.convert(
q_m.model.embed_tokens,
(calib_inputs[0],),
strict=False,
)
cm.save(save_path)

save_path = pathlib.Path(save_circle_to_folder, "lm_head.q.circle")
print(f"saving lm_head to {save_path.resolve()}")
with torch.no_grad():
with SuppressWarning(UserWarning, ".*"):
B, S, D = 1, q_m.config.max_position_embeddings, q_m.config.hidden_size
example_hidden = torch.randn(B, S, D)
cm = tico.convert(
q_m.lm_head,
(example_hidden,),
strict=False,
)
cm.save(save_path)

print("saving layers")
for i in range(len(q_m.model.layers)):
save_path = pathlib.Path(save_circle_to_folder, f"decoder_layer_{i}.q.circle")
print(f"saving model layer_{i} to {save_path.resolve()}")
B, S, D = 1, q_m.config.max_position_embeddings, q_m.config.hidden_size
example_hidden = torch.randn(B, S, D)

with torch.no_grad():
with SuppressWarning(UserWarning, ".*"):
cm = tico.convert(
q_m.model.layers[i],
(example_hidden,),
strict=False,
)
cm.save(save_path)

save_path = pathlib.Path(save_circle_to_folder, "model.model.q.circle")
print(f"saving model.model to {save_path.resolve()}")
with torch.no_grad():
with SuppressWarning(UserWarning, ".*"):
cm = tico.convert(q_m.model, (calib_inputs[0],), strict=False)

cm.save(save_path)

save_path = pathlib.Path(save_circle_to_folder, "model.q.circle")
print(f"saving the whole model to {save_path.resolve()}")
with torch.no_grad():
with SuppressWarning(UserWarning, ".*"):
cm = tico.convert(q_m, (calib_inputs[0],), strict=False)
cm = tico.convert(q_m.wrapped, (calib_inputs[0],), strict=False)

cm.save(save_path)

Expand Down Expand Up @@ -222,13 +169,19 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
default_dtype=DType.int(16),
default_qscheme=QScheme.PER_TENSOR_SYMM,
overrides={
"model.embeddings": {
"weight": {
"dtype": (
DType.uint(args.embedding_weight_bits)
if args.embedding_weight_bits < 16
else DType.int(args.embedding_weight_bits)
),
"model": {
"embed_tokens": {
"weight": {
"dtype": (
DType.uint(args.embedding_weight_bits)
if args.embedding_weight_bits < 16
else DType.int(args.embedding_weight_bits)
),
},
},
"layers": {},
"norm": {
"weight": {"dtype": DType.int(16)},
},
},
"lm_head": {
Expand All @@ -240,17 +193,14 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
),
},
},
"model.norm": {
"weight": {"dtype": DType.int(16)},
},
},
)
for i in range(len(q_m.model.layers)):
child_scope = f"layer{i}"
cfg.overrides[child_scope] = w_cfg # type: ignore[index]
child_scope = f"{i}"
cfg.overrides["model"]["layers"][child_scope] = w_cfg # type: ignore[index]

qcfg = cfg
prepare(q_m, qcfg)
q_m = prepare(q_m, qcfg)

# -------------------------------------------------------------------------
# Single-pass activation calibration
Expand All @@ -260,6 +210,12 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
# Overwrite weight observers with GPTQ statistics
if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict):
inject_gptq_qparams(q_m, q_m.quantizers)
elif (
hasattr(q_m, "wrapped")
and hasattr(q_m.wrapped, "quantizers")
and isinstance(q_m.wrapped.quantizers, dict)
):
inject_gptq_qparams(q_m.wrapped, q_m.wrapped.quantizers)
else:
print(
"[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection."
Expand All @@ -276,91 +232,14 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
return q_m


def fix_inputs(model, tokenizer, input_ids):
if tokenizer.pad_token_id is not None:
pads = torch.full(
(
input_ids.shape[0],
model.config.max_position_embeddings - input_ids.shape[1],
),
fill_value=tokenizer.pad_token_id,
device=input_ids.device,
)
elif tokenizer.eos_token_id is not None:
pads = torch.full(
(
input_ids.shape[0],
model.config.max_position_embeddings - input_ids.shape[1],
),
fill_value=tokenizer.eos_token_id,
device=input_ids.device,
)
else:
raise RuntimeError(
"failed to pad sequence - tokenizer doesn't have pad_token_id/eos_token_id"
)

return torch.cat((input_ids, pads), dim=1)


class LLamaWithFixedInput(LlamaForCausalLM):
def __init__(self, parent: LlamaForCausalLM, tokenizer):
assert parent.config is not None, "config is a must have"
super().__init__(parent.config)
self.__dict__.update(parent.__dict__)

def forward(
self,
input_ids: torch.LongTensor = None, # type: ignore[assignment]
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
# fixed input size, due to position_ids fixed
orig_len = input_ids.shape[-1]
input_ids = fix_inputs(self, self.tokenizer, input_ids)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, fix_inputs can be removed.

if labels is not None:
labels = fix_inputs(self, self.tokenizer, labels)
res = super().forward(
input_ids,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
labels,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position,
logits_to_keep,
**kwargs,
)
# we need to trim to the original size
res.logits = res.logits[..., :orig_len, :]
return res

self.forward = types.MethodType(forward, self)
self.tokenizer = tokenizer


def evaluate(q_m, tokenizer, dataset_test, args):
# -------------------------------------------------------------------------
# Evaluate perplexity on Wikitext-2
# -------------------------------------------------------------------------
print("\nCalculating perplexities …")
enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt")
ppl_uint8 = perplexity(
q_m, enc, args.device, stride=q_m.config.max_position_embeddings
q_m, enc, args.device, stride=q_m.wrapped.config.max_position_embeddings
)

print("\n┌── Wikitext-2 test perplexity ─────────────")
Expand Down Expand Up @@ -576,7 +455,7 @@ def main():
q_m = quantize_using_PTQ(q_m, calib_inputs, args)

# after PTQ quantizer only fixed-length input sequences are valid
evaluate(LLamaWithFixedInput(q_m, tokenizer), tokenizer, dataset_test, args)
evaluate(q_m, tokenizer, dataset_test, args)

if args.save_circle_to_folder is not None:
save_circles_to(q_m, calib_inputs, args.save_circle_to_folder)
Expand Down
2 changes: 1 addition & 1 deletion tico/quantization/wrapq/examples/quantize_with_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase


# Token-budget presets for activation calibration
TOKENS: dict[str, int] = {
# Smoke test (<1 min turnaround on CPU/GPU)
Expand All @@ -66,6 +65,7 @@
TRAIN_SPLIT = "train"
TEST_SPLIT = "test"


# -------------------------------------------------------------------------
# 1. Helper — copy GPTQ (scale, zp) into PTQ observers
# -------------------------------------------------------------------------
Expand Down
14 changes: 12 additions & 2 deletions tico/quantization/wrapq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,18 @@ def _wrap_supported(
Recursively attempt to wrap boundaries. Strictness is applied at every boundary.
"""
assert not isinstance(root, QuantModuleBase), "The module is already wrapped."
try:
return PTQWrapper(root, qcfg=qcfg, fp_name="model")
except NotImplementedError as e:
print("no special wrapper for model, wrappig using general case")

# Case A: HuggingFace-style transformers: model.model.layers
lm = getattr(root, "model", None)

embeddings = (
getattr(lm, "embed_tokens", None) if isinstance(lm, nn.Module) else None
getattr(lm, "embed_tokens", None)
if isinstance(lm.embed_tokens, nn.Module) # type: ignore[union-attr]
else None
)
if isinstance(embeddings, nn.Module):
child_scope = "model.embeddings"
Expand All @@ -99,7 +105,11 @@ def _wrap_supported(
)
lm.embed_tokens = wrapped # type: ignore[union-attr]

model_norm = getattr(lm, "norm", None) if isinstance(lm, nn.Module) else None
model_norm = (
getattr(lm, "norm", None)
if isinstance(lm.norm, nn.Module) # type: ignore[union-attr]
else None
)
if isinstance(model_norm, nn.Module):
child_scope = "model.norm"
child_cfg = qcfg.child(child_scope)
Expand Down
13 changes: 9 additions & 4 deletions tico/quantization/wrapq/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,15 @@ def perplexity(
input_ids_full = input_ids_full.to(device)

if max_length is None:
assert hasattr(model, "config")
model_config = model.config
if hasattr(model.config, "text_config"):
model_config = model.config.text_config
if hasattr(model, "config"):
assert hasattr(model, "config")
model_config = model.config
else:
assert hasattr(model.wrapped, "config")
model_config = model.wrapped.config

if hasattr(model_config, "text_config"):
model_config = model_config.text_config
assert hasattr(model_config, "max_position_embeddings")
assert isinstance(model_config.max_position_embeddings, int)
max_length = model_config.max_position_embeddings
Expand Down
7 changes: 3 additions & 4 deletions tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,6 @@ def forward(

# Rope tables
cos, sin = position_embeddings
cos = self._fq(cos, self.obs_cos)
sin = self._fq(sin, self.obs_sin)

# --- KV for attention & present_key_value -------------
present_key_value: Tuple[torch.Tensor, torch.Tensor]
Expand All @@ -205,7 +203,7 @@ def forward(
attention_mask = self.causal_mask_template[..., :q_len, :k_len].to(
hidden_states.device
)
attention_mask = self._fq(attention_mask, self.obs_causal_mask)
attention_mask = self._fq(attention_mask, self.obs_causal_mask)

attn_weights_parts = []
attn_out_parts = []
Expand Down Expand Up @@ -251,8 +249,9 @@ def forward(
logits_i = self._fq(q_i @ k_i.transpose(-2, -1), self.obs_logits)

# mask add
assert attention_mask.shape == logits_i.shape # check for compatiblity
logits_i = self._fq(
logits_i + attention_mask.view(1, q_i.size(1), k_i.size(1)),
logits_i + attention_mask,
self.obs_mask_add,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def __init__(
qcfg=post_attention_layernorm,
fp_name=f"{fp_name}.post_attention_layernorm",
)
self.obs_causal_mask = self._make_obs("causal_mask")
self.obs_cos = self._make_obs("cos")
self.obs_sin = self._make_obs("sin")

# Static causal mask template ---------------------------------------
assert hasattr(fp_layer.self_attn, "config") and hasattr(
Expand Down Expand Up @@ -184,18 +187,28 @@ def forward(
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)

# to prevent introduction of attention_mask as a parameter let's use preset attention_mask
L = hidden_states.size(1)
attention_mask = self._slice_causal(L, hidden_states.device)

position_embeddings = (
self.rope_cos_template.to(
dtype=hidden_states.dtype, device=hidden_states.device
),
self.rope_sin_template.to(
dtype=hidden_states.dtype, device=hidden_states.device
),
)
if attention_mask is None or attention_mask.dtype == torch.bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this condition come again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go
Ahhh. Sorry.

  • It was recently removed from quant_decoder_layer.py to have fully quantized model( , because causal_mask received from modeling_llama.py of transformers was float, so to have a fully integer model, the line 206 was removed).
  • This draft uses quantized causal_mask from quant_model.py so the check can be restored to have a chance to convert decoder_layer like this tico.convert(layer, (inp,)) without causal_mask in parameters.
  • In case it's left as it is (no check), all decoder layers will be populated with their own attention_masks which will be disk consuming.

L = hidden_states.size(1)
attention_mask = self._slice_causal(L, hidden_states.device)
attention_mask = attention_mask.squeeze(0)
attention_mask = self.fq(
attention_mask, self.obs_causal_mask
) # let it be quantized immediately

if position_embeddings is None:
position_embeddings = (
self.rope_cos_template.to(
dtype=hidden_states.dtype, device=hidden_states.device
),
self.rope_sin_template.to(
dtype=hidden_states.dtype, device=hidden_states.device
),
)
cos, sin = position_embeddings
position_embeddings = (
self._fq(cos, self.obs_cos),
self._fq(sin, self.obs_sin),
)

attn_out = self.self_attn(
hidden_states=hidden_states,
Expand Down Expand Up @@ -241,6 +254,7 @@ def forward(

# No local observers; just recurse into children
def _all_observers(self):
yield from (self.obs_causal_mask, self.obs_cos, self.obs_sin)
yield from self.self_attn._all_observers()
yield from self.mlp._all_observers()
yield self.obs_mlp_residual_out
Loading
Loading