Research code for PloidyNet, a multimodal neural network for non-invasive embryo ploidy prediction from blastocyst-stage embryo images and clinical data.
This repository contains model implementation, training scripts, evaluation utilities, and data-loading code used for the PloidyNet study. The underlying embryo images, clinical variables, PGT-A-derived labels, trained model weights, sample-level predictions, and train/validation/test split files are not included because the study data contain sensitive human embryo and clinical information and are restricted by ethics approval and institutional data-governance requirements.
project_fusion/
datasets/
clinical_builder.py Build BASE11 clinical features from the label CSV
clinical_dataset.py PyTorch dataset wrapper for clinical features
image_dataset.py PyTorch dataset wrapper for embryo images
indexer.py Deterministic sample indexing and label alignment
models/
clinical_mlp.py Clinical branch
image_encoder.py EfficientNet-B0 image branch
fusion_gate_b.py Evaluation-time fusion gate
train/
train_clinical.py Train the clinical branch
train_image.py Train the image branch
eval/
eval_fusion_B0plus.py Evaluate fused predictions using the fusion gate
scripts/
check_label_csv.py Validate label CSV schema and values
check_valid_mask.py Check image valid-mask statistics
utils/
mask_circle.py Embryo-focused soft masking utilities
metrics.py Binary classification metrics
split.py Stratified split utilities
valid_ratio.py Valid-mask ratio utilities
The study dataset is not publicly available. To run the code, users must provide their own data organized according to the expected schema below.
Do not commit private data or derived sample-level artifacts to a public repository. In particular, do not upload raw embryo images, clinical CSV files, split_ids.json, preds_*.csv, valid_ratio.json, model_best.pth, manifest.json, stdout.log, or any other file containing sample identifiers, labels, predictions, paths to private data, or trained weights.
The label CSV must contain the following columns:
embryo_number
ploidy
mother_age
father_age
expansion
ICM_A
ICM_B
ICM_C
TE_A
TE_B
TE_C
risk_bad_pregnancy_history
risk_recurrent_miscarriage
TB
Validate a CSV file with:
python scripts/check_label_csv.py --csv data/labels.csvThe image loader expects embryo images under a root image directory:
data/images/
<mother_or_group_dir>/
<embryo_dir_name_like_...-SAMPLE-01>/
T0_analysis.csv
F-15/
*_RUN003.JPG
F0/
*_RUN003.JPG
F15/
*_RUN003.JPG
Embryo identifiers are parsed from embryo directory names by taking the final two hyphen-separated tokens when the last token is numeric. For example, ...-SAMPLE-01 is parsed as SAMPLE-01. These identifiers must match the embryo_number values in the label CSV.
Python 3.10 or newer is recommended.
python -m venv .venv
source .venv/bin/activate # Windows: .venv\Scripts\activate
pip install -r requirements.txtFor GPU training, install the PyTorch build that matches your CUDA version following the official PyTorch instructions, then install the remaining requirements.
Train the clinical branch:
python /project_fusion/train/train_clinical.py \
--seed $s \
--root_dir /project_fusion/data/data_crop \
--label_csv /project_fusion/data/label_tiny.csv \
--out_dir /project_fusion/runs/clinicalTrain the image branch with embryo-focused soft masking:
python /project_fusion/train/train_image.py \
--seed $s \
--root_dir /project_fusion/data/data_crop \
--label_csv /project_fusion/data/label_tiny.csv \
--out_dir /project_fusion/runs/image \
--use_mask_circle \
--unfreeze_last_k 2Evaluate the fusion gate:
python /project_fusion/eval/eval_fusion_B0plus.py \
--seed $s \
--clinical_run_dir /project_fusion/runs/clinical/seed_$s \
--image_run_dir /project_fusion/runs/image/seed_$s \
--data_root_dir /project_fusion/data/data_crop \
--label_csv /project_fusion/data/label_tiny.csv \
--b_min -1 --b_max 1 --b_step 1 \
--w_conf_min -6 --w_conf_max 0 --w_conf_step 2 \
--w_dis_min -8 --w_dis_max 0 --w_dis_step 2 \
--w_v_min 0 --w_v_max 0 --w_v_step 1The scripts write run artifacts under runs/ by default. These artifacts may contain sample identifiers, labels, predictions, and private path information; they are ignored by .gitignore and should not be made public when generated from restricted data.
- Splits are deterministic for a given seed and are stratified by the binary ploidy label.
- The current splitting utility performs embryo-level splitting, matching the study implementation.
- The image branch uses EfficientNet-B0 from
torchvision. - The fusion gate is selected on the validation set and applied to the test set without gradient-based training.
This repository is provided for research transparency and reproducibility. It is not a medical device and is not intended for clinical decision-making without appropriate validation, governance, and regulatory review.
This project is released under the MIT License. See LICENSE for details.