Skip to content

codeFafnir/ECE228_ML_for_Physical_Applications

Repository files navigation

ECE228_ML_for_Physical_Applications

We study compressing a large PINN for mmWave channel estimation that fuses pilot-based initial estimates with an RSS map. The baseline (358.9M params, 1.4 GB) is unsuitable for edge deployment, so we explore three approaches: quantization-aware mixed precision (QAT), post-training quantization with Hadamard rotation (H-GPTQ), and physics-guided knowledge distillation (PG-KD). On a Boston ray-tracing benchmark, these methods give large memory and speed gains while retaining or improving NMSE performance.


Steps to reproduce results

Dataset Generation and Baseline PINN Training

First, navigate to the PINN_channel-estimation-main/ folder.

cd PINN_channel-estimation-main

Step 1 — Build ground-truth channel tensors from the ray-tracing CSVs

make_correct_channels.py parses the Wireless Insite data to produce a (num_snapshots, D, Nr, Nt) complex tensor.

# Boston, 15 GHz, 400 MHz
python make_correct_channels.py \
    --csv Dataset/15GHz_concatenated_data.csv \
    --out 3D_channel_15GHz_2x2_Pt50.npy \
    --pt 50 --bw 4e8

Step 2 — Generate the initial LS-OFDM channel estimates

For a given (SNR, Np) operating point, init_estimation.py simulates OFDM pilot transmission, performs LS interpolation, and saves a .npy with the same shape as Step 1. We limit ourselves to Np = 4 and run for 5 SNR values.

# SNR = 0 dB, Np = 256 (pilot_spacing = 4) — repeat for other SNR values
python init_estimation.py \
    --true-channels 3D_channel_15GHz_2x2_Pt50.npy \
    --output initial_estimate_ls_snr0.npy \
    --snr 0 \
    --n-subcarriers 1024 \
    --pilot-spacing 4

Step 3 — Train the PINN

Edit the config dict at the bottom of train.py so that smomp_file points at the initial estimate from Step 2 and accurate_file points at the ground-truth tensor from Step 1. Then:

python train.py

This trains for 500 epochs, saves the best-validation-NMSE checkpoint to name_val, and the last-epoch checkpoint to name_train. Run once per SNR, keeping accurate_file the same.


Experiment 1: QAT — Physics-Shielded Mixed-Precision Quantization Aware Training

All scripts live in experiment3_qat/.

Repository layout expected by train_qat.py

project_root/
├── PINN_channel-estimation-main/
│   ├── initial_estimate_ls_snr0.npy      # generated by Step 2
│   ├── 3D_channel_15GHz_2x2_Pt50.npy    # generated by Step 1
│   ├── ue_positions_noisy.txt
│   ├── simple_ls_0_val.pth               # generated by Step 3
│   ├── find_in_map.py                    # imported by train_qat.py
│   └── Dataset/
│       └── 50_15GHz.jpg
└── experiment3_qat/
    ├── train_qat.py                      # QAT training script
    └── Model.py                          

Method overview

Four physics-critical submodules are kept in FP32 ("physics shield") while the convolutional encoder/decoder is quantized to INT8:

Submodule Precision Reason
enc1, enc2, enc3 INT8 (per-channel weights) Bulk of FLOPs; linear loss path
dec1, dec2, dec3 INT8 (per-tensor weights) fbgemm requires per-tensor for ConvTranspose2d
skip_conv1/2, to_sequence, from_sequence INT8 Linear path; low sensitivity
rss_encoder FP32 Feeds the RSS power-matching term directly; quantization noise propagates quadratically through the physics constraint
transformer_decoder FP32 Softmax inside attention saturates under INT8 Q/K
cross_attention.multihead_attn FP32 Same attention-stability reason
final_conv FP32 Complex output projection; keeps the output manifold intact

Observer choice: activations use HistogramObserver (2048-bin histogram with percentile-clipped min/max), which gives better calibration than moving-average observers for the heavy-tailed activation distributions produced by skip connections and attention residuals. Weights use PerChannelMinMaxObserver for Conv2d/Linear and MinMaxObserver for ConvTranspose2d (fbgemm constraint).

Training uses a three-phase schedule to stabilize quantization ranges:

Phase Epochs Observers Fake-quant Effect
1 — Calibration 0 → freeze_bn_epoch ON ON Histogram ranges accumulate from training data
2 — Range freeze freeze_bn_epochfreeze_obs_epoch OFF ON Ranges locked, weights adapt to fixed grid
3 — FP32 polish freeze_obs_epoch → end OFF OFF Final fine-tuning without rounding noise

Running QAT on Colab G4

cd experiment3_qat/
python train_qat.py

The script resolves all data paths automatically (see layout above). No arguments needed.

The script will:

  1. Validate all required input files and raise a clear error if any are missing.
  2. Load simple_ls_0_val.pth as the FP32 pretrained checkpoint.
  3. Insert fake-quantization nodes with the physics-shielded QConfigMapping (HistogramObserver for activations).
  4. Fine-tune for 100 epochs using the three-phase schedule.
  5. Print an in-memory live evaluation (fake-quant on GPU) immediately after training.
  6. Export and save pinn_qat_best_val.pth and pinn_int8.pth to experiment3_qat/.
  7. Print a full side-by-side comparison table of FP32 vs INT8 NMSE, latency, disk size, and parameter footprint.

Configuring for a different SNR

Only two entries in the config dict at the bottom of train_qat.py need to change. Both paths are constructed with os.path.join(_pinn, ...) where _pinn is already resolved, so just update the filenames:

"smomp_file":            os.path.join(_pinn, "initial_estimate_ls_snrn10.npy"),
"pretrained_checkpoint": os.path.join(_pinn, "simple_ls_n10_val.pth"),

All other paths (accurate_file, user_positions_file, rss_image_path) remain the same because they are SNR-independent.

Key hyperparameters

Parameter Default Description
qat_epochs 100 Total fine-tuning epochs
qat_lr 1e-4 Initial Adam learning rate (cosine decay to 5e-6)
freeze_bn_epoch 60 Phase 1→2 transition: freeze histogram observer ranges
freeze_obs_epoch 80 Phase 2→3 transition: disable fake-quant for FP32 polish
batch_size 32 DataLoader batch size
qat_backend x86 PyTorch quantized engine (x86 for Colab/T4)

Outputs

All outputs are written to experiment3_qat/ (same directory as train_qat.py).

File Contents
pinn_qat_best_val.pth Fake-quant model state dict (best validation NMSE)
pinn_int8.pth Converted INT8 model state dict
qat_training_log.txt Epoch-by-epoch train loss, train NMSE, val NMSE

Experiment 2: GPTQ Post-Training Quantization

All scripts live in experiment1_turboquant/. Two PTQ variants are implemented:

Method Description
GPTQ Layer-by-layer second-order quantization using the OBQ formulation. One Hessian in memory at a time (MPS/CPU safe).
Hadamard-GPTQ (H-GPTQ) Applies a Hadamard rotation $W \leftarrow WH^\top$ before quantization to redistribute outliers, then absorbs the rotation into the adjacent layer at save time — zero inference overhead.

Both target INT-8 and INT-4 weight precision. Norm layers (GroupNorm, LayerNorm) are skipped. Conv2d inputs are unfolded so that all layer types are treated as linear maps for Hessian accumulation.

Running GPTQ

cd experiment1_turboquant

# GPTQ at 8-bit
python quantize_pinn.py \
    --method gptq \
    --bits 8 \
    --checkpoint ../simple_ls_0_val.pth \
    --smomp_file  ../PINN_channel-estimation-main/initial_estimate_ls_snr0.npy \
    --accurate_file ../PINN_channel-estimation-main/3D_channel_15GHz_2x2_Pt50.npy \
    --user_positions ../PINN_channel-estimation-main/ue_positions_noisy.txt \
    --rss_image ../PINN_channel-estimation-main/Dataset/50_15GHz.jpg

# Hadamard-GPTQ at 4-bit
python quantize_pinn.py \
    --method hadamard_gptq \
    --bits 4 \
    --checkpoint ../simple_ls_0_val.pth \
    ...

Evaluating GPTQ

python eval_all.py \
    --checkpoint ../simple_ls_0_val.pth \
    --smomp_file  ../PINN_channel-estimation-main/initial_estimate_ls_snr0.npy \
    --accurate_file ../PINN_channel-estimation-main/3D_channel_15GHz_2x2_Pt50.npy \
    --user_positions ../PINN_channel-estimation-main/ue_positions_noisy.txt \
    --rss_image ../PINN_channel-estimation-main/Dataset/50_15GHz.jpg \
    --device mps

Prints a comparison table: Teacher FP32 vs GPTQ vs H-GPTQ at each bit-width with NMSE (dB) and compression ratio.


Experiment 3: Physics-Guided Knowledge Distillation (PG-KD)

All scripts live in experiment2_PhysicsInformedKnowledgeDistillation/. A compact student U-Net is trained to imitate the 358.9M-parameter teacher using a four-term loss:

$$\mathcal{L} = \mathcal{L}_\text{NMSE} + \alpha,\mathcal{L}_\text{KD}^\text{soft} + \beta,\mathcal{L}_\text{physics} + \gamma,\mathcal{L}_\text{xattn}$$

