This repository contains a full pipeline for predicting 30-day readmission of diabetic patients using heterogeneous graph neural networks (HGT, R-GCN, GraphSAGE) implemented with PyTorch Geometric. The notebook main.ipynb orchestrates the end-to-end workflow while the supporting Python packages under data/, graph/, models/, train/, and evaluation/ let you script and reuse each stage independently.
Final reports and figures saved in a dedicated reports/ folder. (Figures, csv, configs, final report)
Python 3.10+ is recommended. Install dependencies with:
python -m pip install -r requirements.txt
Key packages include:
- PyTorch + PyTorch Geometric (and optional DGL neighbor loaders). Uncomment the wheel index in
requirements.txtthat matches your CUDA/Torch combo before installing. - Scientific Python stack: numpy, pandas, scipy, scikit-learn, matplotlib, tensorboard, numexpr, joblib, tqdm, tqdm-joblib.
- Configuration & utilities: pydantic (v1), omegaconf.
- Visualisation & data exploration: plotly, dcor, phik.
- Baseline/tabular models: xgboost, catboost.
Note: PyTorch Geometric requires CUDA-specific wheels. Follow https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html if you need GPU support.
main.ipynb/main.py– orchestrate the full experiment (data prep → graph building → model training → evaluation).data/– loaders, filters, preprocessing, vocabularies, and split helpers.graph/– hetero graph builders plus inductive inference helpers.models/– HGT, R-GCN, and GraphSAGE backbones plus classification heads.train/– data loaders, optimizer utilities, and the AMP-aware training loop.evaluation/– calibration, metrics, reporting utilities, and rich data exploration plots.benchmarks/– fast tabular baselines (logistic regression, random forest, XGBoost, CatBoost, KNN).grid_search/– reusable grid-search runners for graph models and tabular models.infer/– helpers for batch inference on new encounters.utils/– config schema, logging, artifact helpers, IO utilities, and system tuning helpers.tests/– lightweight unit tests for data vocabularies, graph builders, sampling, and evaluation helpers.raw/&data/– expected locations for the CSV inputs and cached/intermediate artifacts.
- Download the public diabetes readmission dataset plus
IDS_mapping.csvand place them underraw/(see the defaults inconfig_dictinsidemain.ipynb). - Adjust the configuration JSON in Section 2 of the notebook (or inside
main.py) to point to your data files, tweak graph settings, and pick a backbone architecture. - Run
main.ipynbin Jupyter/Colab (or mirror the sequence inside your own script). The numbered cells:- Load the CSVs, filter encounters, and build train/validation/test splits that respect patient grouping.
- Preprocess features (imputation, rare-category handling, one-hot encoding) and persist scalers/encoders.
- Construct heterogeneous graphs via
graph.builder.build_heterodatafor each split. - Instantiate a GNN from
models/and train it with the AMP-awaretrain.loop.Trainer, logging metrics to TensorBoard. - Evaluate predictions with
evaluation.evaluation(metrics, threshold tuning, calibration, subgroup reports) and generate interactive figures fromevaluation.model_evaluator/evaluation.data_exploration. - Optionally train tabular baselines (
benchmarks/train_and_eval_baselines) for comparison.
- Artifacts (model checkpoints, configs, scalers, calibration models, and plots) are stored under
artifacts/.utils.artifacts.save_best_artifactkeeps alatestpointer for quick reuse. - To drive the pipeline from a pure Python script, import
Configfromutils.config, call the functions indata/andgraph/, and instantiatetrain.loop.Trainerwith loaders fromtrain/loader.py.
graph.inductive.build_star_graph_for_row reuses the same vocabularies to build a star graph around a single encounter. This lets you score unseen patients by loading the saved vocabs, creating the star graph on-the-fly, and running the trained GNN head. The infer/ package contains helpers for batching predictions and exporting CSVs.
grid_search/tabular.pyandgrid_search/rgcn_fast.pyprovide joblib+tqdm powered sweeps with optional CUDA MPS helpers. Configure them by passing theConfigobject and a data factory callable.benchmarks/baselines.pytrains Logistic Regression, Random Forest, XGBoost, CatBoost, and KNN references to contextualise GNN performance.
Basic unit tests are provided in the tests/ directory. You can run these tests to verify that:
- Unknown token handling in vocab works (
test_vocabs.py). - Graph builder creates expected nodes/edges and adds reverse edges correctly (
test_graph_builder.py). - Group splits have no patient overlap (
test_splits.py).
Run the full suite with:
pytest tests
This notebook and codebase are designed for clarity and completeness. For actual production use, some optimization and tuning might be necessary. The model variants (HGT, R-GCN, GraphSAGE) are all implemented; you can switch the model.arch in the config to try different GNN types. Calibration and threshold selection are performed on validation data to optimize final performance.