Skip to content

ABChaha11/PloidyNet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PloidyNet

Research code for PloidyNet, a multimodal neural network for non-invasive embryo ploidy prediction from blastocyst-stage embryo images and clinical data.

This repository contains model implementation, training scripts, evaluation utilities, and data-loading code used for the PloidyNet study. The underlying embryo images, clinical variables, PGT-A-derived labels, trained model weights, sample-level predictions, and train/validation/test split files are not included because the study data contain sensitive human embryo and clinical information and are restricted by ethics approval and institutional data-governance requirements.

Repository Contents

project_fusion/
  datasets/
    clinical_builder.py      Build BASE11 clinical features from the label CSV
    clinical_dataset.py      PyTorch dataset wrapper for clinical features
    image_dataset.py         PyTorch dataset wrapper for embryo images
    indexer.py               Deterministic sample indexing and label alignment

  models/
    clinical_mlp.py          Clinical branch
    image_encoder.py         EfficientNet-B0 image branch
    fusion_gate_b.py         Evaluation-time fusion gate

  train/
    train_clinical.py        Train the clinical branch
    train_image.py           Train the image branch

  eval/
    eval_fusion_B0plus.py    Evaluate fused predictions using the fusion gate

  scripts/
    check_label_csv.py       Validate label CSV schema and values
    check_valid_mask.py      Check image valid-mask statistics

  utils/
    mask_circle.py           Embryo-focused soft masking utilities
    metrics.py               Binary classification metrics
    split.py                 Stratified split utilities
    valid_ratio.py           Valid-mask ratio utilities

Data Availability

The study dataset is not publicly available. To run the code, users must provide their own data organized according to the expected schema below.

Do not commit private data or derived sample-level artifacts to a public repository. In particular, do not upload raw embryo images, clinical CSV files, split_ids.json, preds_*.csv, valid_ratio.json, model_best.pth, manifest.json, stdout.log, or any other file containing sample identifiers, labels, predictions, paths to private data, or trained weights.

Expected Label CSV

The label CSV must contain the following columns:

embryo_number
ploidy
mother_age
father_age
expansion
ICM_A
ICM_B
ICM_C
TE_A
TE_B
TE_C
risk_bad_pregnancy_history
risk_recurrent_miscarriage
TB

Validate a CSV file with:

python scripts/check_label_csv.py --csv data/labels.csv

Expected Image Directory Structure

The image loader expects embryo images under a root image directory:

data/images/
  <mother_or_group_dir>/
    <embryo_dir_name_like_...-SAMPLE-01>/
      T0_analysis.csv
      F-15/
        *_RUN003.JPG
      F0/
        *_RUN003.JPG
      F15/
        *_RUN003.JPG

Embryo identifiers are parsed from embryo directory names by taking the final two hyphen-separated tokens when the last token is numeric. For example, ...-SAMPLE-01 is parsed as SAMPLE-01. These identifiers must match the embryo_number values in the label CSV.

Installation

Python 3.10 or newer is recommended.

python -m venv .venv
source .venv/bin/activate  # Windows: .venv\Scripts\activate
pip install -r requirements.txt

For GPU training, install the PyTorch build that matches your CUDA version following the official PyTorch instructions, then install the remaining requirements.

Example Workflow

Train the clinical branch:

python /project_fusion/train/train_clinical.py \
    --seed $s \
    --root_dir /project_fusion/data/data_crop \
    --label_csv /project_fusion/data/label_tiny.csv \
    --out_dir /project_fusion/runs/clinical

Train the image branch with embryo-focused soft masking:

python /project_fusion/train/train_image.py \
    --seed $s \
    --root_dir /project_fusion/data/data_crop \
    --label_csv /project_fusion/data/label_tiny.csv \
    --out_dir /project_fusion/runs/image \
    --use_mask_circle \
    --unfreeze_last_k 2

Evaluate the fusion gate:

python /project_fusion/eval/eval_fusion_B0plus.py \
    --seed $s \
    --clinical_run_dir /project_fusion/runs/clinical/seed_$s \
    --image_run_dir /project_fusion/runs/image/seed_$s \
    --data_root_dir /project_fusion/data/data_crop \
    --label_csv /project_fusion/data/label_tiny.csv \
    --b_min -1 --b_max 1 --b_step 1 \
    --w_conf_min -6 --w_conf_max 0 --w_conf_step 2 \
    --w_dis_min -8 --w_dis_max 0 --w_dis_step 2 \
    --w_v_min 0 --w_v_max 0 --w_v_step 1

The scripts write run artifacts under runs/ by default. These artifacts may contain sample identifiers, labels, predictions, and private path information; they are ignored by .gitignore and should not be made public when generated from restricted data.

Reproducibility Notes

  • Splits are deterministic for a given seed and are stratified by the binary ploidy label.
  • The current splitting utility performs embryo-level splitting, matching the study implementation.
  • The image branch uses EfficientNet-B0 from torchvision.
  • The fusion gate is selected on the validation set and applied to the test set without gradient-based training.

Responsible Use

This repository is provided for research transparency and reproducibility. It is not a medical device and is not intended for clinical decision-making without appropriate validation, governance, and regulatory review.

License

This project is released under the MIT License. See LICENSE for details.

About

Research code for PloidyNet, a multimodal neural network for non-invasive embryo ploidy prediction.

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages