Skip to content

TruhnLab/BayesianCXRAgent

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

BayesianCXRAgent

Bayesian uncertainty estimation for multilabel chest-radiograph classification on the TAIX-Ray cohort. The classifier is a DINOv2-ViT-S backbone with the transformer's dropout layers replaced by Monte-Carlo dropout (kept active at inference) so that the predictive variance over N stochastic forward passes quantifies epistemic uncertainty for each finding.

Repository layout

BayesianCXRAgent/
├── cxr/                              # Installable package (model + data code)
│   ├── data/
│   │   ├── datasets/cxr_dataset.py   # TAIX-Ray loader (PNG + CSV metadata)
│   │   └── datamodules/datamodule.py # PyTorch-Lightning DataModule
│   └── models/
│       ├── base_model.py             # BasicClassifier / BasicRegression
│       ├── mst.py                    # DINOv2-ViT-S + MCDropout (`MST`, `MSTRegression`)
│       ├── resnet.py                 # ResNet baseline
│       └── utils/losses.py           # CORN / CE multi-head losses
├── scripts/
│   ├── main_train.py                 # Train the classifier (single run or subset sweep)
│   ├── uncertainty_analysis.py       # MC-dropout inference -> predictions+variance .npz
│   ├── uncertainty_plots.py          # Pure plotting helpers (no torch dependency)
│   ├── plot_uncertainty_analysis.py  # Regenerate plots from the cached .npz
│   ├── analyze_subset_results.py     # `uncertainty_by_class_and_subset.png`
│   └── plot_subset_losses.py         # `loss_and_variance_curves.png` (pulls from W&B)
├── data/TAIX-Ray/                    # Drop the dataset here (see "Data" below)
├── pyproject.toml
└── requirements.txt

Installation

python -m venv .venv && source .venv/bin/activate
pip install --upgrade pip
pip install -r requirements.txt
pip install -e .

The DINOv2 backbone is pulled from torch.hub on first run, so the install host needs internet access.

Data

The code expects the TAIX-Ray release at data/TAIX-Ray/ (relative to the repo root) with this layout:

data/TAIX-Ray/
├── metadata/
│   ├── split.csv         # columns: UID, Fold, Split
│   └── annotation.csv    # columns: UID, PatientID, ..., HeartSize, PulmonaryCongestion, PleuralEffusion_Left, ...
└── data_png_resize_512/
    └── <UID>.png         # 512-px long-edge PNG, gray

The dataset is publicly available on Hugging Face at TLAIM/TAIX-Ray. The included loader assumes the default (512px) configuration. If you keep the data elsewhere, point to it with the environment variable:

export TAIX_RAY_PATH=/path/to/TAIX-Ray

Each class is annotated on a 0-4 ordinal scale (HeartSize: 0-3). For the binary training objective the loader binarises grades with a threshold of > 0; raw grades are preserved on disk and re-attached to predictions at analysis time so that annotator-uncertain cases (grade 1) can be inspected separately.

All scripts run from the repository root. Weights & Biases logging is wired into the training script - log in once with wandb login or set WANDB_MODE=offline if you don't want to log.

1. Train the classifier

Single full-data run (saves checkpoints, logs uncertainty on the validation split every 5 epochs):

python scripts/main_train.py --model MST --task binary --dropout_p 0.05

Subset sweep (drives the loss/variance figure: spawns one worker per subset size, distributes them over --num_gpus GPUs):

python scripts/main_train.py \
    --subset_training \
    --subset_sizes 1.0 5.0 10.0 20.0 50.0 100.0 \
    --num_gpus 3 \
    --model MST --task binary --dropout_p 0.05

Outputs are written under runs/MST_<timestamp>/ (single) or runs/subset_experiment_<timestamp>/subset_<size>pct/ (sweep).

2. MC-dropout inference + reliability

Compute MC-dropout predictive variance over the validation split and write the cache used by the plotting scripts:

python scripts/uncertainty_analysis.py \
    --model_path runs/MST_<timestamp>/<best>.ckpt \
    --num_samples 30 \
    --max_images 100000          # >= |val| -> processes the full validation split

Output: results/uncertainty_analysis/

  • uncertainty_cache.npz (predictions, variances, raw grades, UIDs)
  • reliability_diagram.png
  • entropy_vs_error.png
  • uncertainties_vs_standard_pred.png
  • pred_variance_histogram.png
  • miscalibrated_images/ (high-error / high-uncertainty cases)

To restyle the plots without re-running inference, edit the helpers in scripts/uncertainty_plots.py and run:

python scripts/plot_uncertainty_analysis.py

3. Subset-sweep figures

These read the per-subset uncertainty CSVs that the UncertaintyCallback writes during training. The loss curves are pulled from W&B; edit EXPERIMENT_WANDB_RUNS at the top of plot_subset_losses.py to reference your own run IDs (or rely on the local-wandb-folder fallback).

python scripts/analyze_subset_results.py runs/subset_experiment_<timestamp>/ --save_plots
python scripts/plot_subset_losses.py     runs/subset_experiment_<timestamp>/ --save_plots

This produces uncertainty_by_class_and_subset.png and loss_and_variance_curves.png next to the experiment.

Model details

  • Backbone: facebookresearch/dinov2 -> dinov2_vits14 (frozen-shape input 448x448, gray repeated to 3 channels), classifier head is a single linear layer on the CLS token.
  • MC dropout: every nn.Dropout inside the transformer is swapped for an MCDropout module that stays active at inference when forward_bayesian(x, num_samples=N) is called. Default dropout_p=0.05.
  • Targets: 8 ordinal findings - HeartSize (0-3); PulmonaryCongestion, PleuralEffusion {Left, Right}, PulmonaryOpacities {Left, Right}, Atelectasis {Left, Right} (each 0-4). Binarised at > 0 for training.
  • Loss: BCEWithLogitsLoss for the binary head; CornLossMulti for the ordinal regression head (used by MSTRegression).
  • Optimizer: AdamW, lr 1e-6, weight decay 1e-2, mixed-precision (fp16).

License & citation

Dataset license follows TAIX-Ray (see the Hugging Face card). If you use this code, please cite the corresponding paper alongside the TAIX-Ray release.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages