diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..7a4075c3b 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -229,3 +229,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + Sleep Staging (DREAMT) diff --git a/docs/api/tasks/pyhealth.tasks.SleepStagingDREAMT.rst b/docs/api/tasks/pyhealth.tasks.SleepStagingDREAMT.rst new file mode 100644 index 000000000..4fbc2479c --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.SleepStagingDREAMT.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.SleepStagingDREAMT +======================================= + +.. autoclass:: pyhealth.tasks.sleep_staging_dreamt.SleepStagingDREAMT + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/dreamt_sleep_staging_rnn.py b/examples/dreamt_sleep_staging_rnn.py new file mode 100644 index 000000000..c3127076b --- /dev/null +++ b/examples/dreamt_sleep_staging_rnn.py @@ -0,0 +1,518 @@ +"""DREAMT Sleep Staging — Ablation Study with PyHealth RNN. + +This script runs two ablation experiments on the DREAMT +sleep staging task using ``SleepStagingDREAMT``: + +1. **Signal-subset ablation** (binary wake/sleep): + ACC-only vs BVP/HRV-only vs EDA+TEMP-only vs ALL signals. + +2. **Label-granularity ablation** (ALL signals): + 2-class (wake/sleep) vs 5-class (W/N1/N2/N3/R). + +The model is PyHealth's built-in ``RNN`` (LSTM variant), trained +using PyHealth's ``Trainer`` with patient-level data splits. + +Each 30-second epoch's raw multi-channel signal is reduced to +per-channel statistics (mean, std, min, max) to form a compact +feature vector fed to the model. + +Usage — full DREAMT run:: + + python dreamt_sleep_staging_rnn.py --root /path/to/dreamt + +Usage — synthetic demo (no dataset required):: + + python dreamt_sleep_staging_rnn.py --demo + +Metrics: F1 (macro), Accuracy, Cohen's Kappa. + +Results / Findings +------------------ + +**Demo mode** (synthetic data, 6 patients, 2 training epochs): + +Results are non-meaningful and serve only to verify that the +full pipeline (epoching -> feature extraction -> PyHealth RNN +training -> evaluation) runs end-to-end without error. Expected +output is near-random performance. + +**Paper reference** (Wang et al. CHIL 2024, Table 2): + +The original paper reports wake/sleep (2-class) detection on +80 participants (after artifact QC) using LightGBM / GPBoost +with hand-crafted features from all 8 E4 channels and 5-fold +participant-level CV. Key results from Table 2: + +- Baseline LightGBM: F1 = 0.777, Acc = 0.816, Kappa = 0.605 +- Best (GPBoost + Apnea Severity RE + LSTM post-processing): + F1 = 0.823, Acc = 0.857, Kappa = 0.694 + +The paper does not report per-signal-subset ablations; those +are original to this script. This ablation also uses a simpler +feature set (4 summary stats per channel) and a neural model +(LSTM via PyHealth RNN) with a 70/10/20 patient-level split, +so results are expected to differ from the paper. + +Reference: + Wang et al. "Addressing wearable sleep tracking inequity: + a new dataset and novel methods for a population with sleep + disorders." CHIL 2024, PMLR 248:380-396. +""" + +import argparse +import os +import tempfile +import warnings +from typing import Any, Dict, List, Optional + +import numpy as np + +from pyhealth.datasets import ( + create_sample_dataset, + get_dataloader, + split_by_patient, +) +from pyhealth.models import RNN +from pyhealth.tasks.sleep_staging_dreamt import ( + ALL_SIGNAL_COLUMNS, + SleepStagingDREAMT, +) +from pyhealth.trainer import Trainer + +warnings.filterwarnings("ignore", category=FutureWarning) + +EPOCH_LEN: int = 30 * 64 # 1920 samples per 30-s epoch at 64 Hz + +SIGNAL_SUBSETS: Dict[str, List[str]] = { + "ACC": ["ACC_X", "ACC_Y", "ACC_Z"], + "BVP_HRV": ["BVP", "HR", "IBI"], + "EDA_TEMP": ["EDA", "TEMP"], + "ALL": list(ALL_SIGNAL_COLUMNS), +} + + +def _epoch_features(signal: np.ndarray) -> List[float]: + """Convert a raw epoch signal to a compact feature vector. + + Computes mean, std, min, and max per channel. + + Args: + signal: Array of shape ``(n_channels, epoch_len)``. + + Returns: + Flat list of length ``4 * n_channels``. + """ + feats: List[float] = [] + for ch in range(signal.shape[0]): + s = signal[ch].astype(np.float64) + feats.extend([ + float(np.mean(s)), + float(np.std(s)), + float(np.min(s)), + float(np.max(s)), + ]) + return feats + + +def _prepare_samples( + raw_samples: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Transform task-output samples into feature-vector samples. + + Each raw sample has ``signal`` (n_channels, epoch_len). + This replaces it with a flat feature vector suitable for + PyHealth's ``TensorProcessor``. + + Args: + raw_samples: Output of ``SleepStagingDREAMT(patient)``. + + Returns: + List of dicts with ``patient_id``, ``features``, ``label``. + """ + return [ + { + "patient_id": s["patient_id"], + "features": _epoch_features(s["signal"]), + "label": s["label"], + } + for s in raw_samples + ] + + +def run_config( + raw_samples: List[Dict[str, Any]], + num_classes: int, + device: str = "cpu", + epochs: int = 5, + hidden_dim: int = 64, + split_ratios: Optional[List[float]] = None, +) -> Dict[str, float]: + """Run one ablation configuration with PyHealth RNN and Trainer. + + Args: + raw_samples: Epoch samples from ``SleepStagingDREAMT``. + num_classes: Number of classification classes (2 or 5). + device: Torch device string. + epochs: Training epochs. + hidden_dim: LSTM hidden dimension. + split_ratios: Patient-level train/val/test ratios. + + Returns: + Dictionary of evaluation metric scores. + """ + if split_ratios is None: + split_ratios = [0.7, 0.1, 0.2] + + samples = _prepare_samples(raw_samples) + + dataset = create_sample_dataset( + samples=samples, + input_schema={"features": "tensor"}, + output_schema={"label": "multiclass"}, + dataset_name="dreamt", + task_name="sleep_staging", + ) + + train_ds, val_ds, test_ds = split_by_patient( + dataset, split_ratios, seed=42, + ) + + train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False) + test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False) + + model = RNN( + dataset=dataset, + embedding_dim=hidden_dim, + hidden_dim=hidden_dim, + rnn_type="LSTM", + num_layers=1, + dropout=0.0, + ) + + trainer = Trainer( + model=model, + metrics=["accuracy", "f1_macro", "cohen_kappa"], + device=device, + enable_logging=False, + ) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=epochs, + ) + + results = trainer.evaluate(test_loader) + return results + + +# ----------------------------------------------------------- +# Demo data generation +# ----------------------------------------------------------- + + +def _generate_demo_csv( + tmpdir: str, + patient_id: str, + n_epochs: int, + rng: np.random.RandomState, +) -> str: + """Create one synthetic 64 Hz CSV file. + + Args: + tmpdir: Directory to write the CSV. + patient_id: Used in the filename. + n_epochs: Number of 30-s epochs to generate. + rng: Random state for reproducibility. + + Returns: + Path to the written CSV file. + """ + import pandas as pd + + stages_pool = ["W", "N1", "N2", "N3", "R"] + rows = n_epochs * EPOCH_LEN + data = { + "TIMESTAMP": np.arange(rows) / 64.0, + "BVP": rng.randn(rows) * 50, + "IBI": np.clip(rng.rand(rows) * 0.2 + 0.7, 0, 2), + "EDA": rng.rand(rows) * 5 + 0.1, + "TEMP": rng.rand(rows) * 4 + 33, + "ACC_X": rng.randn(rows) * 10, + "ACC_Y": rng.randn(rows) * 10, + "ACC_Z": rng.randn(rows) * 10, + "HR": rng.rand(rows) * 30 + 60, + } + stage_col = [] + for i in range(n_epochs): + st = stages_pool[i % len(stages_pool)] + stage_col.extend([st] * EPOCH_LEN) + data["Sleep_Stage"] = stage_col + + csv_path = os.path.join(tmpdir, f"{patient_id}_whole_df.csv") + pd.DataFrame(data).to_csv(csv_path, index=False) + return csv_path + + +def _generate_demo_samples( + n_classes: int = 2, + signal_columns: Optional[List[str]] = None, + n_patients: int = 6, + epochs_per_patient: int = 15, + seed: int = 123, +) -> List[Dict[str, Any]]: + """Create synthetic samples for demo mode. + + Args: + n_classes: Number of label classes (2, 3, or 5). + signal_columns: Which signal columns to include. + n_patients: Number of synthetic patients. + epochs_per_patient: 30-s epochs per patient. + seed: Random seed for reproducibility. + + Returns: + List of sample dicts ready for training. + """ + from types import SimpleNamespace + + rng = np.random.RandomState(seed) + all_samples: List[Dict[str, Any]] = [] + + with tempfile.TemporaryDirectory() as tmpdir: + for p in range(n_patients): + pid = f"DEMO_{p:03d}" + csv_path = _generate_demo_csv( + tmpdir, pid, epochs_per_patient, rng, + ) + + evt = SimpleNamespace(file_64hz=csv_path) + patient = SimpleNamespace( + patient_id=pid, + get_events=lambda et=None, e=evt: [e], + ) + + task = SleepStagingDREAMT( + n_classes=n_classes, + signal_columns=signal_columns, + apply_filters=False, + ) + all_samples.extend(task(patient)) + + return all_samples + + +# ----------------------------------------------------------- +# Ablation runners +# ----------------------------------------------------------- + +DEFAULT_ROOT = os.path.expanduser("~/.pyhealth/dreamt") + + +def _resolve_root(root_arg: Optional[str]) -> str: + """Find a valid DREAMT root, or exit with guidance. + + Args: + root_arg: User-supplied ``--root`` value, or None. + + Returns: + Absolute path to the DREAMT version directory. + + Raises: + SystemExit: If no valid directory is found. + """ + candidates = ( + [root_arg] + if root_arg + else [ + DEFAULT_ROOT, + os.path.expanduser("~/data/dreamt"), + os.path.expanduser("~/dreamt"), + ] + ) + for path in candidates: + if path and os.path.isdir(path): + info = os.path.join(path, "participant_info.csv") + if os.path.isfile(info): + return path + for sub in sorted(os.listdir(path)): + subpath = os.path.join(path, sub) + if os.path.isdir(subpath) and os.path.isfile( + os.path.join(subpath, "participant_info.csv") + ): + return subpath + print( + "ERROR: Could not find the DREAMT dataset.\n" + "\n" + "Download from PhysioNet (credentialed access):\n" + " https://physionet.org/content/dreamt/\n" + "\n" + "Then either:\n" + f" - Extract to {DEFAULT_ROOT}/\n" + " - Or pass --root /path/to/dreamt/version/\n" + "\n" + "The directory must contain participant_info.csv\n" + "and a data_64Hz/ folder with per-participant CSVs." + ) + raise SystemExit(1) + + +def _run_ablations_real(args: argparse.Namespace) -> None: + """Run ablations on the real DREAMT dataset. + + Args: + args: Parsed command-line arguments. + """ + from pyhealth.datasets import DREAMTDataset + + root = _resolve_root(args.root) + print(f"Loading DREAMT dataset from {root} ...") + dataset = DREAMTDataset(root=root) + + print("\n" + "=" * 60) + print("ABLATION 1: Signal Subset (2-class wake/sleep)") + print("=" * 60) + + for subset_name, columns in SIGNAL_SUBSETS.items(): + print(f"\n--- Signal subset: {subset_name} ---") + task = SleepStagingDREAMT( + n_classes=2, + signal_columns=columns, + ) + sample_ds = dataset.set_task(task) + raw = [sample_ds[i] for i in range(len(sample_ds))] + print(f" Total samples: {len(raw)}") + results = run_config( + raw, + num_classes=2, + epochs=args.epochs, + hidden_dim=args.hidden_dim, + device=args.device, + ) + print(f" Results: {results}") + + print("\n" + "=" * 60) + print("ABLATION 2: Label Granularity (ALL signals)") + print("=" * 60) + + for nc in [2, 5]: + print(f"\n--- {nc}-class ---") + task = SleepStagingDREAMT(n_classes=nc) + sd = dataset.set_task(task) + raw = [sd[i] for i in range(len(sd))] + print(f" Total samples: {len(raw)}") + results = run_config( + raw, + num_classes=nc, + epochs=args.epochs, + hidden_dim=args.hidden_dim, + device=args.device, + ) + print(f" Results: {results}") + + +def _run_ablations_demo(args: argparse.Namespace) -> None: + """Run ablations on synthetic demo data. + + Args: + args: Parsed command-line arguments. + """ + print("=== DEMO MODE (synthetic data) ===\n") + print("Generating 6 synthetic patients (15 epochs each) ...") + + demo_epochs = min(args.epochs, 2) + + print("\n" + "=" * 60) + print("ABLATION 1: Signal Subset (2-class, demo)") + print("=" * 60) + + for subset_name, columns in SIGNAL_SUBSETS.items(): + print(f"\n--- Signal subset: {subset_name} ---") + seed = abs(hash(subset_name)) % (2**31) + raw = _generate_demo_samples( + n_classes=2, + signal_columns=columns, + n_patients=6, + seed=seed, + ) + print(f" Total samples: {len(raw)}") + results = run_config( + raw, + num_classes=2, + epochs=demo_epochs, + hidden_dim=args.hidden_dim, + device=args.device, + split_ratios=[0.5, 0.17, 0.33], + ) + print(f" Results: {results}") + + print("\n" + "=" * 60) + print("ABLATION 2: Label Granularity (demo)") + print("=" * 60) + + for nc in [2, 5]: + print(f"\n--- {nc}-class ---") + raw = _generate_demo_samples( + n_classes=nc, n_patients=6, seed=123, + ) + print(f" Total samples: {len(raw)}") + results = run_config( + raw, + num_classes=nc, + epochs=demo_epochs, + hidden_dim=args.hidden_dim, + device=args.device, + split_ratios=[0.5, 0.17, 0.33], + ) + print(f" Results: {results}") + + print("\nDemo complete.") + + +def main() -> None: + """Entry point for the DREAMT sleep staging ablation study.""" + parser = argparse.ArgumentParser( + description="DREAMT sleep staging ablation (PyHealth RNN)", + ) + parser.add_argument( + "--root", + default=None, + help=( + "Path to DREAMT dataset. " + f"Default: {DEFAULT_ROOT}" + ), + ) + parser.add_argument( + "--demo", + action="store_true", + help=( + "Run with synthetic data instead of real " + "DREAMT (no dataset download required)." + ), + ) + parser.add_argument( + "--epochs", + type=int, + default=30, + help="Training epochs (default: 30)", + ) + parser.add_argument( + "--hidden_dim", + type=int, + default=64, + help="RNN hidden dimension (default: 64)", + ) + parser.add_argument( + "--device", + default="cpu", + help="Device: cpu or cuda (default: cpu)", + ) + args = parser.parse_args() + + if args.demo: + _run_ablations_demo(args) + else: + _run_ablations_real(args) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..eee3f8fef 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -66,3 +66,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .sleep_staging_dreamt import SleepStagingDREAMT diff --git a/pyhealth/tasks/sleep_staging_dreamt.py b/pyhealth/tasks/sleep_staging_dreamt.py new file mode 100644 index 000000000..21477e8cd --- /dev/null +++ b/pyhealth/tasks/sleep_staging_dreamt.py @@ -0,0 +1,379 @@ +"""Sleep staging task for the DREAMT dataset. + +This task processes overnight Empatica E4 wearable recordings from the +DREAMT dataset into per-epoch raw signal windows with sleep stage labels, +suitable for deep learning models. + +The preprocessing follows the methodology described in: + + Wang et al. "Addressing wearable sleep tracking inequity: a new + dataset and novel methods for a population with sleep disorders." + CHIL 2024, PMLR 248:380-396. + +Signal-specific preprocessing: + - **ACC (ACC_X, ACC_Y, ACC_Z)**: 5th-order Butterworth bandpass + filter, 3-11 Hz (Altini & Kinnunen 2021). + - **BVP**: Chebyshev Type II bandpass filter, 0.5-20 Hz. + - **TEMP**: Winsorized (clipped) to [31, 40] degrees C. + - **EDA, HR, IBI**: No additional filtering. +""" + +import logging +from typing import Any, Dict, List, Optional + +import numpy as np +import pandas as pd +from scipy.signal import butter, cheby2, filtfilt + +from pyhealth.tasks.base_task import BaseTask + +logger = logging.getLogger(__name__) + +# Label mappings per classification granularity +LABEL_MAP_5CLASS: Dict[str, int] = { + "W": 0, + "N1": 1, + "N2": 2, + "N3": 3, + "R": 4, +} + +LABEL_MAP_3CLASS: Dict[str, int] = { + "W": 0, + "N1": 1, + "N2": 1, + "N3": 1, + "R": 2, +} + +LABEL_MAP_2CLASS: Dict[str, int] = { + "W": 0, + "N1": 1, + "N2": 1, + "N3": 1, + "R": 1, +} + +LABEL_MAPS: Dict[int, Dict[str, int]] = { + 5: LABEL_MAP_5CLASS, + 3: LABEL_MAP_3CLASS, + 2: LABEL_MAP_2CLASS, +} + +ALL_SIGNAL_COLUMNS: List[str] = [ + "BVP", + "ACC_X", + "ACC_Y", + "ACC_Z", + "EDA", + "TEMP", + "HR", + "IBI", +] + +# Excluded sleep stages +EXCLUDED_STAGES = {"P", "Missing"} + + +def _apply_butterworth_bandpass( + signal: np.ndarray, + lowcut: float, + highcut: float, + fs: int, + order: int = 5, +) -> np.ndarray: + """Apply a Butterworth bandpass filter. + + Args: + signal: 1-D array of signal values. + lowcut: Lower cutoff frequency in Hz. + highcut: Upper cutoff frequency in Hz. + fs: Sampling rate in Hz. + order: Filter order. + + Returns: + Filtered signal array. + """ + nyq = 0.5 * fs + b, a = butter(order, [lowcut / nyq, highcut / nyq], btype="band") + try: + return filtfilt(b, a, signal) + except ValueError: + logger.warning( + "Butterworth filtfilt failed (likely too few samples); " + "returning raw signal." + ) + return signal + + +def _apply_chebyshev_bandpass( + signal: np.ndarray, + lowcut: float, + highcut: float, + fs: int, + order: int = 4, + rs: float = 40.0, +) -> np.ndarray: + """Apply a Chebyshev Type II bandpass filter. + + Args: + signal: 1-D array of signal values. + lowcut: Lower cutoff frequency in Hz. + highcut: Upper cutoff frequency in Hz. + fs: Sampling rate in Hz. + order: Filter order. + rs: Minimum stopband attenuation in dB. + + Returns: + Filtered signal array. + """ + nyq = 0.5 * fs + b, a = cheby2(order, rs, [lowcut / nyq, highcut / nyq], btype="band") + try: + return filtfilt(b, a, signal) + except ValueError: + logger.warning( + "Chebyshev filtfilt failed (likely too few samples); " + "returning raw signal." + ) + return signal + + +def _apply_filters( + epoch: np.ndarray, + columns: List[str], + fs: int, +) -> np.ndarray: + """Apply signal-specific filters to one epoch. + + Args: + epoch: Array of shape ``(n_channels, epoch_len)``. + columns: Column names corresponding to each channel. + fs: Sampling rate in Hz. + + Returns: + Filtered epoch array of the same shape. + """ + filtered = epoch.copy() + for i, col in enumerate(columns): + if col in ("ACC_X", "ACC_Y", "ACC_Z"): + filtered[i] = _apply_butterworth_bandpass( + filtered[i], lowcut=3.0, highcut=11.0, fs=fs + ) + elif col == "BVP": + filtered[i] = _apply_chebyshev_bandpass( + filtered[i], lowcut=0.5, highcut=20.0, fs=fs + ) + elif col == "TEMP": + filtered[i] = np.clip(filtered[i], 31.0, 40.0) + # EDA, HR, IBI: no additional filtering + return filtered + + +class SleepStagingDREAMT(BaseTask): + """Sleep staging task for the DREAMT dataset. + + Transforms each participant's overnight Empatica E4 recording into + non-overlapping 30-second epochs of raw multi-channel signal data + with integer sleep stage labels. Supports 5-class, 3-class, and + 2-class classification granularities. + + Signal-specific preprocessing (from Wang et al. CHIL 2024): + + - **ACC**: 5th-order Butterworth bandpass, 3-11 Hz + - **BVP**: Chebyshev Type II bandpass, 0.5-20 Hz + - **TEMP**: Winsorized to [31, 40] C + + Epochs labeled ``"P"`` or ``"Missing"`` are excluded. + + Attributes: + task_name: ``"SleepStagingDREAMT"`` + input_schema: ``{"signal": "tensor"}`` + output_schema: ``{"label": "multiclass"}`` + + Args: + n_classes: Number of classification classes. Must be one of + ``{2, 3, 5}``. Default ``5``. + + - **5-class**: W=0, N1=1, N2=2, N3=3, R=4 + - **3-class**: W=0, NREM(N1/N2/N3)=1, REM=2 + - **2-class**: W=0, Sleep(N1/N2/N3/R)=1 + + signal_columns: List of signal column names to include. + Default includes all 8 channels: ``["BVP", "ACC_X", + "ACC_Y", "ACC_Z", "EDA", "TEMP", "HR", "IBI"]``. + epoch_seconds: Duration of each epoch in seconds. + Default ``30.0``. + sampling_rate: Sampling rate in Hz. Default ``64``. + apply_filters: Whether to apply signal-specific filters. + Default ``True``. + + Examples: + >>> from pyhealth.datasets import DREAMTDataset + >>> ds = DREAMTDataset(root="/path/to/dreamt/2.1.0") + >>> task = SleepStagingDREAMT(n_classes=5) + >>> sample_ds = ds.set_task(task) + >>> sample = sample_ds.samples[0] + >>> sample.keys() + dict_keys(['patient_id', 'signal', 'label', 'epoch_index']) + >>> sample["signal"].shape # (8, 1920) for 8 channels, 30s * 64Hz + (8, 1920) + + >>> # 2-class (wake vs sleep) with ACC channels only + >>> task_binary = SleepStagingDREAMT( + ... n_classes=2, + ... signal_columns=["ACC_X", "ACC_Y", "ACC_Z"], + ... ) + >>> sample_ds = ds.set_task(task_binary) + >>> sample_ds.samples[0]["signal"].shape + (3, 1920) + """ + + task_name: str = "SleepStagingDREAMT" + input_schema: Dict[str, str] = {"signal": "tensor"} + output_schema: Dict[str, str] = {"label": "multiclass"} + + def __init__( + self, + n_classes: int = 5, + signal_columns: Optional[List[str]] = None, + epoch_seconds: float = 30.0, + sampling_rate: int = 64, + apply_filters: bool = True, + ) -> None: + if n_classes not in {2, 3, 5}: + raise ValueError( + f"n_classes must be one of {{2, 3, 5}}, got {n_classes}" + ) + self.n_classes = n_classes + self.signal_columns = ( + list(signal_columns) if signal_columns is not None + else list(ALL_SIGNAL_COLUMNS) + ) + self.epoch_seconds = epoch_seconds + self.sampling_rate = sampling_rate + self.apply_filters = apply_filters + self.epoch_len = int(epoch_seconds * sampling_rate) + self.label_map = LABEL_MAPS[n_classes] + super().__init__() + + def __call__( + self, + patient: Any, + ) -> List[Dict[str, Any]]: + """Process one DREAMT patient into epoch samples. + + Args: + patient: A ``Patient`` object from ``DREAMTDataset``. + + Returns: + List of sample dicts, each containing: + + - ``patient_id`` (str): Patient identifier. + - ``signal`` (np.ndarray): Shape ``(n_channels, epoch_len)``, + dtype ``float32``. + - ``label`` (int): Integer sleep stage label. + - ``epoch_index`` (int): Sequential epoch position within + the recording. + """ + pid: str = patient.patient_id + + try: + events = patient.get_events(event_type="dreamt_sleep") + except (TypeError, KeyError): + events = patient.get_events() + + if not events: + return [] + + event = events[0] + file_path = getattr(event, "file_64hz", None) + if file_path is None or ( + isinstance(file_path, str) and file_path.lower() == "none" + ): + return [] + + try: + df = pd.read_csv(str(file_path)) + except (FileNotFoundError, pd.errors.EmptyDataError, OSError) as exc: + logger.warning("Could not read %s: %s", file_path, exc) + return [] + + # Build case-insensitive column mapping + col_map: Dict[str, str] = { + col.lower(): col for col in df.columns + } + + # Check required columns exist + required = set(c.lower() for c in self.signal_columns) | { + "sleep_stage" + } + missing = [c for c in required if c not in col_map] + if missing: + logger.warning( + "Patient %s missing columns: %s", pid, missing + ) + return [] + + # Resolve actual column names + stage_col = col_map["sleep_stage"] + signal_col_names = [ + col_map[c.lower()] for c in self.signal_columns + ] + + # Drop excluded stages + mask = ~df[stage_col].isin(EXCLUDED_STAGES) + df = df.loc[mask].reset_index(drop=True) + + if len(df) < self.epoch_len: + return [] + + n_epochs = len(df) // self.epoch_len + samples: List[Dict[str, Any]] = [] + epoch_counter = 0 + + for i in range(n_epochs): + start = i * self.epoch_len + end = start + self.epoch_len + epoch_df = df.iloc[start:end] + + # Label from the middle of the epoch + mid = start + self.epoch_len // 2 + stage = str(df[stage_col].iloc[mid]).strip() + + if stage not in self.label_map: + continue + + label = self.label_map[stage] + + # Extract signal channels as (n_channels, epoch_len) + signal = np.stack( + [ + epoch_df[col].values.astype(np.float64) + for col in signal_col_names + ], + axis=0, + ) + + # Replace NaN with 0 + np.nan_to_num(signal, nan=0.0, copy=False) + + # Apply signal-specific filters + if self.apply_filters: + signal = _apply_filters( + signal, self.signal_columns, self.sampling_rate + ) + + signal = signal.astype(np.float32) + + samples.append( + { + "patient_id": pid, + "signal": signal, + "label": label, + "epoch_index": epoch_counter, + } + ) + epoch_counter += 1 + + return samples diff --git a/tests/core/test_sleep_staging_dreamt.py b/tests/core/test_sleep_staging_dreamt.py new file mode 100644 index 000000000..c2c8671ab --- /dev/null +++ b/tests/core/test_sleep_staging_dreamt.py @@ -0,0 +1,437 @@ +"""Tests for SleepStagingDREAMT task. + +All tests use in-memory fake patients with small temporary CSV files. +No real DREAMT data is required. Tests complete in milliseconds. +""" + +import os +from types import SimpleNamespace +from typing import List, Optional + +import numpy as np +import pandas as pd +import pytest + +from pyhealth.tasks.sleep_staging_dreamt import ( + ALL_SIGNAL_COLUMNS, + SleepStagingDREAMT, +) + +EPOCH_LEN = 30 * 64 # 1920 samples + + +# ----------------------------------------------------------- +# Helpers +# ----------------------------------------------------------- + + +def _make_csv( + n_epochs: int, + stages: List[str], + tmpdir: str, + patient_id: str = "S001", +) -> str: + """Create a synthetic 64 Hz CSV with ``n_epochs`` epochs. + + Args: + n_epochs: Number of 30-second epochs to generate. + stages: Sleep stage labels, one per epoch (cycled). + tmpdir: Directory to write the CSV into. + patient_id: Used in the filename. + + Returns: + Absolute path to the written CSV file. + """ + rng = np.random.RandomState(42) + rows = n_epochs * EPOCH_LEN + data = { + "TIMESTAMP": np.arange(rows) / 64.0, + "BVP": rng.randn(rows) * 50, + "IBI": np.clip(rng.rand(rows) * 0.2 + 0.7, 0, 2), + "EDA": rng.rand(rows) * 5 + 0.1, + "TEMP": rng.rand(rows) * 15 + 28, # range includes <31 and >40 + "ACC_X": rng.randn(rows) * 10, + "ACC_Y": rng.randn(rows) * 10, + "ACC_Z": rng.randn(rows) * 10, + "HR": rng.rand(rows) * 30 + 60, + } + + stage_col = [] + for i in range(n_epochs): + stage = stages[i % len(stages)] + stage_col.extend([stage] * EPOCH_LEN) + data["Sleep_Stage"] = stage_col + + df = pd.DataFrame(data) + path = os.path.join(tmpdir, f"{patient_id}_whole_df.csv") + df.to_csv(path, index=False) + return path + + +def _make_patient( + file_path: Optional[str], + patient_id: str = "S001", +) -> SimpleNamespace: + """Build a mock Patient mimicking DREAMTDataset. + + Args: + file_path: Path to the CSV, or None for an empty patient. + patient_id: Identifier for the mock patient. + + Returns: + SimpleNamespace matching the DREAMT patient contract. + """ + event = SimpleNamespace(file_64hz=file_path) + patient = SimpleNamespace( + patient_id=patient_id, + get_events=lambda event_type=None: [event], + ) + return patient + + +# ----------------------------------------------------------- +# Tests — Initialization +# ----------------------------------------------------------- + + +class TestInit: + """Task initialization tests.""" + + def test_default_params(self): + """Default init uses 5 classes and all 8 channels.""" + task = SleepStagingDREAMT() + assert task.n_classes == 5 + assert task.signal_columns == list(ALL_SIGNAL_COLUMNS) + assert task.epoch_seconds == 30.0 + assert task.sampling_rate == 64 + assert task.apply_filters is True + assert task.epoch_len == 1920 + + def test_custom_params(self): + """Custom init parameters are stored correctly.""" + task = SleepStagingDREAMT( + n_classes=2, + signal_columns=["ACC_X", "ACC_Y"], + epoch_seconds=15.0, + sampling_rate=32, + apply_filters=False, + ) + assert task.n_classes == 2 + assert task.signal_columns == ["ACC_X", "ACC_Y"] + assert task.epoch_seconds == 15.0 + assert task.sampling_rate == 32 + assert task.apply_filters is False + assert task.epoch_len == 480 + + def test_invalid_n_classes_raises(self): + """n_classes not in {2, 3, 5} raises ValueError.""" + with pytest.raises(ValueError, match="n_classes"): + SleepStagingDREAMT(n_classes=4) + + def test_invalid_n_classes_other(self): + """n_classes=1 also raises ValueError.""" + with pytest.raises(ValueError): + SleepStagingDREAMT(n_classes=1) + + def test_class_attributes(self): + """Task has correct class-level attributes.""" + task = SleepStagingDREAMT() + assert task.task_name == "SleepStagingDREAMT" + assert task.input_schema == {"signal": "tensor"} + assert task.output_schema == {"label": "multiclass"} + + +# ----------------------------------------------------------- +# Tests — 5-class +# ----------------------------------------------------------- + + +class TestFiveClass: + """5-class sleep staging tests.""" + + def test_sample_count(self, tmp_path): + """Correct number of valid epochs returned.""" + stages = ["W", "N1", "N2", "N3", "R"] + csv = _make_csv(5, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(n_classes=5, apply_filters=False) + samples = task(patient) + assert len(samples) == 5 + + def test_label_mapping(self, tmp_path): + """5-class maps W=0, N1=1, N2=2, N3=3, R=4.""" + stages = ["W", "N1", "N2", "N3", "R"] + csv = _make_csv(5, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(n_classes=5, apply_filters=False) + samples = task(patient) + labels = [s["label"] for s in samples] + assert labels == [0, 1, 2, 3, 4] + + def test_signal_shape(self, tmp_path): + """Signal shape is (n_channels, 1920).""" + stages = ["W", "N2"] + csv = _make_csv(2, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(n_classes=5, apply_filters=False) + samples = task(patient) + assert samples[0]["signal"].shape == (8, 1920) + assert samples[0]["signal"].dtype == np.float32 + + def test_patient_id(self, tmp_path): + """Samples carry the correct patient_id.""" + csv = _make_csv( + 2, ["W", "N1"], str(tmp_path), patient_id="S042" + ) + patient = _make_patient(csv, patient_id="S042") + task = SleepStagingDREAMT(apply_filters=False) + samples = task(patient) + assert all(s["patient_id"] == "S042" for s in samples) + + def test_epoch_indices_sequential(self, tmp_path): + """epoch_index is sequential starting from 0.""" + stages = ["W", "N1", "N2", "N3", "R"] * 3 + csv = _make_csv(15, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(apply_filters=False) + samples = task(patient) + indices = [s["epoch_index"] for s in samples] + assert indices == list(range(len(samples))) + + +# ----------------------------------------------------------- +# Tests — 3-class +# ----------------------------------------------------------- + + +class TestThreeClass: + """3-class sleep staging tests.""" + + def test_label_mapping(self, tmp_path): + """3-class maps W=0, NREM=1, REM=2.""" + stages = ["W", "N1", "N2", "N3", "R"] + csv = _make_csv(5, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(n_classes=3, apply_filters=False) + samples = task(patient) + labels = [s["label"] for s in samples] + assert labels == [0, 1, 1, 1, 2] + + +# ----------------------------------------------------------- +# Tests — 2-class +# ----------------------------------------------------------- + + +class TestTwoClass: + """2-class (wake vs sleep) tests.""" + + def test_label_mapping(self, tmp_path): + """2-class maps W=0, all sleep=1.""" + stages = ["W", "N1", "N2", "N3", "R"] + csv = _make_csv(5, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(n_classes=2, apply_filters=False) + samples = task(patient) + labels = [s["label"] for s in samples] + assert labels == [0, 1, 1, 1, 1] + + +# ----------------------------------------------------------- +# Tests — Stage exclusion +# ----------------------------------------------------------- + + +class TestStageExclusion: + """P and Missing stage exclusion tests.""" + + def test_p_stage_excluded(self, tmp_path): + """Epochs with P (preparation) stage are dropped.""" + stages = ["P", "P", "W", "N1"] + csv = _make_csv(4, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(apply_filters=False) + samples = task(patient) + labels = [s["label"] for s in samples] + # After dropping P rows, remaining data may yield fewer epochs + # All returned labels should be valid (no P) + assert all(lbl in {0, 1, 2, 3, 4} for lbl in labels) + + def test_missing_stage_excluded(self, tmp_path): + """Epochs with Missing stage are dropped.""" + stages = ["W", "Missing", "N2"] + csv = _make_csv(3, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(apply_filters=False) + samples = task(patient) + # Missing rows are dropped, so we may get fewer epochs + for s in samples: + assert s["label"] in {0, 1, 2, 3, 4} + + def test_all_p_returns_empty(self, tmp_path): + """Patient with only P stages returns empty list.""" + stages = ["P", "P", "P"] + csv = _make_csv(3, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(apply_filters=False) + samples = task(patient) + assert samples == [] + + def test_all_missing_returns_empty(self, tmp_path): + """Patient with only Missing stages returns empty list.""" + stages = ["Missing", "Missing"] + csv = _make_csv(2, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(apply_filters=False) + samples = task(patient) + assert samples == [] + + +# ----------------------------------------------------------- +# Tests — Signal column subsetting +# ----------------------------------------------------------- + + +class TestSignalSubset: + """Signal column selection tests.""" + + def test_acc_only(self, tmp_path): + """Selecting ACC channels gives shape (3, 1920).""" + stages = ["W", "N2"] + csv = _make_csv(2, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT( + signal_columns=["ACC_X", "ACC_Y", "ACC_Z"], + apply_filters=False, + ) + samples = task(patient) + assert samples[0]["signal"].shape == (3, 1920) + + def test_single_channel(self, tmp_path): + """Single channel gives shape (1, 1920).""" + stages = ["W"] + csv = _make_csv(1, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT( + signal_columns=["HR"], + apply_filters=False, + ) + samples = task(patient) + assert samples[0]["signal"].shape == (1, 1920) + + def test_bvp_temp(self, tmp_path): + """Custom subset of BVP + TEMP gives shape (2, 1920).""" + stages = ["N3"] + csv = _make_csv(1, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT( + signal_columns=["BVP", "TEMP"], + apply_filters=False, + ) + samples = task(patient) + assert samples[0]["signal"].shape == (2, 1920) + + +# ----------------------------------------------------------- +# Tests — Filtering +# ----------------------------------------------------------- + + +class TestFiltering: + """Signal filtering tests.""" + + def test_filters_run_without_error(self, tmp_path): + """Filters execute without raising exceptions.""" + stages = ["W", "N2", "R"] + csv = _make_csv(3, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(apply_filters=True) + samples = task(patient) + assert len(samples) > 0 + assert samples[0]["signal"].shape == (8, 1920) + + def test_temp_winsorization(self, tmp_path): + """TEMP values are clipped to [31, 40] after filtering.""" + stages = ["W"] + csv = _make_csv(1, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT( + signal_columns=["TEMP"], + apply_filters=True, + ) + samples = task(patient) + temp_signal = samples[0]["signal"][0] + assert np.all(temp_signal >= 31.0) + assert np.all(temp_signal <= 40.0) + + def test_filters_disabled(self, tmp_path): + """With apply_filters=False, TEMP is not clipped.""" + # Generate TEMP data that will have values outside [31, 40] + stages = ["W"] + csv = _make_csv(1, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT( + signal_columns=["TEMP"], + apply_filters=False, + ) + samples = task(patient) + temp_signal = samples[0]["signal"][0] + # Synthetic data has range ~[28, 43], so some should be outside + has_below = np.any(temp_signal < 31.0) + has_above = np.any(temp_signal > 40.0) + assert has_below or has_above + + +# ----------------------------------------------------------- +# Tests — Edge cases +# ----------------------------------------------------------- + + +class TestEdgeCases: + """Edge case handling tests.""" + + def test_empty_patient_no_file(self): + """Patient with no file returns empty list.""" + patient = _make_patient(None, patient_id="S_EMPTY") + task = SleepStagingDREAMT() + samples = task(patient) + assert samples == [] + + def test_empty_patient_none_string(self, tmp_path): + """Patient with file_64hz='None' returns empty list.""" + patient = _make_patient("None", patient_id="S_NONE") + task = SleepStagingDREAMT() + samples = task(patient) + assert samples == [] + + def test_multi_patient_isolation(self, tmp_path): + """Each patient's samples reference only its own id.""" + csv_a = _make_csv( + 3, ["W", "N1", "N2"], str(tmp_path), patient_id="P_A" + ) + csv_b = _make_csv( + 2, ["N3", "R"], str(tmp_path), patient_id="P_B" + ) + patient_a = _make_patient(csv_a, patient_id="P_A") + patient_b = _make_patient(csv_b, patient_id="P_B") + + task = SleepStagingDREAMT(apply_filters=False) + samples_a = task(patient_a) + samples_b = task(patient_b) + + assert all(s["patient_id"] == "P_A" for s in samples_a) + assert all(s["patient_id"] == "P_B" for s in samples_b) + + # epoch_index restarts at 0 for each patient + if samples_a: + assert samples_a[0]["epoch_index"] == 0 + if samples_b: + assert samples_b[0]["epoch_index"] == 0 + + def test_nonexistent_file(self, tmp_path): + """Nonexistent CSV path returns empty list.""" + fake_path = os.path.join(str(tmp_path), "does_not_exist.csv") + patient = _make_patient(fake_path) + task = SleepStagingDREAMT() + samples = task(patient) + assert samples == []