Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,4 @@ Available Datasets
datasets/pyhealth.datasets.TCGAPRADDataset
datasets/pyhealth.datasets.splitter
datasets/pyhealth.datasets.utils
pyhealth.datasets.ptbxl
7 changes: 7 additions & 0 deletions docs/api/datasets/pyhealth.datasets.ptbxl.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.datasets.ptbxl
=======================

.. autoclass:: pyhealth.datasets.PTBXLDataset
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,4 @@ Available Tasks
Mutation Pathogenicity (COSMIC) <tasks/pyhealth.tasks.MutationPathogenicityPrediction>
Cancer Survival Prediction (TCGA) <tasks/pyhealth.tasks.CancerSurvivalPrediction>
Cancer Mutation Burden (TCGA) <tasks/pyhealth.tasks.CancerMutationBurden>
PTB-XL MI Classification <tasks/pyhealth.tasks.ptbxl_mi_classification>
7 changes: 7 additions & 0 deletions docs/api/tasks/pyhealth.tasks.ptbxl_mi_classification.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.tasks.ptbxl_mi_classification
======================================

.. autoclass:: pyhealth.tasks.PTBXLMIClassificationTask
:members:
:undoc-members:
:show-inheritance:
21 changes: 21 additions & 0 deletions examples/ptbxl_mi_classification_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pyhealth.datasets import PTBXLDataset
from pyhealth.tasks import PTBXLMIClassificationTask


def main():
root = "/Users/zaidalkhatib/Downloads/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3"

dataset = PTBXLDataset(
root=root,
dev=True,
)

task = PTBXLMIClassificationTask(root=root)
task_dataset = dataset.set_task(task)

print(task_dataset[0])
print(f"Number of samples: {len(task_dataset)}")


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,4 @@ def __init__(self, *args, **kwargs):
save_processors,
)
from .collate import collate_temporal
from .ptbxl import PTBXLDataset
46 changes: 46 additions & 0 deletions pyhealth/datasets/ptbxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
from typing import Optional

import dask.dataframe as dd
import pandas as pd

from pyhealth.datasets import BaseDataset


class PTBXLDataset(BaseDataset):
"""PTB-XL ECG dataset represented as an event table."""

def __init__(
self,
root: str,
dataset_name: Optional[str] = "PTBXL",
dev: bool = False,
cache_dir: Optional[str] = None,
num_workers: int = 1,
):
super().__init__(
root=root,
tables=["ptbxl"],
dataset_name=dataset_name,
cache_dir=cache_dir,
num_workers=num_workers,
dev=dev,
)

def load_data(self) -> dd.DataFrame:
metadata_path = os.path.join(self.root, "ptbxl_database.csv")
df = pd.read_csv(metadata_path)

event_df = pd.DataFrame(
{
"patient_id": df["patient_id"].astype(str),
"event_type": "ptbxl",
"timestamp": pd.NaT,
"ptbxl/ecg_id": df["ecg_id"],
"ptbxl/filename_lr": df["filename_lr"],
"ptbxl/filename_hr": df["filename_hr"],
"ptbxl/scp_codes": df["scp_codes"],
}
)

return dd.from_pandas(event_df, npartitions=1)
1 change: 1 addition & 0 deletions pyhealth/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,4 @@
VariantClassificationClinVar,
)
from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task
from .ptbxl_mi_classification import PTBXLMIClassificationTask
72 changes: 72 additions & 0 deletions pyhealth/tasks/ptbxl_mi_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import ast
import os
import pickle
from typing import Dict, List

import numpy as np
import pandas as pd

from pyhealth.tasks import BaseTask


class PTBXLMIClassificationTask(BaseTask):
task_name = "ptbxl_mi_classification"
input_schema = {
"signal": "tensor",
}
output_schema = {
"label": "binary",
}

def __init__(self, root: str):
self.root = root

