Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,4 @@ Available Datasets
datasets/pyhealth.datasets.TCGAPRADDataset
datasets/pyhealth.datasets.splitter
datasets/pyhealth.datasets.utils
pyhealth.datasets.ptbxl
7 changes: 7 additions & 0 deletions docs/api/datasets/pyhealth.datasets.ptbxl.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.datasets.ptbxl
=======================

.. autoclass:: pyhealth.datasets.PTBXLDataset
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,4 @@ Available Tasks
Mutation Pathogenicity (COSMIC) <tasks/pyhealth.tasks.MutationPathogenicityPrediction>
Cancer Survival Prediction (TCGA) <tasks/pyhealth.tasks.CancerSurvivalPrediction>
Cancer Mutation Burden (TCGA) <tasks/pyhealth.tasks.CancerMutationBurden>
PTB-XL MI Classification <tasks/pyhealth.tasks.ptbxl_mi_classification>
7 changes: 7 additions & 0 deletions docs/api/tasks/pyhealth.tasks.ptbxl_mi_classification.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.tasks.ptbxl_mi_classification
======================================

.. autoclass:: pyhealth.tasks.PTBXLMIClassificationTask
:members:
:undoc-members:
:show-inheritance:
27 changes: 27 additions & 0 deletions examples/ptbxl_mi_classification_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from pyhealth.datasets import PTBXLDataset
from pyhealth.tasks import PTBXLMIClassificationTask
import os

def main():
root = os.path.expanduser(
"~/Downloads/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3"
)
dataset = PTBXLDataset(
root=root,
dev=True,
use_high_resolution=False, # False -> records100, True -> records500
)

task = PTBXLMIClassificationTask(
root=root,
signal_length=1000, # 10 seconds at 100 Hz
normalize=True,
)
task_dataset = dataset.set_task(task)

print(task_dataset[0])
print(f"Number of samples: {len(task_dataset)}")


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,4 @@ def __init__(self, *args, **kwargs):
save_processors,
)
from .collate import collate_temporal
from .ptbxl import PTBXLDataset
49 changes: 49 additions & 0 deletions pyhealth/datasets/ptbxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
from typing import Optional

import dask.dataframe as dd
import pandas as pd

from pyhealth.datasets import BaseDataset


class PTBXLDataset(BaseDataset):
"""PTB-XL ECG dataset represented as an event table."""

def __init__(
self,
root: str,
dataset_name: Optional[str] = "PTBXL",
dev: bool = False,
cache_dir: Optional[str] = None,
num_workers: int = 1,
use_high_resolution: bool = False,
):
self.use_high_resolution = use_high_resolution
super().__init__(
root=root,
tables=["ptbxl"],
dataset_name=dataset_name,
cache_dir=cache_dir,
num_workers=num_workers,
dev=dev,
)

def load_data(self) -> dd.DataFrame:
metadata_path = os.path.join(self.root, "ptbxl_database.csv")
df = pd.read_csv(metadata_path)

record_path_col = "filename_hr" if self.use_high_resolution else "filename_lr"

event_df = pd.DataFrame(
{
"patient_id": df["patient_id"].astype(str),
"event_type": "ptbxl",
"timestamp": pd.NaT,
"ptbxl/ecg_id": df["ecg_id"],
"ptbxl/record_path": df[record_path_col],
"ptbxl/scp_codes": df["scp_codes"],
}
)

return dd.from_pandas(event_df, npartitions=1)
1 change: 1 addition & 0 deletions pyhealth/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,4 @@
VariantClassificationClinVar,
)
from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task
from .ptbxl_mi_classification import PTBXLMIClassificationTask
135 changes: 135 additions & 0 deletions pyhealth/tasks/ptbxl_mi_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""PTBXL MI classification task for PyHealth.

This module defines a task that loads PTB-XL ECG records, maps SCP
diagnostic codes to myocardial infarction (MI) labels, and returns one
binary-labeled sample per record.
"""

import ast
import os
from typing import Dict, List

import numpy as np
import pandas as pd
import wfdb

from pyhealth.tasks import BaseTask


class PTBXLMIClassificationTask(BaseTask):
"""Task for classifying myocardial infarction (MI) in PTB-XL ECG records.

This task converts the PTB-XL SCP diagnostic codes into a binary MI label
and loads the corresponding ECG signal for each record.

Attributes:
task_name (str): The name of the task.
input_schema (Dict[str, str]): Input schema mapping signal to tensor.
output_schema (Dict[str, str]): Output schema mapping label to binary.
"""

task_name = "ptbxl_mi_classification"
input_schema = {
"signal": "tensor",
}
output_schema = {
"label": "binary",
}

def __init__(
self,
root: str,
signal_length: int = 1000,
normalize: bool = True,
):
"""Initialize the PTBXL MI classification task.

