diff --git a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py index 404efd1d..5e002d87 100644 --- a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py +++ b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py @@ -36,6 +36,7 @@ import torch import tqdm from datasets import load_dataset +from lm_eval.utils import make_table from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.cache_utils import Cache @@ -48,6 +49,7 @@ from tico.quantization import convert, prepare from tico.quantization.config.gptq import GPTQConfig from tico.quantization.config.ptq import PTQConfig +from tico.quantization.evaluation.script.llm_tasks_eval import evaluate_llm_on_tasks from tico.quantization.wrapq.dtypes import DType from tico.quantization.wrapq.observers.affine_base import AffineObserverBase from tico.quantization.wrapq.qscheme import QScheme @@ -364,6 +366,11 @@ def evaluate(q_m, tokenizer, dataset_test, args): print(f"│ int16 : {ppl_uint8:8.2f}") print("└───────────────────────────────────────────") + if args.eval_tasks is not None: + results = evaluate_llm_on_tasks(q_m, tokenizer, args.eval_tasks) + print("Quantized RESULTS ARE:") + print(make_table(results)) + def main(): parser = argparse.ArgumentParser( @@ -521,6 +528,11 @@ def main(): print(f"│ FP32 : {ppl_fp32:8.2f}") print("└───────────────────────────────────────────") + if args.eval_tasks is not None: + results = evaluate_llm_on_tasks(model, tokenizer, args.eval_tasks) + print("Original RESULTS ARE:") + print(make_table(results)) + # ------------------------------------------------------------------------- # Prepare calibration dataset # -------------------------------------------------------------------------