scp_path = os.path.join(self.root, "scp_statements.csv")
scp_df = pd.read_csv(scp_path, index_col=0)
self.mi_codes = set(
scp_df[scp_df["diagnostic_class"] == "MI"].index.astype(str).tolist()
)

def __call__(self, patient) -> List[Dict]:
samples = []

rows = patient.data_source.to_dicts()

for idx, row in enumerate(rows):
raw_label = row["ptbxl/scp_codes"]

try:
scp_codes = (
ast.literal_eval(raw_label)
if isinstance(raw_label, str)
else raw_label
)
except (ValueError, SyntaxError):
scp_codes = {}

label = 1 if any(code in self.mi_codes for code in scp_codes.keys()) else 0

signal = np.zeros((12, 1000), dtype=np.float32)

visit_id = str(row["ptbxl/ecg_id"])
cache_dir = os.path.join("/tmp", "ptbxl_task_cache")
os.makedirs(cache_dir, exist_ok=True)
save_file_path = os.path.join(
cache_dir, f"{patient.patient_id}-MI-{visit_id}.pkl"
)

with open(save_file_path, "wb") as f:
pickle.dump({"signal": signal, "label": label}, f)

samples.append(
{
"patient_id": patient.patient_id,
"visit_id": visit_id,
"record_id": idx + 1,
"signal": signal.tolist(),
"label": label,
"epoch_path": save_file_path,
}
)

return samples
36 changes: 36 additions & 0 deletions tests/core/test_ptbxl_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
import tempfile
import unittest

from pyhealth.datasets import PTBXLDataset


class TestPTBXLDataset(unittest.TestCase):
def test_load_data_dev_mode(self):
with tempfile.TemporaryDirectory() as tmpdir:
csv_path = os.path.join(tmpdir, "ptbxl_database.csv")

with open(csv_path, "w") as f:
f.write("ecg_id,patient_id,filename_lr,filename_hr,scp_codes\n")
f.write('1,100,records100/00000/00001_lr,records500/00000/00001_hr,"{\'MI\': 1}"\n')
f.write('2,101,records100/00000/00002_lr,records500/00000/00002_hr,"{\'NORM\': 1}"\n')

dataset = PTBXLDataset(
root=tmpdir,
dev=True,
)

df = dataset.load_data().compute()

self.assertEqual(len(df), 2)
self.assertIn("patient_id", df.columns)
self.assertIn("event_type", df.columns)
self.assertIn("ptbxl/ecg_id", df.columns)
self.assertIn("ptbxl/filename_lr", df.columns)
self.assertIn("ptbxl/scp_codes", df.columns)
self.assertEqual(str(df.iloc[0]["patient_id"]), "100")
self.assertEqual(df.iloc[0]["event_type"], "ptbxl")


if __name__ == "__main__":
unittest.main()
42 changes: 42 additions & 0 deletions tests/core/test_ptbxl_mi_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import unittest
import pandas as pd
import polars as pl

from pyhealth.tasks.ptbxl_mi_classification import PTBXLMIClassificationTask
from pyhealth.data import Patient


class TestPTBXLTask(unittest.TestCase):

def test_mi_label_extraction(self):
# synthetic patient data
df = pd.DataFrame({
"patient_id": ["1", "1"],
"event_type": ["ptbxl", "ptbxl"],
"timestamp": [None, None],
"ptbxl/ecg_id": [100, 101],
"ptbxl/filename_lr": ["a", "b"],
"ptbxl/filename_hr": ["a", "b"],
"ptbxl/scp_codes": [
"{'MI': 1}", # should be label = 1
"{'NORM': 1}" # should be label = 0
],
})

pl_df = pl.from_pandas(df)

patient = Patient(
patient_id="1",
data_source=pl_df
)

task = PTBXLMIClassificationTask()
samples = task(patient)

self.assertEqual(len(samples), 2)
self.assertEqual(samples[0]["label"], 1)
self.assertEqual(samples[1]["label"], 0)


if __name__ == "__main__":
unittest.main()