Args:
root: PTB-XL dataset root directory containing `scp_statements.csv`.
signal_length: Number of samples to use for each ECG signal.
normalize: Whether to z-score normalize each ECG channel.
"""

self.root = root
self.signal_length = signal_length
self.normalize = normalize

scp_path = os.path.join(self.root, "scp_statements.csv")
scp_df = pd.read_csv(scp_path, index_col=0)
self.mi_codes = set(
scp_df[scp_df["diagnostic_class"] == "MI"].index.astype(str).tolist()
)

def _load_ecg_signal(self, record_rel_path: str) -> np.ndarray:
"""Loads a PTB-XL WFDB record and returns shape (12, signal_length)."""
record_path = os.path.join(self.root, record_rel_path)

# WFDB expects the record path without file extension.
signal, _ = wfdb.rdsamp(record_path)

# rdsamp returns shape (num_samples, num_channels)
signal = signal.T.astype(np.float32) # -> (channels, time)

if self.normalize:
mean = signal.mean(axis=1, keepdims=True)
std = signal.std(axis=1, keepdims=True)
std = np.where(std < 1e-6, 1.0, std)
signal = (signal - mean) / std

current_len = signal.shape[1]
if current_len >= self.signal_length:
signal = signal[:, : self.signal_length]
else:
pad_width = self.signal_length - current_len
signal = np.pad(signal, ((0, 0), (0, pad_width)), mode="constant")

return signal

def __call__(self, patient) -> List[Dict]:
"""Generate PTB-XL MI samples from a patient record.

Args:
patient: Patient object containing PTB-XL event data.

Returns:
A list of sample dictionaries with keys:
- patient_id
- visit_id
- record_id
- signal
- label
"""

samples = []

rows = patient.data_source.to_dicts()

for idx, row in enumerate(rows):
raw_label = row["ptbxl/scp_codes"]
record_rel_path = row["ptbxl/record_path"]

try:
scp_codes = (
ast.literal_eval(raw_label)
if isinstance(raw_label, str)
else raw_label
)
except (ValueError, SyntaxError):
scp_codes = {}

label = 1 if any(code in self.mi_codes for code in scp_codes.keys()) else 0
signal = self._load_ecg_signal(record_rel_path)

visit_id = str(row["ptbxl/ecg_id"])

samples.append(
{
"patient_id": patient.patient_id,
"visit_id": visit_id,
"record_id": idx + 1,
"signal": signal.tolist(),
"label": label,
}
)

return samples
36 changes: 36 additions & 0 deletions tests/core/test_ptbxl_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
import tempfile
import unittest

from pyhealth.datasets import PTBXLDataset


class TestPTBXLDataset(unittest.TestCase):
def test_load_data_dev_mode(self):
with tempfile.TemporaryDirectory() as tmpdir:
csv_path = os.path.join(tmpdir, "ptbxl_database.csv")

with open(csv_path, "w") as f:
f.write("ecg_id,patient_id,filename_lr,filename_hr,scp_codes\n")
f.write('1,100,records100/00000/00001_lr,records500/00000/00001_hr,"{\'MI\': 1}"\n')
f.write('2,101,records100/00000/00002_lr,records500/00000/00002_hr,"{\'NORM\': 1}"\n')

dataset = PTBXLDataset(
root=tmpdir,
dev=True,
)

df = dataset.load_data().compute()

self.assertEqual(len(df), 2)
self.assertIn("patient_id", df.columns)
self.assertIn("event_type", df.columns)
self.assertIn("ptbxl/ecg_id", df.columns)
self.assertIn("ptbxl/record_path", df.columns)
self.assertIn("ptbxl/scp_codes", df.columns)
self.assertEqual(str(df.iloc[0]["patient_id"]), "100")
self.assertEqual(df.iloc[0]["event_type"], "ptbxl")


if __name__ == "__main__":
unittest.main()
63 changes: 63 additions & 0 deletions tests/core/test_ptbxl_mi_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os
import tempfile
import unittest
from unittest.mock import patch

import numpy as np
import pandas as pd
import polars as pl

from pyhealth.tasks.ptbxl_mi_classification import PTBXLMIClassificationTask
from pyhealth.data import Patient


class TestPTBXLTask(unittest.TestCase):
@patch.object(PTBXLMIClassificationTask, "_load_ecg_signal")
def test_mi_label_extraction(self, mock_load_signal):
mock_load_signal.return_value = np.zeros((12, 1000), dtype=np.float32)

with tempfile.TemporaryDirectory() as tmpdir:
scp_path = os.path.join(tmpdir, "scp_statements.csv")

# minimal synthetic SCP mapping
scp_df = pd.DataFrame(
{
"diagnostic_class": ["MI", "NORM"],
},
index=["IMI", "NORM"],
)
scp_df.to_csv(scp_path)

df = pd.DataFrame(
{
"patient_id": ["1", "1"],
"event_type": ["ptbxl", "ptbxl"],
"timestamp": [None, None],
"ptbxl/ecg_id": [100, 101],
"ptbxl/record_path": [
"records100/00000/00001_lr",
"records100/00000/00002_lr",
],
"ptbxl/scp_codes": [
"{'IMI': 1}",
"{'NORM': 1}",
],
}
)

patient = Patient(
patient_id="1",
data_source=pl.from_pandas(df),
)

task = PTBXLMIClassificationTask(root=tmpdir)
samples = task(patient)

self.assertEqual(len(samples), 2)
self.assertEqual(samples[0]["label"], 1)
self.assertEqual(samples[1]["label"], 0)
self.assertEqual(np.array(samples[0]["signal"]).shape, (12, 1000))


if __name__ == "__main__":
unittest.main()