Bayesian uncertainty estimation for multilabel chest-radiograph classification
on the TAIX-Ray cohort. The classifier is a DINOv2-ViT-S backbone with the
transformer's dropout layers replaced by Monte-Carlo dropout (kept active at
inference) so that the predictive variance over N stochastic forward passes
quantifies epistemic uncertainty for each finding.
BayesianCXRAgent/
├── cxr/ # Installable package (model + data code)
│ ├── data/
│ │ ├── datasets/cxr_dataset.py # TAIX-Ray loader (PNG + CSV metadata)
│ │ └── datamodules/datamodule.py # PyTorch-Lightning DataModule
│ └── models/
│ ├── base_model.py # BasicClassifier / BasicRegression
│ ├── mst.py # DINOv2-ViT-S + MCDropout (`MST`, `MSTRegression`)
│ ├── resnet.py # ResNet baseline
│ └── utils/losses.py # CORN / CE multi-head losses
├── scripts/
│ ├── main_train.py # Train the classifier (single run or subset sweep)
│ ├── uncertainty_analysis.py # MC-dropout inference -> predictions+variance .npz
│ ├── uncertainty_plots.py # Pure plotting helpers (no torch dependency)
│ ├── plot_uncertainty_analysis.py # Regenerate plots from the cached .npz
│ ├── analyze_subset_results.py # `uncertainty_by_class_and_subset.png`
│ └── plot_subset_losses.py # `loss_and_variance_curves.png` (pulls from W&B)
├── data/TAIX-Ray/ # Drop the dataset here (see "Data" below)
├── pyproject.toml
└── requirements.txt
python -m venv .venv && source .venv/bin/activate
pip install --upgrade pip
pip install -r requirements.txt
pip install -e .The DINOv2 backbone is pulled from torch.hub on first run, so the install
host needs internet access.
The code expects the TAIX-Ray release at data/TAIX-Ray/ (relative to the
repo root) with this layout:
data/TAIX-Ray/
├── metadata/
│ ├── split.csv # columns: UID, Fold, Split
│ └── annotation.csv # columns: UID, PatientID, ..., HeartSize, PulmonaryCongestion, PleuralEffusion_Left, ...
└── data_png_resize_512/
└── <UID>.png # 512-px long-edge PNG, gray
The dataset is publicly available on Hugging Face at
TLAIM/TAIX-Ray. The
included loader assumes the default (512px) configuration. If you keep the
data elsewhere, point to it with the environment variable:
export TAIX_RAY_PATH=/path/to/TAIX-RayEach class is annotated on a 0-4 ordinal scale (HeartSize: 0-3). For the
binary training objective the loader binarises grades with a threshold of
> 0; raw grades are preserved on disk and re-attached to predictions at
analysis time so that annotator-uncertain cases (grade 1) can be inspected
separately.
All scripts run from the repository root. Weights & Biases logging is wired
into the training script - log in once with wandb login or set
WANDB_MODE=offline if you don't want to log.
Single full-data run (saves checkpoints, logs uncertainty on the validation split every 5 epochs):
python scripts/main_train.py --model MST --task binary --dropout_p 0.05Subset sweep (drives the loss/variance figure: spawns one worker per subset
size, distributes them over --num_gpus GPUs):
python scripts/main_train.py \
--subset_training \
--subset_sizes 1.0 5.0 10.0 20.0 50.0 100.0 \
--num_gpus 3 \
--model MST --task binary --dropout_p 0.05Outputs are written under runs/MST_<timestamp>/ (single) or
runs/subset_experiment_<timestamp>/subset_<size>pct/ (sweep).
Compute MC-dropout predictive variance over the validation split and write the cache used by the plotting scripts:
python scripts/uncertainty_analysis.py \
--model_path runs/MST_<timestamp>/<best>.ckpt \
--num_samples 30 \
--max_images 100000 # >= |val| -> processes the full validation splitOutput: results/uncertainty_analysis/
uncertainty_cache.npz(predictions, variances, raw grades, UIDs)reliability_diagram.pngentropy_vs_error.pnguncertainties_vs_standard_pred.pngpred_variance_histogram.pngmiscalibrated_images/(high-error / high-uncertainty cases)
To restyle the plots without re-running inference, edit the helpers in
scripts/uncertainty_plots.py and run:
python scripts/plot_uncertainty_analysis.pyThese read the per-subset uncertainty CSVs that the
UncertaintyCallback writes during training. The loss curves are pulled from
W&B; edit EXPERIMENT_WANDB_RUNS at the top of plot_subset_losses.py to
reference your own run IDs (or rely on the local-wandb-folder fallback).
python scripts/analyze_subset_results.py runs/subset_experiment_<timestamp>/ --save_plots
python scripts/plot_subset_losses.py runs/subset_experiment_<timestamp>/ --save_plotsThis produces uncertainty_by_class_and_subset.png and
loss_and_variance_curves.png next to the experiment.
- Backbone:
facebookresearch/dinov2 -> dinov2_vits14(frozen-shape input 448x448, gray repeated to 3 channels), classifier head is a single linear layer on the CLS token. - MC dropout: every
nn.Dropoutinside the transformer is swapped for anMCDropoutmodule that stays active at inference whenforward_bayesian(x, num_samples=N)is called. Defaultdropout_p=0.05. - Targets: 8 ordinal findings - HeartSize (0-3); PulmonaryCongestion,
PleuralEffusion {Left, Right}, PulmonaryOpacities {Left, Right},
Atelectasis {Left, Right} (each 0-4). Binarised at
> 0for training. - Loss:
BCEWithLogitsLossfor the binary head;CornLossMultifor the ordinal regression head (used byMSTRegression). - Optimizer: AdamW, lr 1e-6, weight decay 1e-2, mixed-precision (fp16).
Dataset license follows TAIX-Ray (see the Hugging Face card). If you use this code, please cite the corresponding paper alongside the TAIX-Ray release.