I made this tool because the default dynamic ONNX quantization, while fast, often degrades the performance of finetuned models, especially those with subgraphs. Excluding a few nodes restores performance instantly without keeping everything in FP32 and sacrificing inference speed. To do that, this script employs a two-stage workflow to identify the optimal set of nodes to exclude from quantization in ONNX encoder-decoder models. Still WIP and supports only encoder-decoder models.
- Two-Stage Optimization Workflow: A systematic approach to first identify sensitive operation types and then prune the exclusion list to a minimal set.
- Broad Compatibility: Designed for encoder-decoder models and supports both encoder, decoder, or combined searches.
- Multiple Search Strategies: Offers various strategies (
first,best,percent) to discover the initial set of sensitive nodes, providing flexibility for different optimization needs. - Performance Benchmarking: Integrated benchmarking to evaluate model performance using a suite of metrics (WER, CER, BLEU, ROUGE) and compare it against FP32 and fully quantized baselines.
- Automated Model Export: The script can automatically export the final, optimally quantized models for immediate use.
The script automates the process of discovering which nodes in an ONNX model are most sensitive to quantization. It does this by:
- Establishing Baselines: It first benchmarks the original FP32 model and a fully quantized version to establish performance boundaries.
- Stage 1: Sensitive Operator Discovery: The script iteratively excludes nodes by their operator type (e.g.,
MatMul,Add,Gelu) and measures the impact on a primary performance metric. This stage identifies a "tipping point" operator, where excluding it and all preceding operator types brings the model's performance to a desired level. - Stage 2: Pruning: Starting with the broad list of nodes from Stage 1, this stage systematically re-includes nodes one by one to find the smallest possible set of excluded nodes that maintains the performance gains.
- Final Evaluation: The script provides a final benchmark of the partially quantized model and a detailed comparison against the baselines. Supports multiprocessing to speed up the process.
- Python 3.8+
- CUDA >=12.6
- ONNX Runtime
- Transformers
- Optimum
- Additional libraries for metrics:
jiwerfor WER/CER andevaluatefor BLEU/ROUGE.
I recommend using a RAM disk for quant_test_dir, as the process can write several terabytes of data in a single run and prematurely wear out your SSD/HDD.
-
Clone the repository:
git clone https://github.com/AdamCodd/ONNX-Quant-Optimize.git cd ONNX-Quant-Optimize -
Install the required packages:
pip install -r requirements-cpu.txt
or if you're testing on GPU:
bash pip install -r requirements-gpu.txt
-
Configure your settings: Create a
config_quant.jsonfile to specify model paths, search strategies, and other parameters. A documented example of this file is provided in the repo. -
Provide evaluation samples: Create a
samples.jsonlfile with prompts and ground truth references for benchmarking. Each line should be a JSON object:{"input": "Your input prompt here.", "ground_truth": "The expected output."} -
Run the script:
python discover_quant_exclusion.py --config config_quant.json
Export the quantized ONNX encoder/decoder (optionally):
python discover_quant_exclusion.py --config config_quant.json --export_final path/to/your/directory
The script is controlled by a config_quant.json file. Here are some of the key options:
"candidate_op_types": A list of ONNX operator types to consider for exclusion during the search."enable_subgraph": A boolean that corresponds to the EnableSubgraph extra option in ONNX Runtime quantization."execution_provider": The ONNX Runtime execution provider to use for inference (e.g., "CPUExecutionProvider" or "CUDAExecutionProvider")."fp32_decoder": The filename of the FP32 ONNX decoder model inside the "onnx_dir"."fp32_encoder": The filename of the FP32 ONNX encoder model inside the "onnx_dir"."max_generation_length": The maximum number of new tokens to generate during the evaluation benchmark."max_nodes_to_exclude": An optional integer. If set, the Stage 2 pruning process will stop early if the number of excluded nodes becomes less than or equal to this value."metrics": A list of metrics to evaluate. Options: "wer", "cer", "bleu", "rouge"."model_dir": Path to the directory containing the source model (e.g., in Hugging Face format)."model_reference": The format of the reference model. Options:safetensors,pytorchoronnx-fp32."multiprocessing": A boolean flag to enable or disable the use of multiple processes for parallelizing the quantization and benchmarking tasks."onnx_dir": Directory to save the exported ONNX models."primary_metric": The main metric to use for optimization decisions."quant_test_dir": A temporary directory for intermediate quantized models."quant_type": The target data type for quantization. Supported options are "QInt8" and "QUInt8"."resume": A boolean. If true, the script will attempt to load a _search_state.json file from the quant_test_dir to resume a previously interrupted search."samples_jsonl": The file path to a JSONL file containing samples for evaluation. Each line should be a JSON object with input prompts and ground truth references."search_target": The part of the model to search. Options:"encoder","decoder","both"."metrics": A list of metrics to evaluate. Options:"wer","cer","bleu","rouge"."primary_metric": The main metric to use for optimization decisions."strategy_stage1": The strategy for the first stage of the search. Options:"first": Stops at the first operator type that improves the score."best": Tries all operator types and picks the one with the best cumulative score."percent": Aims to recover a certain percentage of the performance gap between the fully quantized and FP32 models.
"strategy_stage2": The strategy for the pruning stage. Options:"relaxed": A node is kept excluded if removing it doesn't degrade performance below the "tipping point" score from Stage 1."strict": A node is kept excluded only if removing it degrades the current best score.
"target": A float value (e.g., 0.5 for 50%) used only with the "percent" strategy for Stage 1. It defines the target percentage of the performance gap to recover."task": A string describing the model's task, used for metadata (e.g., "text2text-generation")."with-past": A boolean to indicate whether the model should be exported with support for past key-value caches for faster generation. Only for decoders."workers": The number of worker processes to use when "multiprocessing" is enabled. If set to null, it'll automatically usemax(1, os.cpu_count() // 2)workers.
After running, the script will output the recommended exclusion lists for the encoder and decoder, along with a summary table comparing the performance of the different model versions:
========================= FINAL RESULTS =========================
Found minimal exclusion list for DECODER with 8 nodes.
RECOMMENDATION: Use this list for the 'nodes_to_exclude' argument when quantizing the decoder.
--------------------------------------------------------------------------------
DECODER_NODES_TO_EXCLUDE = [
'/model/decoder/layers/0/encoder_attn/MatMul',
'/model/decoder/layers/1/encoder_attn/MatMul',
...
]
--------------------------------------------------------------------------------
No beneficial nodes to exclude were found for the ENCODER.
📊============== FINAL SUMMARY ==============📊
Model Time (s) ΔTime WER CER ΔWER ΔCER Filesize
---------------------------------------------------------------------------------------------------------------------------------
FP32 Reference 15.2345 - 0.1234 0.0456 - - 950.23MB
QUInt8 Dynamic (Fully Quantized) 8.1234 -7.111s (x0.53) 0.1876 0.0789 +0.0642 +0.0333 240.12MB
QUInt8 Dynamic (Partially Quantized) 8.5678 -6.667s (x0.56) 0.1255 0.0467 +0.0021 +0.0011 245.34MB
- Only encoder-decoders models. It needs to be extended to encoder and decoder only models.