Skip to content

ALEX8642/wafer-defect-classifier

Repository files navigation

Wafer Defect Classifier

ResNet-18 trained on the public WM-811K wafer map dataset. 9-class spatial defect classification with calibrated confidence, Grad-CAM interpretability, and a one-click Gradio demo.

Macro-F1 0.92 · Balanced accuracy 0.91 · ECE 0.0031 (after temperature scaling)

CE baseline with TTA + thresholds: macro-F1 0.90. Focal loss retraining + CBAM channel-and-spatial attention adds +2pp across tail classes (Scratch, Loc, Random).


Results

Focal loss + CBAM + TTA + per-class thresholds

Class Precision Recall F1
Edge-Ring 0.99 0.99 0.99
none 0.99 1.00 0.99
Center 0.97 0.93 0.95
Near-full 0.97 0.93 0.95
Random 0.91 0.91 0.91
Edge-Loc 0.90 0.89 0.89
Scratch 0.84 0.87 0.86
Donut 0.86 0.86 0.86
Loc 0.90 0.79 0.84
Macro avg 0.92 0.91 0.92
CE baseline (class-weighted CE + TTA + per-class τ)
Class Precision Recall F1
Edge-Ring 0.98 0.98 0.98
none 0.99 0.99 0.99
Center 0.95 0.94 0.95
Near-full 0.93 0.93 0.93
Random 0.81 0.94 0.87
Edge-Loc 0.84 0.89 0.87
Scratch 0.73 0.92 0.82
Donut 0.83 0.90 0.86
Loc 0.75 0.84 0.79
Macro avg 0.87 0.93 0.90

Plain accuracy (0.98) is suppressed — a constant "none" predictor scores 0.85 while catching zero defects. Macro-F1 and balanced accuracy are the right metrics under 85% class imbalance.

Improvement stack: Single-pass baseline ~0.87 → CE + TTA + thresholds 0.90 → focal loss + CBAM + TTA + thresholds 0.92. Each layer is documented in docs/IMPROVEMENTS.md. Scratch F1 0.69 → 0.82 → 0.86; Loc F1 0.74 → 0.79 → 0.84.


Grad-CAM spatial interpretability

Does the model key on the physically meaningful region? Three examples:

Scratch (99.99%) Edge-Ring (100%) Center (54%)
Scratch Edge-Ring Center

Scratch: activation tightly follows the linear/arc streak — the model has learned the mechanical-damage spatial signature.

Edge-Ring: activation concentrates on the interior passing-die zone (the boundary between the intact die region and the failing ring). The model has learned "Edge-Ring = large passing interior bounded by a failing perimeter" — an inverted but valid representation.

Center: correct localisation to the lower-center cluster at 54% confidence, reflecting genuine ambiguity between Center and Loc.


Calibration & operating point

Confusion matrix Reliability diagram Threshold sensitivity
Confusion matrix Reliability diagram Threshold sensitivity

Confusion matrix: errors concentrate in the expected tail-class confusions (Loc↔Center, Scratch↔Edge-Loc) rather than leaking into "none" — escapes stay rare.

Reliability diagram: post temperature-scaling (T=0.6685) the curve tracks the diagonal closely — ECE 0.0031, calibrated confidence you can threshold on.

Threshold sensitivity: cost-weighted error across τ ∈ [0.05, 0.99] at a 10:1 escape/false-alarm ratio, locating the operating point used for per-class thresholds.


Demo

pip install -r requirements.txt
pip install -e .

# Place LSWMD.pkl in data/raw/ (download from Kaggle: wafer-map-dataset)
# then train — config defaults reproduce the headline focal+CBAM recipe (~15 min on a 5090):
python -m wafer.train

# Run the Gradio demo:
python -m wafer.demo
# → http://localhost:7860

The demo loads 9 test-set examples (one per class) at startup. Click any example to see the predicted class, calibrated confidence, process-mode interpretation, and Grad-CAM overlay in one view.


Approach

Dataset: WM-811K (Wu et al., 2015) — 811k wafer maps, 172k labeled across 9 failure-pattern classes. Binned maps (0=outside, 1=pass, 2=fail). No optical/SEM imagery — spatial defect classification only.

