From f8caba4e8405b56dd34b14c0469e12576bdfd83c Mon Sep 17 00:00:00 2001 From: zaidalkhatib Date: Sun, 5 Apr 2026 23:30:48 -0700 Subject: [PATCH 1/4] Add PTB-XL dataset and MI classification task --- docs/api/datasets.rst | 1 + docs/api/datasets/pyhealth.datasets.ptbxl.rst | 7 ++ docs/api/tasks.rst | 1 + ...pyhealth.tasks.ptbxl_mi_classification.rst | 7 ++ examples/ptbxl_mi_classification_cnn.py | 19 +++++ pyhealth/datasets/__init__.py | 1 + pyhealth/datasets/ptbxl.py | 51 ++++++++++++++ pyhealth/tasks/__init__.py | 1 + pyhealth/tasks/ptbxl_mi_classification.py | 69 +++++++++++++++++++ tests/core/test_ptbxl_dataset.py | 36 ++++++++++ tests/core/test_ptbxl_mi_classification.py | 42 +++++++++++ 11 files changed, 235 insertions(+) create mode 100644 docs/api/datasets/pyhealth.datasets.ptbxl.rst create mode 100644 docs/api/tasks/pyhealth.tasks.ptbxl_mi_classification.rst create mode 100644 examples/ptbxl_mi_classification_cnn.py create mode 100644 pyhealth/datasets/ptbxl.py create mode 100644 pyhealth/tasks/ptbxl_mi_classification.py create mode 100644 tests/core/test_ptbxl_dataset.py create mode 100644 tests/core/test_ptbxl_mi_classification.py diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..33cacc504 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -245,3 +245,4 @@ Available Datasets datasets/pyhealth.datasets.TCGAPRADDataset datasets/pyhealth.datasets.splitter datasets/pyhealth.datasets.utils + pyhealth.datasets.ptbxl diff --git a/docs/api/datasets/pyhealth.datasets.ptbxl.rst b/docs/api/datasets/pyhealth.datasets.ptbxl.rst new file mode 100644 index 000000000..dc43ce9ee --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.ptbxl.rst @@ -0,0 +1,7 @@ +pyhealth.datasets.ptbxl +======================= + +.. autoclass:: pyhealth.datasets.PTBXLDataset + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..f63df4596 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -229,3 +229,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + PTB-XL MI Classification \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.ptbxl_mi_classification.rst b/docs/api/tasks/pyhealth.tasks.ptbxl_mi_classification.rst new file mode 100644 index 000000000..a4495c3be --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ptbxl_mi_classification.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.ptbxl_mi_classification +====================================== + +.. autoclass:: pyhealth.tasks.PTBXLMIClassificationTask + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/ptbxl_mi_classification_cnn.py b/examples/ptbxl_mi_classification_cnn.py new file mode 100644 index 000000000..f6962c4da --- /dev/null +++ b/examples/ptbxl_mi_classification_cnn.py @@ -0,0 +1,19 @@ +from pyhealth.datasets import PTBXLDataset +from pyhealth.tasks import PTBXLMIClassificationTask + + +def main(): + dataset = PTBXLDataset( + root="/Users/zaidalkhatib/Downloads/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3", + dev=True, + ) + + task = PTBXLMIClassificationTask() + task_dataset = dataset.set_task(task) + + print(task_dataset[0]) + print(f"Number of samples: {len(task_dataset)}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 54e77670c..e00bb968c 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -90,3 +90,4 @@ def __init__(self, *args, **kwargs): save_processors, ) from .collate import collate_temporal +from .ptbxl import PTBXLDataset \ No newline at end of file diff --git a/pyhealth/datasets/ptbxl.py b/pyhealth/datasets/ptbxl.py new file mode 100644 index 000000000..753a8d4e0 --- /dev/null +++ b/pyhealth/datasets/ptbxl.py @@ -0,0 +1,51 @@ +import ast +import os +from typing import Optional + +import dask.dataframe as dd +import pandas as pd + +from pyhealth.datasets import BaseDataset + + +class PTBXLDataset(BaseDataset): + """PTB-XL ECG dataset represented as an event table.""" + + def __init__( + self, + root: str, + dataset_name: Optional[str] = "PTBXL", + dev: bool = False, + cache_dir: Optional[str] = None, + num_workers: int = 1, + ): + super().__init__( + root=root, + tables=["ptbxl"], + dataset_name=dataset_name, + cache_dir=cache_dir, + num_workers=num_workers, + dev=dev, + ) + + def load_data(self) -> dd.DataFrame: + metadata_path = os.path.join(self.root, "ptbxl_database.csv") + df = pd.read_csv(metadata_path) + + if self.dev: + df = df.head(10) + + # Keep only the fields we need for the task + event_df = pd.DataFrame( + { + "patient_id": df["patient_id"].astype(str), + "event_type": "ptbxl", + "timestamp": pd.NaT, + "ptbxl/ecg_id": df["ecg_id"], + "ptbxl/filename_lr": df["filename_lr"], + "ptbxl/filename_hr": df["filename_hr"], + "ptbxl/scp_codes": df["scp_codes"], + } + ) + + return dd.from_pandas(event_df, npartitions=1) \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..0e9b70b15 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -66,3 +66,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .ptbxl_mi_classification import PTBXLMIClassificationTask \ No newline at end of file diff --git a/pyhealth/tasks/ptbxl_mi_classification.py b/pyhealth/tasks/ptbxl_mi_classification.py new file mode 100644 index 000000000..f49f0e69a --- /dev/null +++ b/pyhealth/tasks/ptbxl_mi_classification.py @@ -0,0 +1,69 @@ +import ast +import os +import pickle +from typing import Dict, List + +import numpy as np + +from pyhealth.tasks import BaseTask + + +class PTBXLMIClassificationTask(BaseTask): + task_name = "ptbxl_mi_classification" + input_schema = { + "signal": "timeseries", + } + output_schema = { + "label": "binary", + } + + def __call__(self, patient) -> List[Dict]: + samples = [] + + patient_df = patient.data_source + rows = patient_df.to_dicts() + + for idx, row in enumerate(rows): + raw_label = row["ptbxl/scp_codes"] + + try: + scp_codes = ( + ast.literal_eval(raw_label) + if isinstance(raw_label, str) + else raw_label + ) + except (ValueError, SyntaxError): + scp_codes = {} + + label = 1 if "MI" in scp_codes else 0 + + signal = np.zeros((12, 1000), dtype=np.float32) + + visit_id = str(row["ptbxl/ecg_id"]) + cache_dir = os.path.join("/tmp", "ptbxl_task_cache") + os.makedirs(cache_dir, exist_ok=True) + save_file_path = os.path.join( + cache_dir, f"{patient.patient_id}-MI-{visit_id}.pkl" + ) + + with open(save_file_path, "wb") as f: + pickle.dump( + { + "signal": signal, + "label": label, + }, + f, + ) + + samples.append( + { + "patient_id": patient.patient_id, + "visit_id": visit_id, + "record_id": idx + 1, + "signal": signal.tolist(), + "label": label, + "epoch_path": save_file_path, + } + ) + + return samples \ No newline at end of file diff --git a/tests/core/test_ptbxl_dataset.py b/tests/core/test_ptbxl_dataset.py new file mode 100644 index 000000000..5fbe0d43b --- /dev/null +++ b/tests/core/test_ptbxl_dataset.py @@ -0,0 +1,36 @@ +import os +import tempfile +import unittest + +from pyhealth.datasets import PTBXLDataset + + +class TestPTBXLDataset(unittest.TestCase): + def test_load_data_dev_mode(self): + with tempfile.TemporaryDirectory() as tmpdir: + csv_path = os.path.join(tmpdir, "ptbxl_database.csv") + + with open(csv_path, "w") as f: + f.write("ecg_id,patient_id,filename_lr,filename_hr,scp_codes\n") + f.write('1,100,records100/00000/00001_lr,records500/00000/00001_hr,"{\'MI\': 1}"\n') + f.write('2,101,records100/00000/00002_lr,records500/00000/00002_hr,"{\'NORM\': 1}"\n') + + dataset = PTBXLDataset( + root=tmpdir, + dev=True, + ) + + df = dataset.load_data().compute() + + self.assertEqual(len(df), 2) + self.assertIn("patient_id", df.columns) + self.assertIn("event_type", df.columns) + self.assertIn("ptbxl/ecg_id", df.columns) + self.assertIn("ptbxl/filename_lr", df.columns) + self.assertIn("ptbxl/scp_codes", df.columns) + self.assertEqual(str(df.iloc[0]["patient_id"]), "100") + self.assertEqual(df.iloc[0]["event_type"], "ptbxl") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/core/test_ptbxl_mi_classification.py b/tests/core/test_ptbxl_mi_classification.py new file mode 100644 index 000000000..b72901f43 --- /dev/null +++ b/tests/core/test_ptbxl_mi_classification.py @@ -0,0 +1,42 @@ +import unittest +import pandas as pd +import polars as pl + +from pyhealth.tasks.ptbxl_mi_classification import PTBXLMIClassificationTask +from pyhealth.data import Patient + + +class TestPTBXLTask(unittest.TestCase): + + def test_mi_label_extraction(self): + # synthetic patient data + df = pd.DataFrame({ + "patient_id": ["1", "1"], + "event_type": ["ptbxl", "ptbxl"], + "timestamp": [None, None], + "ptbxl/ecg_id": [100, 101], + "ptbxl/filename_lr": ["a", "b"], + "ptbxl/filename_hr": ["a", "b"], + "ptbxl/scp_codes": [ + "{'MI': 1}", # should be label = 1 + "{'NORM': 1}" # should be label = 0 + ], + }) + + pl_df = pl.from_pandas(df) + + patient = Patient( + patient_id="1", + data_source=pl_df + ) + + task = PTBXLMIClassificationTask() + samples = task(patient) + + self.assertEqual(len(samples), 2) + self.assertEqual(samples[0]["label"], 1) + self.assertEqual(samples[1]["label"], 0) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 64e9baee5563babbbe995c1ce45eef8a73afb534 Mon Sep 17 00:00:00 2001 From: zaidalkhatib Date: Mon, 6 Apr 2026 21:23:36 -0700 Subject: [PATCH 2/4] Add PTB-XL dataset and MI classification task --- examples/ptbxl_mi_classification_cnn.py | 6 ++++-- pyhealth/datasets/ptbxl.py | 5 ----- pyhealth/tasks/ptbxl_mi_classification.py | 25 +++++++++++++---------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/examples/ptbxl_mi_classification_cnn.py b/examples/ptbxl_mi_classification_cnn.py index f6962c4da..76fa1b680 100644 --- a/examples/ptbxl_mi_classification_cnn.py +++ b/examples/ptbxl_mi_classification_cnn.py @@ -3,12 +3,14 @@ def main(): + root = "/Users/zaidalkhatib/Downloads/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3" + dataset = PTBXLDataset( - root="/Users/zaidalkhatib/Downloads/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3", + root=root, dev=True, ) - task = PTBXLMIClassificationTask() + task = PTBXLMIClassificationTask(root=root) task_dataset = dataset.set_task(task) print(task_dataset[0]) diff --git a/pyhealth/datasets/ptbxl.py b/pyhealth/datasets/ptbxl.py index 753a8d4e0..fe0b167b0 100644 --- a/pyhealth/datasets/ptbxl.py +++ b/pyhealth/datasets/ptbxl.py @@ -1,4 +1,3 @@ -import ast import os from typing import Optional @@ -32,10 +31,6 @@ def load_data(self) -> dd.DataFrame: metadata_path = os.path.join(self.root, "ptbxl_database.csv") df = pd.read_csv(metadata_path) - if self.dev: - df = df.head(10) - - # Keep only the fields we need for the task event_df = pd.DataFrame( { "patient_id": df["patient_id"].astype(str), diff --git a/pyhealth/tasks/ptbxl_mi_classification.py b/pyhealth/tasks/ptbxl_mi_classification.py index f49f0e69a..83afc0c39 100644 --- a/pyhealth/tasks/ptbxl_mi_classification.py +++ b/pyhealth/tasks/ptbxl_mi_classification.py @@ -4,6 +4,7 @@ from typing import Dict, List import numpy as np +import pandas as pd from pyhealth.tasks import BaseTask @@ -11,17 +12,25 @@ class PTBXLMIClassificationTask(BaseTask): task_name = "ptbxl_mi_classification" input_schema = { - "signal": "timeseries", + "signal": "tensor", } output_schema = { "label": "binary", } + def __init__(self, root: str): + self.root = root + + scp_path = os.path.join(self.root, "scp_statements.csv") + scp_df = pd.read_csv(scp_path, index_col=0) + self.mi_codes = set( + scp_df[scp_df["diagnostic_class"] == "MI"].index.astype(str).tolist() + ) + def __call__(self, patient) -> List[Dict]: samples = [] - patient_df = patient.data_source - rows = patient_df.to_dicts() + rows = patient.data_source.to_dicts() for idx, row in enumerate(rows): raw_label = row["ptbxl/scp_codes"] @@ -35,7 +44,7 @@ def __call__(self, patient) -> List[Dict]: except (ValueError, SyntaxError): scp_codes = {} - label = 1 if "MI" in scp_codes else 0 + label = 1 if any(code in self.mi_codes for code in scp_codes.keys()) else 0 signal = np.zeros((12, 1000), dtype=np.float32) @@ -47,13 +56,7 @@ def __call__(self, patient) -> List[Dict]: ) with open(save_file_path, "wb") as f: - pickle.dump( - { - "signal": signal, - "label": label, - }, - f, - ) + pickle.dump({"signal": signal, "label": label}, f) samples.append( { From 40febc2dbd61cacf5e82e4c731819d4ecca3ebf9 Mon Sep 17 00:00:00 2001 From: zaidalkhatib Date: Sun, 12 Apr 2026 14:28:18 -0700 Subject: [PATCH 3/4] Update PTB-XL tests for waveform-loading task --- examples/ptbxl_mi_classification_cnn.py | 7 +- pyhealth/datasets/ptbxl.py | 7 +- pyhealth/tasks/ptbxl_mi_classification.py | 49 ++++++++++---- tests/core/test_ptbxl_dataset.py | 2 +- tests/core/test_ptbxl_mi_classification.py | 79 ++++++++++++++-------- 5 files changed, 98 insertions(+), 46 deletions(-) diff --git a/examples/ptbxl_mi_classification_cnn.py b/examples/ptbxl_mi_classification_cnn.py index 76fa1b680..1e4ff000b 100644 --- a/examples/ptbxl_mi_classification_cnn.py +++ b/examples/ptbxl_mi_classification_cnn.py @@ -8,9 +8,14 @@ def main(): dataset = PTBXLDataset( root=root, dev=True, + use_high_resolution=False, # False -> records100, True -> records500 ) - task = PTBXLMIClassificationTask(root=root) + task = PTBXLMIClassificationTask( + root=root, + signal_length=1000, # 10 seconds at 100 Hz + normalize=True, + ) task_dataset = dataset.set_task(task) print(task_dataset[0]) diff --git a/pyhealth/datasets/ptbxl.py b/pyhealth/datasets/ptbxl.py index fe0b167b0..00102d948 100644 --- a/pyhealth/datasets/ptbxl.py +++ b/pyhealth/datasets/ptbxl.py @@ -17,7 +17,9 @@ def __init__( dev: bool = False, cache_dir: Optional[str] = None, num_workers: int = 1, + use_high_resolution: bool = False, ): + self.use_high_resolution = use_high_resolution super().__init__( root=root, tables=["ptbxl"], @@ -31,14 +33,15 @@ def load_data(self) -> dd.DataFrame: metadata_path = os.path.join(self.root, "ptbxl_database.csv") df = pd.read_csv(metadata_path) + record_path_col = "filename_hr" if self.use_high_resolution else "filename_lr" + event_df = pd.DataFrame( { "patient_id": df["patient_id"].astype(str), "event_type": "ptbxl", "timestamp": pd.NaT, "ptbxl/ecg_id": df["ecg_id"], - "ptbxl/filename_lr": df["filename_lr"], - "ptbxl/filename_hr": df["filename_hr"], + "ptbxl/record_path": df[record_path_col], "ptbxl/scp_codes": df["scp_codes"], } ) diff --git a/pyhealth/tasks/ptbxl_mi_classification.py b/pyhealth/tasks/ptbxl_mi_classification.py index 83afc0c39..fa3fcae6a 100644 --- a/pyhealth/tasks/ptbxl_mi_classification.py +++ b/pyhealth/tasks/ptbxl_mi_classification.py @@ -1,10 +1,10 @@ import ast import os -import pickle from typing import Dict, List import numpy as np import pandas as pd +import wfdb from pyhealth.tasks import BaseTask @@ -18,8 +18,15 @@ class PTBXLMIClassificationTask(BaseTask): "label": "binary", } - def __init__(self, root: str): + def __init__( + self, + root: str, + signal_length: int = 1000, + normalize: bool = True, + ): self.root = root + self.signal_length = signal_length + self.normalize = normalize scp_path = os.path.join(self.root, "scp_statements.csv") scp_df = pd.read_csv(scp_path, index_col=0) @@ -27,6 +34,31 @@ def __init__(self, root: str): scp_df[scp_df["diagnostic_class"] == "MI"].index.astype(str).tolist() ) + def _load_ecg_signal(self, record_rel_path: str) -> np.ndarray: + """Loads a PTB-XL WFDB record and returns shape (12, signal_length).""" + record_path = os.path.join(self.root, record_rel_path) + + # WFDB expects the record path without file extension. + signal, _ = wfdb.rdsamp(record_path) + + # rdsamp returns shape (num_samples, num_channels) + signal = signal.T.astype(np.float32) # -> (channels, time) + + if self.normalize: + mean = signal.mean(axis=1, keepdims=True) + std = signal.std(axis=1, keepdims=True) + std = np.where(std < 1e-6, 1.0, std) + signal = (signal - mean) / std + + current_len = signal.shape[1] + if current_len >= self.signal_length: + signal = signal[:, : self.signal_length] + else: + pad_width = self.signal_length - current_len + signal = np.pad(signal, ((0, 0), (0, pad_width)), mode="constant") + + return signal + def __call__(self, patient) -> List[Dict]: samples = [] @@ -34,6 +66,7 @@ def __call__(self, patient) -> List[Dict]: for idx, row in enumerate(rows): raw_label = row["ptbxl/scp_codes"] + record_rel_path = row["ptbxl/record_path"] try: scp_codes = ( @@ -45,18 +78,9 @@ def __call__(self, patient) -> List[Dict]: scp_codes = {} label = 1 if any(code in self.mi_codes for code in scp_codes.keys()) else 0 - - signal = np.zeros((12, 1000), dtype=np.float32) + signal = self._load_ecg_signal(record_rel_path) visit_id = str(row["ptbxl/ecg_id"]) - cache_dir = os.path.join("/tmp", "ptbxl_task_cache") - os.makedirs(cache_dir, exist_ok=True) - save_file_path = os.path.join( - cache_dir, f"{patient.patient_id}-MI-{visit_id}.pkl" - ) - - with open(save_file_path, "wb") as f: - pickle.dump({"signal": signal, "label": label}, f) samples.append( { @@ -65,7 +89,6 @@ def __call__(self, patient) -> List[Dict]: "record_id": idx + 1, "signal": signal.tolist(), "label": label, - "epoch_path": save_file_path, } ) diff --git a/tests/core/test_ptbxl_dataset.py b/tests/core/test_ptbxl_dataset.py index 5fbe0d43b..111599494 100644 --- a/tests/core/test_ptbxl_dataset.py +++ b/tests/core/test_ptbxl_dataset.py @@ -26,7 +26,7 @@ def test_load_data_dev_mode(self): self.assertIn("patient_id", df.columns) self.assertIn("event_type", df.columns) self.assertIn("ptbxl/ecg_id", df.columns) - self.assertIn("ptbxl/filename_lr", df.columns) + self.assertIn("ptbxl/record_path", df.columns) self.assertIn("ptbxl/scp_codes", df.columns) self.assertEqual(str(df.iloc[0]["patient_id"]), "100") self.assertEqual(df.iloc[0]["event_type"], "ptbxl") diff --git a/tests/core/test_ptbxl_mi_classification.py b/tests/core/test_ptbxl_mi_classification.py index b72901f43..dee838d19 100644 --- a/tests/core/test_ptbxl_mi_classification.py +++ b/tests/core/test_ptbxl_mi_classification.py @@ -1,4 +1,9 @@ +import os +import tempfile import unittest +from unittest.mock import patch + +import numpy as np import pandas as pd import polars as pl @@ -7,35 +12,51 @@ class TestPTBXLTask(unittest.TestCase): - - def test_mi_label_extraction(self): - # synthetic patient data - df = pd.DataFrame({ - "patient_id": ["1", "1"], - "event_type": ["ptbxl", "ptbxl"], - "timestamp": [None, None], - "ptbxl/ecg_id": [100, 101], - "ptbxl/filename_lr": ["a", "b"], - "ptbxl/filename_hr": ["a", "b"], - "ptbxl/scp_codes": [ - "{'MI': 1}", # should be label = 1 - "{'NORM': 1}" # should be label = 0 - ], - }) - - pl_df = pl.from_pandas(df) - - patient = Patient( - patient_id="1", - data_source=pl_df - ) - - task = PTBXLMIClassificationTask() - samples = task(patient) - - self.assertEqual(len(samples), 2) - self.assertEqual(samples[0]["label"], 1) - self.assertEqual(samples[1]["label"], 0) + @patch.object(PTBXLMIClassificationTask, "_load_ecg_signal") + def test_mi_label_extraction(self, mock_load_signal): + mock_load_signal.return_value = np.zeros((12, 1000), dtype=np.float32) + + with tempfile.TemporaryDirectory() as tmpdir: + scp_path = os.path.join(tmpdir, "scp_statements.csv") + + # minimal synthetic SCP mapping + scp_df = pd.DataFrame( + { + "diagnostic_class": ["MI", "NORM"], + }, + index=["IMI", "NORM"], + ) + scp_df.to_csv(scp_path) + + df = pd.DataFrame( + { + "patient_id": ["1", "1"], + "event_type": ["ptbxl", "ptbxl"], + "timestamp": [None, None], + "ptbxl/ecg_id": [100, 101], + "ptbxl/record_path": [ + "records100/00000/00001_lr", + "records100/00000/00002_lr", + ], + "ptbxl/scp_codes": [ + "{'IMI': 1}", + "{'NORM': 1}", + ], + } + ) + + patient = Patient( + patient_id="1", + data_source=pl.from_pandas(df), + ) + + task = PTBXLMIClassificationTask(root=tmpdir) + samples = task(patient) + + self.assertEqual(len(samples), 2) + self.assertEqual(samples[0]["label"], 1) + self.assertEqual(samples[1]["label"], 0) + self.assertEqual(np.array(samples[0]["signal"]).shape, (12, 1000)) if __name__ == "__main__": From 9339cb859ce3a3b0fcc0264ea7036f154a1f3a26 Mon Sep 17 00:00:00 2001 From: zaidalkhatib Date: Sun, 12 Apr 2026 14:44:27 -0700 Subject: [PATCH 4/4] fix dataset root --- examples/ptbxl_mi_classification_cnn.py | 7 ++-- pyhealth/tasks/ptbxl_mi_classification.py | 40 +++++++++++++++++++++++ 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/examples/ptbxl_mi_classification_cnn.py b/examples/ptbxl_mi_classification_cnn.py index 1e4ff000b..28f3d0723 100644 --- a/examples/ptbxl_mi_classification_cnn.py +++ b/examples/ptbxl_mi_classification_cnn.py @@ -1,10 +1,11 @@ from pyhealth.datasets import PTBXLDataset from pyhealth.tasks import PTBXLMIClassificationTask - +import os def main(): - root = "/Users/zaidalkhatib/Downloads/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3" - + root = os.path.expanduser( + "~/Downloads/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3" + ) dataset = PTBXLDataset( root=root, dev=True, diff --git a/pyhealth/tasks/ptbxl_mi_classification.py b/pyhealth/tasks/ptbxl_mi_classification.py index fa3fcae6a..552563fa8 100644 --- a/pyhealth/tasks/ptbxl_mi_classification.py +++ b/pyhealth/tasks/ptbxl_mi_classification.py @@ -1,3 +1,10 @@ +"""PTBXL MI classification task for PyHealth. + +This module defines a task that loads PTB-XL ECG records, maps SCP +diagnostic codes to myocardial infarction (MI) labels, and returns one +binary-labeled sample per record. +""" + import ast import os from typing import Dict, List @@ -10,6 +17,17 @@ class PTBXLMIClassificationTask(BaseTask): + """Task for classifying myocardial infarction (MI) in PTB-XL ECG records. + + This task converts the PTB-XL SCP diagnostic codes into a binary MI label + and loads the corresponding ECG signal for each record. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): Input schema mapping signal to tensor. + output_schema (Dict[str, str]): Output schema mapping label to binary. + """ + task_name = "ptbxl_mi_classification" input_schema = { "signal": "tensor", @@ -24,6 +42,14 @@ def __init__( signal_length: int = 1000, normalize: bool = True, ): + """Initialize the PTBXL MI classification task. + + Args: + root: PTB-XL dataset root directory containing `scp_statements.csv`. + signal_length: Number of samples to use for each ECG signal. + normalize: Whether to z-score normalize each ECG channel. + """ + self.root = root self.signal_length = signal_length self.normalize = normalize @@ -60,6 +86,20 @@ def _load_ecg_signal(self, record_rel_path: str) -> np.ndarray: return signal def __call__(self, patient) -> List[Dict]: + """Generate PTB-XL MI samples from a patient record. + + Args: + patient: Patient object containing PTB-XL event data. + + Returns: + A list of sample dictionaries with keys: + - patient_id + - visit_id + - record_id + - signal + - label + """ + samples = [] rows = patient.data_source.to_dicts()