We study compressing a large PINN for mmWave channel estimation that fuses pilot-based initial estimates with an RSS map. The baseline (358.9M params, 1.4 GB) is unsuitable for edge deployment, so we explore three approaches: quantization-aware mixed precision (QAT), post-training quantization with Hadamard rotation (H-GPTQ), and physics-guided knowledge distillation (PG-KD). On a Boston ray-tracing benchmark, these methods give large memory and speed gains while retaining or improving NMSE performance.
First, navigate to the PINN_channel-estimation-main/ folder.
cd PINN_channel-estimation-mainmake_correct_channels.py parses the Wireless Insite data to produce a
(num_snapshots, D, Nr, Nt) complex tensor.
# Boston, 15 GHz, 400 MHz
python make_correct_channels.py \
--csv Dataset/15GHz_concatenated_data.csv \
--out 3D_channel_15GHz_2x2_Pt50.npy \
--pt 50 --bw 4e8For a given (SNR, Np) operating point, init_estimation.py simulates OFDM pilot
transmission, performs LS interpolation, and saves a .npy with the same shape as Step 1.
We limit ourselves to Np = 4 and run for 5 SNR values.
# SNR = 0 dB, Np = 256 (pilot_spacing = 4) — repeat for other SNR values
python init_estimation.py \
--true-channels 3D_channel_15GHz_2x2_Pt50.npy \
--output initial_estimate_ls_snr0.npy \
--snr 0 \
--n-subcarriers 1024 \
--pilot-spacing 4Edit the config dict at the bottom of train.py so that smomp_file points at the
initial estimate from Step 2 and accurate_file points at the ground-truth tensor from
Step 1. Then:
python train.pyThis trains for 500 epochs, saves the best-validation-NMSE checkpoint to name_val,
and the last-epoch checkpoint to name_train. Run once per SNR, keeping accurate_file
the same.
All scripts live in experiment3_qat/.
project_root/
├── PINN_channel-estimation-main/
│ ├── initial_estimate_ls_snr0.npy # generated by Step 2
│ ├── 3D_channel_15GHz_2x2_Pt50.npy # generated by Step 1
│ ├── ue_positions_noisy.txt
│ ├── simple_ls_0_val.pth # generated by Step 3
│ ├── find_in_map.py # imported by train_qat.py
│ └── Dataset/
│ └── 50_15GHz.jpg
└── experiment3_qat/
├── train_qat.py # QAT training script
└── Model.py
Four physics-critical submodules are kept in FP32 ("physics shield") while the convolutional encoder/decoder is quantized to INT8:
| Submodule | Precision | Reason |
|---|---|---|
enc1, enc2, enc3 |
INT8 (per-channel weights) | Bulk of FLOPs; linear loss path |
dec1, dec2, dec3 |
INT8 (per-tensor weights) | fbgemm requires per-tensor for ConvTranspose2d |
skip_conv1/2, to_sequence, from_sequence |
INT8 | Linear path; low sensitivity |
rss_encoder |
FP32 | Feeds the RSS power-matching term directly; quantization noise propagates quadratically through the physics constraint |
transformer_decoder |
FP32 | Softmax inside attention saturates under INT8 Q/K |
cross_attention.multihead_attn |
FP32 | Same attention-stability reason |
final_conv |
FP32 | Complex output projection; keeps the output manifold intact |
Observer choice: activations use HistogramObserver (2048-bin histogram with
percentile-clipped min/max), which gives better calibration than moving-average observers
for the heavy-tailed activation distributions produced by skip connections and attention
residuals. Weights use PerChannelMinMaxObserver for Conv2d/Linear and MinMaxObserver
for ConvTranspose2d (fbgemm constraint).
Training uses a three-phase schedule to stabilize quantization ranges:
| Phase | Epochs | Observers | Fake-quant | Effect |
|---|---|---|---|---|
| 1 — Calibration | 0 → freeze_bn_epoch |
ON | ON | Histogram ranges accumulate from training data |
| 2 — Range freeze | freeze_bn_epoch → freeze_obs_epoch |
OFF | ON | Ranges locked, weights adapt to fixed grid |
| 3 — FP32 polish | freeze_obs_epoch → end |
OFF | OFF | Final fine-tuning without rounding noise |
cd experiment3_qat/
python train_qat.pyThe script resolves all data paths automatically (see layout above). No arguments needed.
The script will:
- Validate all required input files and raise a clear error if any are missing.
- Load
simple_ls_0_val.pthas the FP32 pretrained checkpoint. - Insert fake-quantization nodes with the physics-shielded
QConfigMapping(HistogramObserver for activations). - Fine-tune for 100 epochs using the three-phase schedule.
- Print an in-memory live evaluation (fake-quant on GPU) immediately after training.
- Export and save
pinn_qat_best_val.pthandpinn_int8.pthtoexperiment3_qat/. - Print a full side-by-side comparison table of FP32 vs INT8 NMSE, latency, disk size, and parameter footprint.
Only two entries in the config dict at the bottom of train_qat.py need to change.
Both paths are constructed with os.path.join(_pinn, ...) where _pinn is already
resolved, so just update the filenames:
"smomp_file": os.path.join(_pinn, "initial_estimate_ls_snrn10.npy"),
"pretrained_checkpoint": os.path.join(_pinn, "simple_ls_n10_val.pth"),All other paths (accurate_file, user_positions_file, rss_image_path) remain the
same because they are SNR-independent.
| Parameter | Default | Description |
|---|---|---|
qat_epochs |
100 | Total fine-tuning epochs |
qat_lr |
1e-4 | Initial Adam learning rate (cosine decay to 5e-6) |
freeze_bn_epoch |
60 | Phase 1→2 transition: freeze histogram observer ranges |
freeze_obs_epoch |
80 | Phase 2→3 transition: disable fake-quant for FP32 polish |
batch_size |
32 | DataLoader batch size |
qat_backend |
x86 |
PyTorch quantized engine (x86 for Colab/T4) |
All outputs are written to experiment3_qat/ (same directory as train_qat.py).
| File | Contents |
|---|---|
pinn_qat_best_val.pth |
Fake-quant model state dict (best validation NMSE) |
pinn_int8.pth |
Converted INT8 model state dict |
qat_training_log.txt |
Epoch-by-epoch train loss, train NMSE, val NMSE |
All scripts live in experiment1_turboquant/. Two PTQ variants are implemented:
| Method | Description |
|---|---|
| GPTQ | Layer-by-layer second-order quantization using the OBQ formulation. One Hessian in memory at a time (MPS/CPU safe). |
| Hadamard-GPTQ (H-GPTQ) | Applies a Hadamard rotation |
Both target INT-8 and INT-4 weight precision. Norm layers (GroupNorm, LayerNorm) are skipped. Conv2d inputs are unfolded so that all layer types are treated as linear maps for Hessian accumulation.
cd experiment1_turboquant
# GPTQ at 8-bit
python quantize_pinn.py \
--method gptq \
--bits 8 \
--checkpoint ../simple_ls_0_val.pth \
--smomp_file ../PINN_channel-estimation-main/initial_estimate_ls_snr0.npy \
--accurate_file ../PINN_channel-estimation-main/3D_channel_15GHz_2x2_Pt50.npy \
--user_positions ../PINN_channel-estimation-main/ue_positions_noisy.txt \
--rss_image ../PINN_channel-estimation-main/Dataset/50_15GHz.jpg
# Hadamard-GPTQ at 4-bit
python quantize_pinn.py \
--method hadamard_gptq \
--bits 4 \
--checkpoint ../simple_ls_0_val.pth \
...python eval_all.py \
--checkpoint ../simple_ls_0_val.pth \
--smomp_file ../PINN_channel-estimation-main/initial_estimate_ls_snr0.npy \
--accurate_file ../PINN_channel-estimation-main/3D_channel_15GHz_2x2_Pt50.npy \
--user_positions ../PINN_channel-estimation-main/ue_positions_noisy.txt \
--rss_image ../PINN_channel-estimation-main/Dataset/50_15GHz.jpg \
--device mpsPrints a comparison table: Teacher FP32 vs GPTQ vs H-GPTQ at each bit-width with NMSE (dB) and compression ratio.
All scripts live in experiment2_PhysicsInformedKnowledgeDistillation/. A compact student
U-Net is trained to imitate the 358.9M-parameter teacher using a four-term loss:
| Term | Role |
|---|---|
| Supervised NMSE vs ground-truth channel | |
| Temperature-scaled MSE from teacher outputs | |
| RSS power-matching constraint (same as teacher training) | |
| L2 alignment of student vs teacher cross-attention features |
Student architecture replaces every Conv2d with a depthwise-separable block and the 340M FC bottleneck + 5-layer Transformer with a single 8-head cross-attention layer.
Three presets:
| Preset | Params | Compression | Val NMSE |
|---|---|---|---|
| light | 36.0M | 10× | −19.49 dB |
| moderate | 17.8M | 20× | −19.48 dB |
| extreme | 9.3M | 38.7× | −18.76 dB |
(Evaluated on 790-sample true holdout split, never seen during training.)
Run the frozen teacher once and save outputs + cross-attention features as FP16 memmaps. This only needs to be done once.
cd experiment2_PhysicsInformedKnowledgeDistillation
python precompute_teacher.py \
--checkpoint ../simple_ls_0_val.pth \
--smomp_file ../PINN_channel-estimation-main/initial_estimate_ls_snr0.npy \
--accurate_file ../PINN_channel-estimation-main/3D_channel_15GHz_2x2_Pt50.npy \
--user_positions ../PINN_channel-estimation-main/ue_positions_noisy.txt \
--rss_image ../PINN_channel-estimation-main/Dataset/50_15GHz.jpg \
--cache_dir teacher_cache# Light (~36M, 10x)
python train_kd.py --preset light \
--smomp_file ../PINN_channel-estimation-main/initial_estimate_ls_snr0.npy \
--accurate_file ../PINN_channel-estimation-main/3D_channel_15GHz_2x2_Pt50.npy \
--user_positions ../PINN_channel-estimation-main/ue_positions_noisy.txt \
--rss_image ../PINN_channel-estimation-main/Dataset/50_15GHz.jpg \
--cache_dir teacher_cache --epochs 40 --device mps
# Moderate (~18M, 20x)
python train_kd.py --preset moderate [same flags]
# Extreme (~9.3M, 38.7x)
python train_kd.py --preset extreme [same flags]Checkpoints are saved to checkpoints/student_{preset}.pth. Training history is saved to
checkpoints/training_history_{preset}.json for plotting.
python eval_kd.py \
--checkpoint ../simple_ls_0_val.pth \
--smomp_file ../PINN_channel-estimation-main/initial_estimate_ls_snr0.npy \
--accurate_file ../PINN_channel-estimation-main/3D_channel_15GHz_2x2_Pt50.npy \
--user_positions ../PINN_channel-estimation-main/ue_positions_noisy.txt \
--rss_image ../PINN_channel-estimation-main/Dataset/50_15GHz.jpg \
--device mps# Create 790-sample holdout bundle (run once)
python create_holdout.py --output_dir holdout
# Evaluate on holdout
python eval_holdout.py \
--holdout_dir holdout \
--rss_image ../PINN_channel-estimation-main/Dataset/50_15GHz.jpg \
--checkpoint ../simple_ls_0_val.pth \
--checkpoint_dir checkpoints \
--device mpsfrom google.colab import drive
drive.mount('/content/drive')
BASE = "/content/drive/MyDrive/ECE228"
# Precompute (once)
!python precompute_teacher.py \
--checkpoint {BASE}/simple_ls_0_val.pth \
--smomp_file {BASE}/PINN_channel-estimation-main/initial_estimate_ls_snr0.npy \
--accurate_file {BASE}/PINN_channel-estimation-main/3D_channel_15GHz_2x2_Pt50.npy \
--user_positions {BASE}/PINN_channel-estimation-main/ue_positions_noisy.txt \
--rss_image {BASE}/PINN_channel-estimation-main/Dataset/50_15GHz.jpg \
--cache_dir {BASE}/teacher_cache
# Train
!python train_kd.py --preset light \
--smomp_file {BASE}/PINN_channel-estimation-main/initial_estimate_ls_snr0.npy \
--accurate_file {BASE}/PINN_channel-estimation-main/3D_channel_15GHz_2x2_Pt50.npy \
--user_positions {BASE}/PINN_channel-estimation-main/ue_positions_noisy.txt \
--rss_image {BASE}/PINN_channel-estimation-main/Dataset/50_15GHz.jpg \
--cache_dir {BASE}/teacher_cache \
--save_dir {BASE}/checkpoints \
--epochs 40 --device cuda