Preprocessing: One-hot encode {0,1,2} into 3 channels (preserves discrete semantics; scalar normalisation would imply "fail = 2× pass"). Nearest-neighbour resize to 224×224 (preserves binary values; bilinear would create intermediate values that don't exist in the domain).

Architecture: ResNet-18 from scratch — research consistently shows ResNet-18 matches or outperforms ResNet-50 on this task, attributed to the relative simplicity of binned spatial patterns vs. natural images. Three-channel first conv reused unchanged; head replaced with a 9-class linear layer. The headline model appends CBAM channel + spatial attention after each ResNet stage (+43.9k parameters, ~0.4% overhead); the gains land mostly on the tail classes (Scratch, Loc, Random).

Imbalance (85% "none"): Focal loss (γ=2, no class weights) — the modulating factor down-weights the easy, dominant "none" class and focuses learning on hard tail-class examples. Combining focal with class weights double-penalizes rare classes and destabilizes training (negative result documented in docs/IMPROVEMENTS.md). The CE baseline instead uses class-weighted cross-entropy (sklearn.compute_class_weight('balanced')).

Calibration: Temperature scaling (T=0.6685, fit on val set via LBFGS). Reduces ECE (0.0164 → 0.0031). T < 1 means the focal-trained model is underconfident — focal's down-weighting of easy examples suppresses peak confidence, so calibration sharpens rather than softens. (The CE baseline showed the usual from-scratch pattern instead: mildly overconfident, T=1.13.)

Cost-of-quality framing: Two error types with different operational costs — escape (defect predicted as none) vs. false alarm (none predicted as defect). The focal+CBAM model with TTA + per-class thresholds sits at 275 escapes / 137 false alarms — cost-weighted error 0.0835 at a 10:1 escape/FA cost ratio. The CE baseline occupied a different operating point: 54 escapes / 990 false alarms (0.0442). In other words, focal loss bought +2pp macro-F1 by trading escapes for false alarms; which point is better depends on the fab's true cost ratio, and per-class thresholds are the lever for moving along that curve. The threshold-sensitivity plot shows the sweep across τ ∈ [0.05, 0.99].


Repository layout

src/wafer/
  config.py      — WaferConfig dataclass, YAML + CLI merge
  data.py        — WM-811K loading, 70/10/20 stratified split, DataLoaders
  model.py       — ResNet-18 builder, optional CBAM attention
  train.py       — AdamW + CosineAnnealingLR + early stopping on val macro-F1
  evaluate.py    — Test-set metrics, per-class breakdown, confusion matrix
  calibrate.py   — Temperature scaling, ECE, reliability diagram, cost analysis
  explain.py     — Grad-CAM / Grad-CAM++ (hook-based, no extra deps), overlay figures
  demo.py        — Gradio demo

docs/
  ANALYSIS.md         — full narrative for technical audience
  process_modes.md    — 9-class defect → process failure mode table
  phase1_results.md   — Phase 1 test metrics

configs/baseline.yaml — training hyperparameters

Limitations

  • Binned maps only: 0/1/2 per die, not pixel-level inspection images. Defect boundaries are at die pitch resolution.
  • No fab ground truth: process-mode interpretations are illustrative QE reasoning from spatial geometry, not cause-verified claims.
  • Near-full sample size: 30 test samples; precision/recall carry wide CIs.
  • Single seed: results reflect seed=42. Macro-F1 variance is typically ±0.01–0.02.

Extensions

Two follow-up repos build on this pipeline:

wafer-ssl pushes it further on WM-811K: SimCLR contrastive pretraining on the 638k unlabeled maps plus a 4-model ensemble reaches test macro-F1 0.9423 (balanced accuracy 0.9427). An ablation against a from-scratch ensemble (0.9339) attributes ~+0.8pp to the self-supervised pretraining and the remainder to variance reduction from ensembling — which also quantifies the single-seed ±0.01–0.02 variance noted above.

wafer-mixed tests whether the pipeline generalizes: the same ResNet-18+CBAM, reframed as 8-way multi-label classification of mixed (superposed) defect patterns on MixedWM38 — test macro-F1 0.9846. A transfer study using this repo's checkpoint as donor shows pretraining pays only in the low-data regime (+8.8 macro-F1 points at 1 % of training data, a wash at ≥10 %), and per-label thresholds at the same 10:1 cost ratio used here cut label-level escapes 36 %.


References

Wu, M.-J., Jang, J.-S. R., Chen, J.-L. (2015). Wafer Map Failure Pattern Recognition and Similarity Ranking for Large-Scale Data Sets. IEEE Trans. Semiconductor Manufacturing, 28(1), 1–12.

Selvaraju, R. R., et al. (2017). Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization. ICCV 2017.

Releases

No releases published

Packages

 
 
 

Contributors

Languages