Term Role
$\mathcal{L}_\text{NMSE}$ Supervised NMSE vs ground-truth channel
$\mathcal{L}_\text{KD}^\text{soft}$ Temperature-scaled MSE from teacher outputs
$\mathcal{L}_\text{physics}$ RSS power-matching constraint (same as teacher training)
$\mathcal{L}_\text{xattn}$ L2 alignment of student vs teacher cross-attention features

Student architecture replaces every Conv2d with a depthwise-separable block and the 340M FC bottleneck + 5-layer Transformer with a single 8-head cross-attention layer.

Three presets:

Preset Params Compression Val NMSE
light 36.0M 10× −19.49 dB
moderate 17.8M 20× −19.48 dB
extreme 9.3M 38.7× −18.76 dB

(Evaluated on 790-sample true holdout split, never seen during training.)

Step 1 — Precompute teacher cache

Run the frozen teacher once and save outputs + cross-attention features as FP16 memmaps. This only needs to be done once.

cd experiment2_PhysicsInformedKnowledgeDistillation

python precompute_teacher.py \
    --checkpoint ../simple_ls_0_val.pth \
    --smomp_file  ../PINN_channel-estimation-main/initial_estimate_ls_snr0.npy \
    --accurate_file ../PINN_channel-estimation-main/3D_channel_15GHz_2x2_Pt50.npy \
    --user_positions ../PINN_channel-estimation-main/ue_positions_noisy.txt \
    --rss_image ../PINN_channel-estimation-main/Dataset/50_15GHz.jpg \
    --cache_dir teacher_cache

Step 2 — Train student presets (in order)

# Light (~36M, 10x)
python train_kd.py --preset light \
    --smomp_file  ../PINN_channel-estimation-main/initial_estimate_ls_snr0.npy \
    --accurate_file ../PINN_channel-estimation-main/3D_channel_15GHz_2x2_Pt50.npy \
    --user_positions ../PINN_channel-estimation-main/ue_positions_noisy.txt \
    --rss_image ../PINN_channel-estimation-main/Dataset/50_15GHz.jpg \
    --cache_dir teacher_cache --epochs 40 --device mps

# Moderate (~18M, 20x)
python train_kd.py --preset moderate  [same flags]

# Extreme (~9.3M, 38.7x)
python train_kd.py --preset extreme   [same flags]

Checkpoints are saved to checkpoints/student_{preset}.pth. Training history is saved to checkpoints/training_history_{preset}.json for plotting.

Step 3 — Evaluate on validation split

python eval_kd.py \
    --checkpoint ../simple_ls_0_val.pth \
    --smomp_file  ../PINN_channel-estimation-main/initial_estimate_ls_snr0.npy \
    --accurate_file ../PINN_channel-estimation-main/3D_channel_15GHz_2x2_Pt50.npy \
    --user_positions ../PINN_channel-estimation-main/ue_positions_noisy.txt \
    --rss_image ../PINN_channel-estimation-main/Dataset/50_15GHz.jpg \
    --device mps

Step 4 — Create and evaluate on the true holdout set

# Create 790-sample holdout bundle (run once)
python create_holdout.py --output_dir holdout

# Evaluate on holdout
python eval_holdout.py \
    --holdout_dir holdout \
    --rss_image ../PINN_channel-estimation-main/Dataset/50_15GHz.jpg \
    --checkpoint ../simple_ls_0_val.pth \
    --checkpoint_dir checkpoints \
    --device mps

Running on Google Colab

from google.colab import drive
drive.mount('/content/drive')

BASE = "/content/drive/MyDrive/ECE228"

# Precompute (once)
!python precompute_teacher.py \
    --checkpoint {BASE}/simple_ls_0_val.pth \
    --smomp_file  {BASE}/PINN_channel-estimation-main/initial_estimate_ls_snr0.npy \
    --accurate_file {BASE}/PINN_channel-estimation-main/3D_channel_15GHz_2x2_Pt50.npy \
    --user_positions {BASE}/PINN_channel-estimation-main/ue_positions_noisy.txt \
    --rss_image {BASE}/PINN_channel-estimation-main/Dataset/50_15GHz.jpg \
    --cache_dir {BASE}/teacher_cache

# Train
!python train_kd.py --preset light \
    --smomp_file  {BASE}/PINN_channel-estimation-main/initial_estimate_ls_snr0.npy \
    --accurate_file {BASE}/PINN_channel-estimation-main/3D_channel_15GHz_2x2_Pt50.npy \
    --user_positions {BASE}/PINN_channel-estimation-main/ue_positions_noisy.txt \
    --rss_image {BASE}/PINN_channel-estimation-main/Dataset/50_15GHz.jpg \
    --cache_dir {BASE}/teacher_cache \
    --save_dir {BASE}/checkpoints \
    --epochs 40 --device cuda

About

Project Repo

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors