diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..33cacc504 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -245,3 +245,4 @@ Available Datasets datasets/pyhealth.datasets.TCGAPRADDataset datasets/pyhealth.datasets.splitter datasets/pyhealth.datasets.utils + pyhealth.datasets.ptbxl diff --git a/docs/api/datasets/pyhealth.datasets.ptbxl.rst b/docs/api/datasets/pyhealth.datasets.ptbxl.rst new file mode 100644 index 000000000..dc43ce9ee --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.ptbxl.rst @@ -0,0 +1,7 @@ +pyhealth.datasets.ptbxl +======================= + +.. autoclass:: pyhealth.datasets.PTBXLDataset + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..f63df4596 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -229,3 +229,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + PTB-XL MI Classification \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.ptbxl_mi_classification.rst b/docs/api/tasks/pyhealth.tasks.ptbxl_mi_classification.rst new file mode 100644 index 000000000..a4495c3be --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ptbxl_mi_classification.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.ptbxl_mi_classification +====================================== + +.. autoclass:: pyhealth.tasks.PTBXLMIClassificationTask + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/ptbxl_mi_classification_cnn.py b/examples/ptbxl_mi_classification_cnn.py new file mode 100644 index 000000000..76fa1b680 --- /dev/null +++ b/examples/ptbxl_mi_classification_cnn.py @@ -0,0 +1,21 @@ +from pyhealth.datasets import PTBXLDataset +from pyhealth.tasks import PTBXLMIClassificationTask + + +def main(): + root = "/Users/zaidalkhatib/Downloads/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3" + + dataset = PTBXLDataset( + root=root, + dev=True, + ) + + task = PTBXLMIClassificationTask(root=root) + task_dataset = dataset.set_task(task) + + print(task_dataset[0]) + print(f"Number of samples: {len(task_dataset)}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 54e77670c..e00bb968c 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -90,3 +90,4 @@ def __init__(self, *args, **kwargs): save_processors, ) from .collate import collate_temporal +from .ptbxl import PTBXLDataset \ No newline at end of file diff --git a/pyhealth/datasets/ptbxl.py b/pyhealth/datasets/ptbxl.py new file mode 100644 index 000000000..fe0b167b0 --- /dev/null +++ b/pyhealth/datasets/ptbxl.py @@ -0,0 +1,46 @@ +import os +from typing import Optional + +import dask.dataframe as dd +import pandas as pd + +from pyhealth.datasets import BaseDataset + + +class PTBXLDataset(BaseDataset): + """PTB-XL ECG dataset represented as an event table.""" + + def __init__( + self, + root: str, + dataset_name: Optional[str] = "PTBXL", + dev: bool = False, + cache_dir: Optional[str] = None, + num_workers: int = 1, + ): + super().__init__( + root=root, + tables=["ptbxl"], + dataset_name=dataset_name, + cache_dir=cache_dir, + num_workers=num_workers, + dev=dev, + ) + + def load_data(self) -> dd.DataFrame: + metadata_path = os.path.join(self.root, "ptbxl_database.csv") + df = pd.read_csv(metadata_path) + + event_df = pd.DataFrame( + { + "patient_id": df["patient_id"].astype(str), + "event_type": "ptbxl", + "timestamp": pd.NaT, + "ptbxl/ecg_id": df["ecg_id"], + "ptbxl/filename_lr": df["filename_lr"], + "ptbxl/filename_hr": df["filename_hr"], + "ptbxl/scp_codes": df["scp_codes"], + } + ) + + return dd.from_pandas(event_df, npartitions=1) \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..0e9b70b15 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -66,3 +66,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .ptbxl_mi_classification import PTBXLMIClassificationTask \ No newline at end of file diff --git a/pyhealth/tasks/ptbxl_mi_classification.py b/pyhealth/tasks/ptbxl_mi_classification.py new file mode 100644 index 000000000..83afc0c39 --- /dev/null +++ b/pyhealth/tasks/ptbxl_mi_classification.py @@ -0,0 +1,72 @@ +import ast +import os +import pickle +from typing import Dict, List + +import numpy as np +import pandas as pd + +from pyhealth.tasks import BaseTask + + +class PTBXLMIClassificationTask(BaseTask): + task_name = "ptbxl_mi_classification" + input_schema = { + "signal": "tensor", + } + output_schema = { + "label": "binary", + } + + def __init__(self, root: str): + self.root = root + + scp_path = os.path.join(self.root, "scp_statements.csv") + scp_df = pd.read_csv(scp_path, index_col=0) + self.mi_codes = set( + scp_df[scp_df["diagnostic_class"] == "MI"].index.astype(str).tolist() + ) + + def __call__(self, patient) -> List[Dict]: + samples = [] + + rows = patient.data_source.to_dicts() + + for idx, row in enumerate(rows): + raw_label = row["ptbxl/scp_codes"] + + try: + scp_codes = ( + ast.literal_eval(raw_label) + if isinstance(raw_label, str) + else raw_label + ) + except (ValueError, SyntaxError): + scp_codes = {} + + label = 1 if any(code in self.mi_codes for code in scp_codes.keys()) else 0 + + signal = np.zeros((12, 1000), dtype=np.float32) + + visit_id = str(row["ptbxl/ecg_id"]) + cache_dir = os.path.join("/tmp", "ptbxl_task_cache") + os.makedirs(cache_dir, exist_ok=True) + save_file_path = os.path.join( + cache_dir, f"{patient.patient_id}-MI-{visit_id}.pkl" + ) + + with open(save_file_path, "wb") as f: + pickle.dump({"signal": signal, "label": label}, f) + + samples.append( + { + "patient_id": patient.patient_id, + "visit_id": visit_id, + "record_id": idx + 1, + "signal": signal.tolist(), + "label": label, + "epoch_path": save_file_path, + } + ) + + return samples \ No newline at end of file diff --git a/tests/core/test_ptbxl_dataset.py b/tests/core/test_ptbxl_dataset.py new file mode 100644 index 000000000..5fbe0d43b --- /dev/null +++ b/tests/core/test_ptbxl_dataset.py @@ -0,0 +1,36 @@ +import os +import tempfile +import unittest + +from pyhealth.datasets import PTBXLDataset + + +class TestPTBXLDataset(unittest.TestCase): + def test_load_data_dev_mode(self): + with tempfile.TemporaryDirectory() as tmpdir: + csv_path = os.path.join(tmpdir, "ptbxl_database.csv") + + with open(csv_path, "w") as f: + f.write("ecg_id,patient_id,filename_lr,filename_hr,scp_codes\n") + f.write('1,100,records100/00000/00001_lr,records500/00000/00001_hr,"{\'MI\': 1}"\n') + f.write('2,101,records100/00000/00002_lr,records500/00000/00002_hr,"{\'NORM\': 1}"\n') + + dataset = PTBXLDataset( + root=tmpdir, + dev=True, + ) + + df = dataset.load_data().compute() + + self.assertEqual(len(df), 2) + self.assertIn("patient_id", df.columns) + self.assertIn("event_type", df.columns) + self.assertIn("ptbxl/ecg_id", df.columns) + self.assertIn("ptbxl/filename_lr", df.columns) + self.assertIn("ptbxl/scp_codes", df.columns) + self.assertEqual(str(df.iloc[0]["patient_id"]), "100") + self.assertEqual(df.iloc[0]["event_type"], "ptbxl") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/core/test_ptbxl_mi_classification.py b/tests/core/test_ptbxl_mi_classification.py new file mode 100644 index 000000000..b72901f43 --- /dev/null +++ b/tests/core/test_ptbxl_mi_classification.py @@ -0,0 +1,42 @@ +import unittest +import pandas as pd +import polars as pl + +from pyhealth.tasks.ptbxl_mi_classification import PTBXLMIClassificationTask +from pyhealth.data import Patient + + +class TestPTBXLTask(unittest.TestCase): + + def test_mi_label_extraction(self): + # synthetic patient data + df = pd.DataFrame({ + "patient_id": ["1", "1"], + "event_type": ["ptbxl", "ptbxl"], + "timestamp": [None, None], + "ptbxl/ecg_id": [100, 101], + "ptbxl/filename_lr": ["a", "b"], + "ptbxl/filename_hr": ["a", "b"], + "ptbxl/scp_codes": [ + "{'MI': 1}", # should be label = 1 + "{'NORM': 1}" # should be label = 0 + ], + }) + + pl_df = pl.from_pandas(df) + + patient = Patient( + patient_id="1", + data_source=pl_df + ) + + task = PTBXLMIClassificationTask() + samples = task(patient) + + self.assertEqual(len(samples), 2) + self.assertEqual(samples[0]["label"], 1) + self.assertEqual(samples[1]["label"], 0) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file