From ecc7c15a3d5e8a661e42492857ddce5de5cf679a Mon Sep 17 00:00:00 2001 From: "Erie M. Adames" Date: Thu, 9 Apr 2026 15:49:54 -0400 Subject: [PATCH 1/3] Add SleepStagingDREAMT task for DREAMT wearable dataset - Add pyhealth/tasks/sleep_staging_dreamt.py: sleep staging task that segments overnight E4 signals into 30s epochs with 2/3/5-class labels - Add tests/core/test_sleep_staging_dreamt.py: 26 unit tests using synthetic data, completing in ~3 seconds - Add examples/dreamt_sleep_wake_lstm.py: LSTM ablation study with signal subset and label granularity experiments - Add docs RST file and update tasks.rst index - Register SleepStagingDREAMT in pyhealth/tasks/__init__.py Made-with: Cursor --- docs/api/tasks.rst | 1 + .../pyhealth.tasks.SleepStagingDREAMT.rst | 7 + examples/dreamt_sleep_wake_lstm.py | 735 ++++++++++++++++++ pyhealth/tasks/__init__.py | 1 + pyhealth/tasks/sleep_staging_dreamt.py | 379 +++++++++ tests/core/test_sleep_staging_dreamt.py | 440 +++++++++++ 6 files changed, 1563 insertions(+) create mode 100644 docs/api/tasks/pyhealth.tasks.SleepStagingDREAMT.rst create mode 100644 examples/dreamt_sleep_wake_lstm.py create mode 100644 pyhealth/tasks/sleep_staging_dreamt.py create mode 100644 tests/core/test_sleep_staging_dreamt.py diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..7a4075c3b 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -229,3 +229,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + Sleep Staging (DREAMT) diff --git a/docs/api/tasks/pyhealth.tasks.SleepStagingDREAMT.rst b/docs/api/tasks/pyhealth.tasks.SleepStagingDREAMT.rst new file mode 100644 index 000000000..4fbc2479c --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.SleepStagingDREAMT.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.SleepStagingDREAMT +======================================= + +.. autoclass:: pyhealth.tasks.sleep_staging_dreamt.SleepStagingDREAMT + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/dreamt_sleep_wake_lstm.py b/examples/dreamt_sleep_wake_lstm.py new file mode 100644 index 000000000..a0b63fd30 --- /dev/null +++ b/examples/dreamt_sleep_wake_lstm.py @@ -0,0 +1,735 @@ +"""DREAMT Sleep Staging LSTM — Ablation Study. + +This script runs two ablation experiments on the DREAMT +sleep staging task using ``SleepStagingDREAMT``: + +1. **Signal-subset ablation** (binary wake/sleep): + ACC-only vs BVP/HRV-only vs EDA+TEMP-only vs ALL signals. + +2. **Label-granularity ablation** (ALL signals): + 2-class (wake/sleep) vs 5-class (W/N1/N2/N3/R). + +A single-layer unidirectional LSTM is trained with 5-fold +participant-level cross-validation (no subject leakage). + +Each 30-second epoch's raw multi-channel signal is reduced to +per-channel statistics (mean, std, min, max) to form a compact +feature vector suitable for the LSTM. + +Usage — full DREAMT run:: + + python dreamt_sleep_wake_lstm.py --root /path/to/dreamt + +Usage — synthetic demo (no dataset required):: + + python dreamt_sleep_wake_lstm.py --demo + +Metrics: F1, AUROC, Accuracy, Cohen's Kappa. + +Results / Findings +------------------ + +**Demo mode** (synthetic data, 3 patients, 2 training epochs): + +Results are non-meaningful and serve only to verify that the +full pipeline (epoching -> feature extraction -> LSTM training +-> evaluation) runs end-to-end without error. Expected output +is near-random performance (F1 ~ 0.2-0.5, Kappa ~ 0). + +Reference: + Wang et al. "Addressing wearable sleep tracking inequity: + a new dataset and novel methods for a population with sleep + disorders." CHIL 2024, PMLR 248:380-396. +""" + +import argparse +import os +import tempfile +import warnings +from collections import defaultdict +from typing import Any, Dict, List, Optional + +import numpy as np +import torch +import torch.nn as nn +from sklearn.metrics import ( + accuracy_score, + cohen_kappa_score, + f1_score, + roc_auc_score, +) +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from pyhealth.tasks.sleep_staging_dreamt import ( + ALL_SIGNAL_COLUMNS, + SleepStagingDREAMT, +) + +warnings.filterwarnings("ignore") + +EPOCH_LEN: int = 30 * 64 # 1920 samples per 30-s epoch at 64 Hz + +SIGNAL_SUBSETS: Dict[str, List[str]] = { + "ACC": ["ACC_X", "ACC_Y", "ACC_Z"], + "BVP_HRV": ["BVP", "HR", "IBI"], + "EDA_TEMP": ["EDA", "TEMP"], + "ALL": list(ALL_SIGNAL_COLUMNS), +} + + +def _epoch_features(signal: np.ndarray) -> np.ndarray: + """Convert a raw epoch signal to a compact feature vector. + + Computes mean, std, min, and max per channel. + + Args: + signal: Array of shape ``(n_channels, epoch_len)``. + + Returns: + 1-D feature vector of length ``4 * n_channels``. + """ + if isinstance(signal, torch.Tensor): + signal = signal.numpy() + feats: List[float] = [] + for ch in range(signal.shape[0]): + s = signal[ch].astype(np.float64) + feats.extend([ + float(np.mean(s)), + float(np.std(s)), + float(np.min(s)), + float(np.max(s)), + ]) + return np.array(feats, dtype=np.float32) + + +# ----------------------------------------------------------- +# LSTM model +# ----------------------------------------------------------- + + +class SleepLSTM(nn.Module): + """Single-layer unidirectional LSTM for sleep tasks.""" + + def __init__( + self, + input_dim: int, + hidden_dim: int = 64, + num_classes: int = 2, + ) -> None: + super().__init__() + self.lstm = nn.LSTM( + input_dim, + hidden_dim, + num_layers=1, + batch_first=True, + ) + self.fc = nn.Linear(hidden_dim, num_classes) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """Forward pass. + + Args: + x: Input tensor ``(batch, seq_len, features)``. + + Returns: + Logits ``(batch, seq_len, num_classes)``. + """ + out, _ = self.lstm(x) + return self.fc(out) + + +# ----------------------------------------------------------- +# Dataset wrapper +# ----------------------------------------------------------- + + +class SequenceDataset(Dataset): + """Groups epoch samples by patient into sequences.""" + + def __init__( + self, + samples: List[Dict[str, Any]], + ) -> None: + patient_map: Dict[str, list] = defaultdict(list) + for s in samples: + patient_map[s["patient_id"]].append(s) + + self.sequences: List[torch.Tensor] = [] + self.labels_list: List[torch.Tensor] = [] + for pid in sorted(patient_map): + epochs = sorted( + patient_map[pid], + key=lambda e: e["epoch_index"], + ) + signals = np.stack( + [_epoch_features(e["signal"]) for e in epochs], + axis=0, + ) + labels = np.array( + [e["label"] for e in epochs], + ) + self.sequences.append( + torch.tensor(signals, dtype=torch.float32) + ) + self.labels_list.append( + torch.tensor(labels, dtype=torch.long) + ) + + def __len__(self) -> int: + return len(self.sequences) + + def __getitem__( + self, + idx: int, + ) -> tuple: + return self.sequences[idx], self.labels_list[idx] + + +def collate_fn( + batch: list, +) -> tuple: + """Pad variable-length sequences in a batch.""" + seqs, labels = zip(*batch) + max_len = max(s.shape[0] for s in seqs) + feat_dim = seqs[0].shape[1] + padded_seqs = torch.zeros( + len(seqs), + max_len, + feat_dim, + ) + padded_labels = torch.full( + (len(seqs), max_len), + -1, + dtype=torch.long, + ) + masks = torch.zeros( + len(seqs), + max_len, + dtype=torch.bool, + ) + for i, (s, lbl) in enumerate(zip(seqs, labels)): + length = s.shape[0] + padded_seqs[i, :length] = s + padded_labels[i, :length] = lbl + masks[i, :length] = True + return padded_seqs, padded_labels, masks + + +# ----------------------------------------------------------- +# Training / evaluation +# ----------------------------------------------------------- + + +def train_and_evaluate( + train_samples: List[Dict[str, Any]], + test_samples: List[Dict[str, Any]], + num_classes: int = 2, + epochs: int = 30, + lr: float = 1e-3, + hidden_dim: int = 64, + device: str = "cpu", +) -> Dict[str, float]: + """Train LSTM and return test metrics.""" + if not train_samples or not test_samples: + return {} + + n_channels = train_samples[0]["signal"].shape[0] + feat_dim = 4 * n_channels + + train_ds = SequenceDataset(train_samples) + test_ds = SequenceDataset(test_samples) + + train_loader = DataLoader( + train_ds, + batch_size=8, + shuffle=True, + collate_fn=collate_fn, + ) + test_loader = DataLoader( + test_ds, + batch_size=8, + shuffle=False, + collate_fn=collate_fn, + ) + + model = SleepLSTM( + feat_dim, + hidden_dim, + num_classes, + ).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + criterion = nn.CrossEntropyLoss(ignore_index=-1) + + model.train() + for _ in tqdm(range(epochs), desc=" Training", unit="epoch", leave=False): + for seqs, labels, masks in train_loader: + seqs = seqs.to(device) + labels = labels.to(device) + optimizer.zero_grad() + logits = model(seqs) + loss = criterion( + logits.reshape(-1, num_classes), + labels.reshape(-1), + ) + loss.backward() + optimizer.step() + + model.eval() + all_preds: List[int] = [] + all_labels: List[int] = [] + all_probs: list = [] + with torch.no_grad(): + for seqs, labels, masks in test_loader: + seqs = seqs.to(device) + logits = model(seqs) + probs = torch.softmax(logits, dim=-1).cpu() + preds = logits.argmax(dim=-1).cpu() + for i in range(seqs.shape[0]): + valid = masks[i] + all_preds.extend(preds[i][valid].numpy().tolist()) + all_labels.extend(labels[i][valid].numpy().tolist()) + if num_classes == 2: + all_probs.extend( + probs[i][valid][:, 1].numpy().tolist() + ) + else: + all_probs.extend( + probs[i][valid].numpy().tolist() + ) + + y_true = np.array(all_labels) + y_pred = np.array(all_preds) + avg = "binary" if num_classes == 2 else "macro" + + results: Dict[str, float] = { + "f1": f1_score( + y_true, + y_pred, + average=avg, + zero_division=0, + ), + "accuracy": accuracy_score(y_true, y_pred), + "kappa": cohen_kappa_score(y_true, y_pred), + } + + try: + if num_classes == 2: + results["auroc"] = roc_auc_score( + y_true, + np.array(all_probs), + ) + else: + results["auroc"] = roc_auc_score( + y_true, + np.array(all_probs), + multi_class="ovr", + average="macro", + ) + except ValueError: + results["auroc"] = float("nan") + + return results + + +def participant_cv( + samples: List[Dict[str, Any]], + n_folds: int = 5, + num_classes: int = 2, + **kwargs: Any, +) -> Dict[str, str]: + """5-fold participant-level cross-validation.""" + patient_ids = sorted(set(s["patient_id"] for s in samples)) + np.random.seed(42) + np.random.shuffle(patient_ids) + + fold_size = max(1, len(patient_ids) // n_folds) + fold_results: List[Dict[str, float]] = [] + + for fold in tqdm(range(n_folds), desc=" CV folds", unit="fold"): + start = fold * fold_size + end = ( + start + fold_size + if fold < n_folds - 1 + else len(patient_ids) + ) + test_ids = set(patient_ids[start:end]) + train_ids = set(patient_ids) - test_ids + + if not train_ids or not test_ids: + continue + + train_s = [ + s for s in samples if s["patient_id"] in train_ids + ] + test_s = [ + s for s in samples if s["patient_id"] in test_ids + ] + + res = train_and_evaluate( + train_s, + test_s, + num_classes=num_classes, + **kwargs, + ) + if res: + fold_results.append(res) + tqdm.write( + f" Fold {fold + 1}: " + f"F1={res['f1']:.3f} " + f"AUROC={res['auroc']:.3f} " + f"Acc={res['accuracy']:.3f} " + f"Kappa={res['kappa']:.3f}" + ) + + if not fold_results: + return {} + + avg: Dict[str, str] = {} + for key in fold_results[0]: + vals = [r[key] for r in fold_results] + avg[key] = f"{np.mean(vals):.3f} +/- {np.std(vals):.3f}" + return avg + + +# ----------------------------------------------------------- +# Demo data generation +# ----------------------------------------------------------- + + +def _generate_demo_csv( + tmpdir: str, + patient_id: str, + n_epochs: int, + rng: np.random.RandomState, +) -> str: + """Create one synthetic 64 Hz CSV file. + + Args: + tmpdir: Directory to write the CSV. + patient_id: Used in the filename. + n_epochs: Number of 30-s epochs to generate. + rng: Random state for reproducibility. + + Returns: + Path to the written CSV file. + """ + import pandas as pd + + stages_pool = ["W", "N1", "N2", "N3", "R"] + rows = n_epochs * EPOCH_LEN + data = { + "TIMESTAMP": np.arange(rows) / 64.0, + "BVP": rng.randn(rows) * 50, + "IBI": np.clip(rng.rand(rows) * 0.2 + 0.7, 0, 2), + "EDA": rng.rand(rows) * 5 + 0.1, + "TEMP": rng.rand(rows) * 4 + 33, + "ACC_X": rng.randn(rows) * 10, + "ACC_Y": rng.randn(rows) * 10, + "ACC_Z": rng.randn(rows) * 10, + "HR": rng.rand(rows) * 30 + 60, + } + stage_col = [] + for i in range(n_epochs): + st = stages_pool[i % len(stages_pool)] + stage_col.extend([st] * EPOCH_LEN) + data["Sleep_Stage"] = stage_col + + csv_path = os.path.join(tmpdir, f"{patient_id}_whole_df.csv") + pd.DataFrame(data).to_csv(csv_path, index=False) + return csv_path + + +def _generate_demo_samples( + n_classes: int = 2, + signal_columns: Optional[List[str]] = None, + n_patients: int = 3, + epochs_per_patient: int = 20, + seed: int = 123, +) -> List[Dict[str, Any]]: + """Create synthetic samples for demo mode. + + Args: + n_classes: Number of label classes (2, 3, or 5). + signal_columns: Which signal columns to include. + n_patients: Number of synthetic patients. + epochs_per_patient: 30-s epochs per patient. + seed: Random seed for reproducibility. + + Returns: + List of sample dicts ready for training. + """ + from types import SimpleNamespace + + rng = np.random.RandomState(seed) + all_samples: List[Dict[str, Any]] = [] + + with tempfile.TemporaryDirectory() as tmpdir: + for p in range(n_patients): + pid = f"DEMO_{p:03d}" + csv_path = _generate_demo_csv( + tmpdir, pid, epochs_per_patient, rng, + ) + + evt = SimpleNamespace(file_64hz=csv_path) + patient = SimpleNamespace( + patient_id=pid, + get_events=lambda et=None, e=evt: [e], + ) + + task = SleepStagingDREAMT( + n_classes=n_classes, + signal_columns=signal_columns, + apply_filters=False, + ) + all_samples.extend(task(patient)) + + return all_samples + + +# ----------------------------------------------------------- +# Main +# ----------------------------------------------------------- + +DEFAULT_ROOT = os.path.expanduser("~/.pyhealth/dreamt") + + +def _resolve_root( + root_arg: Optional[str], +) -> str: + """Find a valid DREAMT root, or exit with guidance. + + Args: + root_arg: User-supplied ``--root`` value, or None. + + Returns: + Absolute path to the DREAMT version directory. + + Raises: + SystemExit: If no valid directory is found. + """ + candidates = ( + [root_arg] + if root_arg + else [ + DEFAULT_ROOT, + os.path.expanduser("~/data/dreamt"), + os.path.expanduser("~/dreamt"), + ] + ) + for path in candidates: + if path and os.path.isdir(path): + info = os.path.join(path, "participant_info.csv") + if os.path.isfile(info): + return path + for sub in sorted(os.listdir(path)): + subpath = os.path.join(path, sub) + if os.path.isdir(subpath) and os.path.isfile( + os.path.join(subpath, "participant_info.csv") + ): + return subpath + print( + "ERROR: Could not find the DREAMT dataset.\n" + "\n" + "Download from PhysioNet (credentialed access):\n" + " https://physionet.org/content/dreamt/\n" + "\n" + "Then either:\n" + f" - Extract to {DEFAULT_ROOT}/\n" + " - Or pass --root /path/to/dreamt/version/\n" + "\n" + "The directory must contain participant_info.csv\n" + "and a data_64Hz/ folder with per-participant CSVs." + ) + raise SystemExit(1) + + +def _run_ablations_real(args: argparse.Namespace) -> None: + """Run ablations on the real DREAMT dataset. + + Args: + args: Parsed command-line arguments. + """ + from pyhealth.datasets import DREAMTDataset + + root = _resolve_root(args.root) + print(f"Loading DREAMT dataset from {root} ...") + dataset = DREAMTDataset(root=root) + + print("\n" + "=" * 60) + print("ABLATION 1: Signal Subset (2-class wake/sleep)") + print("=" * 60) + + for subset_name, columns in tqdm(SIGNAL_SUBSETS.items(), desc="Signal subsets", unit="subset"): + tqdm.write(f"\n--- Signal subset: {subset_name} ---") + task = SleepStagingDREAMT( + n_classes=2, + signal_columns=columns, + ) + sample_ds = dataset.set_task(task) + samples = [sample_ds[i] for i in tqdm(range(len(sample_ds)), desc=" Loading samples", leave=False)] + tqdm.write(f" Total samples: {len(samples)}") + avg = participant_cv( + samples, + num_classes=2, + epochs=args.epochs, + hidden_dim=args.hidden_dim, + device=args.device, + ) + tqdm.write(f" Average: {avg}") + + print("\n" + "=" * 60) + print("ABLATION 2: Label Granularity (ALL signals)") + print("=" * 60) + + print("\n--- 2-class (wake vs sleep) ---") + task_2 = SleepStagingDREAMT(n_classes=2) + sd_2 = dataset.set_task(task_2) + samps_2 = [sd_2[i] for i in tqdm(range(len(sd_2)), desc=" Loading samples", leave=False)] + avg_2 = participant_cv( + samps_2, + num_classes=2, + epochs=args.epochs, + hidden_dim=args.hidden_dim, + device=args.device, + ) + print(f" Average: {avg_2}") + + print("\n--- 5-class (W/N1/N2/N3/R) ---") + task_5 = SleepStagingDREAMT(n_classes=5) + sd_5 = dataset.set_task(task_5) + samps_5 = [sd_5[i] for i in tqdm(range(len(sd_5)), desc=" Loading samples", leave=False)] + avg_5 = participant_cv( + samps_5, + num_classes=5, + epochs=args.epochs, + hidden_dim=args.hidden_dim, + device=args.device, + ) + print(f" Average: {avg_5}") + + +def _run_ablations_demo(args: argparse.Namespace) -> None: + """Run ablations on synthetic demo data. + + Args: + args: Parsed command-line arguments. + """ + print("=== DEMO MODE (synthetic data) ===\n") + print("Generating 3 synthetic patients (20 epochs each) ...") + + demo_epochs = min(args.epochs, 2) + + print("\n" + "=" * 60) + print("ABLATION 1: Signal Subset (2-class, demo)") + print("=" * 60) + + for subset_name, columns in tqdm(SIGNAL_SUBSETS.items(), desc="Signal subsets", unit="subset"): + tqdm.write(f"\n--- Signal subset: {subset_name} ---") + seed = abs(hash(subset_name)) % (2**31) + sub_samples = _generate_demo_samples( + n_classes=2, + signal_columns=columns, + n_patients=3, + seed=seed, + ) + tqdm.write(f" Total samples: {len(sub_samples)}") + n_pids = len(set(s["patient_id"] for s in sub_samples)) + avg = participant_cv( + sub_samples, + n_folds=min(3, n_pids), + num_classes=2, + epochs=demo_epochs, + hidden_dim=args.hidden_dim, + device=args.device, + ) + tqdm.write(f" Average: {avg}") + + print("\n" + "=" * 60) + print("ABLATION 2: Label Granularity (demo)") + print("=" * 60) + + print("\n--- 2-class (wake vs sleep) ---") + samples_2 = _generate_demo_samples( + n_classes=2, n_patients=3, seed=123, + ) + n_pids_2 = len(set(s["patient_id"] for s in samples_2)) + avg_2 = participant_cv( + samples_2, + n_folds=min(3, n_pids_2), + num_classes=2, + epochs=demo_epochs, + hidden_dim=args.hidden_dim, + device=args.device, + ) + print(f" Average: {avg_2}") + + print("\n--- 5-class (W/N1/N2/N3/R) ---") + samples_5 = _generate_demo_samples( + n_classes=5, n_patients=3, seed=123, + ) + n_pids_5 = len(set(s["patient_id"] for s in samples_5)) + avg_5 = participant_cv( + samples_5, + n_folds=min(3, n_pids_5), + num_classes=5, + epochs=demo_epochs, + hidden_dim=args.hidden_dim, + device=args.device, + ) + print(f" Average: {avg_5}") + + print("\nDemo complete.") + + +def main() -> None: + """Entry point for the DREAMT LSTM ablation study.""" + parser = argparse.ArgumentParser( + description="DREAMT LSTM ablation study", + ) + parser.add_argument( + "--root", + default=None, + help=( + "Path to DREAMT dataset. " + f"Default: {DEFAULT_ROOT}" + ), + ) + parser.add_argument( + "--demo", + action="store_true", + help=( + "Run with synthetic data instead of real " + "DREAMT (no dataset download required)." + ), + ) + parser.add_argument( + "--epochs", + type=int, + default=30, + help="Training epochs per fold", + ) + parser.add_argument( + "--hidden_dim", + type=int, + default=64, + help="LSTM hidden dimension", + ) + parser.add_argument( + "--device", + default="cpu", + help="Device (cpu or cuda)", + ) + args = parser.parse_args() + + if args.demo: + _run_ablations_demo(args) + else: + _run_ablations_real(args) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..eee3f8fef 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -66,3 +66,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .sleep_staging_dreamt import SleepStagingDREAMT diff --git a/pyhealth/tasks/sleep_staging_dreamt.py b/pyhealth/tasks/sleep_staging_dreamt.py new file mode 100644 index 000000000..330674903 --- /dev/null +++ b/pyhealth/tasks/sleep_staging_dreamt.py @@ -0,0 +1,379 @@ +"""Sleep staging task for the DREAMT dataset. + +This task processes overnight Empatica E4 wearable recordings from the +DREAMT dataset into per-epoch raw signal windows with sleep stage labels, +suitable for deep learning models. + +The preprocessing follows the methodology described in: + + Wang et al. "Addressing wearable sleep tracking inequity: a new + dataset and novel methods for a population with sleep disorders." + CHIL 2024, PMLR 248:380-396. + +Signal-specific preprocessing: + - **ACC (ACC_X, ACC_Y, ACC_Z)**: 5th-order Butterworth bandpass + filter, 3-11 Hz (Altini & Kinnunen 2021). + - **BVP**: Chebyshev Type II bandpass filter, 0.5-20 Hz. + - **TEMP**: Winsorized (clipped) to [31, 40] degrees C. + - **EDA, HR, IBI**: No additional filtering. +""" + +import logging +from typing import Any, Dict, List, Optional + +import numpy as np +import pandas as pd +from scipy.signal import butter, cheby2, filtfilt + +from pyhealth.tasks.base_task import BaseTask + +logger = logging.getLogger(__name__) + +# Label mappings per classification granularity +LABEL_MAP_5CLASS: Dict[str, int] = { + "W": 0, + "N1": 1, + "N2": 2, + "N3": 3, + "R": 4, +} + +LABEL_MAP_3CLASS: Dict[str, int] = { + "W": 0, + "N1": 1, + "N2": 1, + "N3": 1, + "R": 2, +} + +LABEL_MAP_2CLASS: Dict[str, int] = { + "W": 0, + "N1": 1, + "N2": 1, + "N3": 1, + "R": 1, +} + +LABEL_MAPS: Dict[int, Dict[str, int]] = { + 5: LABEL_MAP_5CLASS, + 3: LABEL_MAP_3CLASS, + 2: LABEL_MAP_2CLASS, +} + +ALL_SIGNAL_COLUMNS: List[str] = [ + "BVP", + "ACC_X", + "ACC_Y", + "ACC_Z", + "EDA", + "TEMP", + "HR", + "IBI", +] + +# Excluded sleep stages +EXCLUDED_STAGES = {"P", "Missing"} + + +def _apply_butterworth_bandpass( + signal: np.ndarray, + lowcut: float, + highcut: float, + fs: int, + order: int = 5, +) -> np.ndarray: + """Apply a Butterworth bandpass filter. + + Args: + signal: 1-D array of signal values. + lowcut: Lower cutoff frequency in Hz. + highcut: Upper cutoff frequency in Hz. + fs: Sampling rate in Hz. + order: Filter order. + + Returns: + Filtered signal array. + """ + nyq = 0.5 * fs + b, a = butter(order, [lowcut / nyq, highcut / nyq], btype="band") + try: + return filtfilt(b, a, signal) + except ValueError: + logger.warning( + "Butterworth filtfilt failed (likely too few samples); " + "returning raw signal." + ) + return signal + + +def _apply_chebyshev_bandpass( + signal: np.ndarray, + lowcut: float, + highcut: float, + fs: int, + order: int = 4, + rs: float = 40.0, +) -> np.ndarray: + """Apply a Chebyshev Type II bandpass filter. + + Args: + signal: 1-D array of signal values. + lowcut: Lower cutoff frequency in Hz. + highcut: Upper cutoff frequency in Hz. + fs: Sampling rate in Hz. + order: Filter order. + rs: Minimum stopband attenuation in dB. + + Returns: + Filtered signal array. + """ + nyq = 0.5 * fs + b, a = cheby2(order, rs, [lowcut / nyq, highcut / nyq], btype="band") + try: + return filtfilt(b, a, signal) + except ValueError: + logger.warning( + "Chebyshev filtfilt failed (likely too few samples); " + "returning raw signal." + ) + return signal + + +def _apply_filters( + epoch: np.ndarray, + columns: List[str], + fs: int, +) -> np.ndarray: + """Apply signal-specific filters to one epoch. + + Args: + epoch: Array of shape ``(n_channels, epoch_len)``. + columns: Column names corresponding to each channel. + fs: Sampling rate in Hz. + + Returns: + Filtered epoch array of the same shape. + """ + filtered = epoch.copy() + for i, col in enumerate(columns): + if col in ("ACC_X", "ACC_Y", "ACC_Z"): + filtered[i] = _apply_butterworth_bandpass( + filtered[i], lowcut=3.0, highcut=11.0, fs=fs + ) + elif col == "BVP": + filtered[i] = _apply_chebyshev_bandpass( + filtered[i], lowcut=0.5, highcut=20.0, fs=fs + ) + elif col == "TEMP": + filtered[i] = np.clip(filtered[i], 31.0, 40.0) + # EDA, HR, IBI: no additional filtering + return filtered + + +class SleepStagingDREAMT(BaseTask): + """Sleep staging task for the DREAMT dataset. + + Transforms each participant's overnight Empatica E4 recording into + non-overlapping 30-second epochs of raw multi-channel signal data + with integer sleep stage labels. Supports 5-class, 3-class, and + 2-class classification granularities. + + Signal-specific preprocessing (from Wang et al. CHIL 2024): + + - **ACC**: 5th-order Butterworth bandpass, 3-11 Hz + - **BVP**: Chebyshev Type II bandpass, 0.5-20 Hz + - **TEMP**: Winsorized to [31, 40] C + + Epochs labeled ``"P"`` (preparation) or ``"Missing"`` are excluded. + + Attributes: + task_name: ``"SleepStagingDREAMT"`` + input_schema: ``{"signal": "tensor"}`` + output_schema: ``{"label": "multiclass"}`` + + Args: + n_classes: Number of classification classes. Must be one of + ``{2, 3, 5}``. Default ``5``. + + - **5-class**: W=0, N1=1, N2=2, N3=3, R=4 + - **3-class**: W=0, NREM(N1/N2/N3)=1, REM=2 + - **2-class**: W=0, Sleep(N1/N2/N3/R)=1 + + signal_columns: List of signal column names to include. + Default includes all 8 channels: ``["BVP", "ACC_X", + "ACC_Y", "ACC_Z", "EDA", "TEMP", "HR", "IBI"]``. + epoch_seconds: Duration of each epoch in seconds. + Default ``30.0``. + sampling_rate: Sampling rate in Hz. Default ``64``. + apply_filters: Whether to apply signal-specific filters. + Default ``True``. + + Examples: + >>> from pyhealth.datasets import DREAMTDataset + >>> ds = DREAMTDataset(root="/path/to/dreamt/2.1.0") + >>> task = SleepStagingDREAMT(n_classes=5) + >>> sample_ds = ds.set_task(task) + >>> sample = sample_ds.samples[0] + >>> sample.keys() + dict_keys(['patient_id', 'signal', 'label', 'epoch_index']) + >>> sample["signal"].shape # (8, 1920) for 8 channels, 30s * 64Hz + (8, 1920) + + >>> # 2-class (wake vs sleep) with ACC channels only + >>> task_binary = SleepStagingDREAMT( + ... n_classes=2, + ... signal_columns=["ACC_X", "ACC_Y", "ACC_Z"], + ... ) + >>> sample_ds = ds.set_task(task_binary) + >>> sample_ds.samples[0]["signal"].shape + (3, 1920) + """ + + task_name: str = "SleepStagingDREAMT" + input_schema: Dict[str, str] = {"signal": "tensor"} + output_schema: Dict[str, str] = {"label": "multiclass"} + + def __init__( + self, + n_classes: int = 5, + signal_columns: Optional[List[str]] = None, + epoch_seconds: float = 30.0, + sampling_rate: int = 64, + apply_filters: bool = True, + ) -> None: + if n_classes not in {2, 3, 5}: + raise ValueError( + f"n_classes must be one of {{2, 3, 5}}, got {n_classes}" + ) + self.n_classes = n_classes + self.signal_columns = ( + list(signal_columns) if signal_columns is not None + else list(ALL_SIGNAL_COLUMNS) + ) + self.epoch_seconds = epoch_seconds + self.sampling_rate = sampling_rate + self.apply_filters = apply_filters + self.epoch_len = int(epoch_seconds * sampling_rate) + self.label_map = LABEL_MAPS[n_classes] + super().__init__() + + def __call__( + self, + patient: Any, + ) -> List[Dict[str, Any]]: + """Process one DREAMT patient into epoch samples. + + Args: + patient: A ``Patient`` object from ``DREAMTDataset``. + + Returns: + List of sample dicts, each containing: + + - ``patient_id`` (str): Patient identifier. + - ``signal`` (np.ndarray): Shape ``(n_channels, epoch_len)``, + dtype ``float32``. + - ``label`` (int): Integer sleep stage label. + - ``epoch_index`` (int): Sequential epoch position within + the recording. + """ + pid: str = patient.patient_id + + try: + events = patient.get_events(event_type="dreamt_sleep") + except Exception: + events = patient.get_events() + + if not events: + return [] + + event = events[0] + file_path = getattr(event, "file_64hz", None) + if file_path is None or ( + isinstance(file_path, str) and file_path.lower() == "none" + ): + return [] + + try: + df = pd.read_csv(str(file_path)) + except Exception as exc: + logger.warning("Could not read %s: %s", file_path, exc) + return [] + + # Build case-insensitive column mapping + col_map: Dict[str, str] = { + col.lower(): col for col in df.columns + } + + # Check required columns exist + required = set(c.lower() for c in self.signal_columns) | { + "sleep_stage" + } + missing = [c for c in required if c not in col_map] + if missing: + logger.warning( + "Patient %s missing columns: %s", pid, missing + ) + return [] + + # Resolve actual column names + stage_col = col_map["sleep_stage"] + signal_col_names = [ + col_map[c.lower()] for c in self.signal_columns + ] + + # Drop excluded stages + mask = ~df[stage_col].isin(EXCLUDED_STAGES) + df = df.loc[mask].reset_index(drop=True) + + if len(df) < self.epoch_len: + return [] + + n_epochs = len(df) // self.epoch_len + samples: List[Dict[str, Any]] = [] + epoch_counter = 0 + + for i in range(n_epochs): + start = i * self.epoch_len + end = start + self.epoch_len + epoch_df = df.iloc[start:end] + + # Label from the middle of the epoch + mid = start + self.epoch_len // 2 + stage = str(df[stage_col].iloc[mid]).strip() + + if stage not in self.label_map: + continue + + label = self.label_map[stage] + + # Extract signal channels as (n_channels, epoch_len) + signal = np.stack( + [ + epoch_df[col].values.astype(np.float64) + for col in signal_col_names + ], + axis=0, + ) + + # Replace NaN with 0 + np.nan_to_num(signal, nan=0.0, copy=False) + + # Apply signal-specific filters + if self.apply_filters: + signal = _apply_filters( + signal, self.signal_columns, self.sampling_rate + ) + + signal = signal.astype(np.float32) + + samples.append( + { + "patient_id": pid, + "signal": signal, + "label": label, + "epoch_index": epoch_counter, + } + ) + epoch_counter += 1 + + return samples diff --git a/tests/core/test_sleep_staging_dreamt.py b/tests/core/test_sleep_staging_dreamt.py new file mode 100644 index 000000000..80812f532 --- /dev/null +++ b/tests/core/test_sleep_staging_dreamt.py @@ -0,0 +1,440 @@ +"""Tests for SleepStagingDREAMT task. + +All tests use in-memory fake patients with small temporary CSV files. +No real DREAMT data is required. Tests complete in milliseconds. +""" + +import os +from types import SimpleNamespace +from typing import List, Optional + +import numpy as np +import pandas as pd +import pytest + +from pyhealth.tasks.sleep_staging_dreamt import ( + ALL_SIGNAL_COLUMNS, + LABEL_MAP_2CLASS, + LABEL_MAP_3CLASS, + LABEL_MAP_5CLASS, + SleepStagingDREAMT, +) + +EPOCH_LEN = 30 * 64 # 1920 samples + + +# ----------------------------------------------------------- +# Helpers +# ----------------------------------------------------------- + + +def _make_csv( + n_epochs: int, + stages: List[str], + tmpdir: str, + patient_id: str = "S001", +) -> str: + """Create a synthetic 64 Hz CSV with ``n_epochs`` epochs. + + Args: + n_epochs: Number of 30-second epochs to generate. + stages: Sleep stage labels, one per epoch (cycled). + tmpdir: Directory to write the CSV into. + patient_id: Used in the filename. + + Returns: + Absolute path to the written CSV file. + """ + rng = np.random.RandomState(42) + rows = n_epochs * EPOCH_LEN + data = { + "TIMESTAMP": np.arange(rows) / 64.0, + "BVP": rng.randn(rows) * 50, + "IBI": np.clip(rng.rand(rows) * 0.2 + 0.7, 0, 2), + "EDA": rng.rand(rows) * 5 + 0.1, + "TEMP": rng.rand(rows) * 15 + 28, # range includes <31 and >40 + "ACC_X": rng.randn(rows) * 10, + "ACC_Y": rng.randn(rows) * 10, + "ACC_Z": rng.randn(rows) * 10, + "HR": rng.rand(rows) * 30 + 60, + } + + stage_col = [] + for i in range(n_epochs): + stage = stages[i % len(stages)] + stage_col.extend([stage] * EPOCH_LEN) + data["Sleep_Stage"] = stage_col + + df = pd.DataFrame(data) + path = os.path.join(tmpdir, f"{patient_id}_whole_df.csv") + df.to_csv(path, index=False) + return path + + +def _make_patient( + file_path: Optional[str], + patient_id: str = "S001", +) -> SimpleNamespace: + """Build a mock Patient mimicking DREAMTDataset. + + Args: + file_path: Path to the CSV, or None for an empty patient. + patient_id: Identifier for the mock patient. + + Returns: + SimpleNamespace matching the DREAMT patient contract. + """ + event = SimpleNamespace(file_64hz=file_path) + patient = SimpleNamespace( + patient_id=patient_id, + get_events=lambda event_type=None: [event], + ) + return patient + + +# ----------------------------------------------------------- +# Tests — Initialization +# ----------------------------------------------------------- + + +class TestInit: + """Task initialization tests.""" + + def test_default_params(self): + """Default init uses 5 classes and all 8 channels.""" + task = SleepStagingDREAMT() + assert task.n_classes == 5 + assert task.signal_columns == list(ALL_SIGNAL_COLUMNS) + assert task.epoch_seconds == 30.0 + assert task.sampling_rate == 64 + assert task.apply_filters is True + assert task.epoch_len == 1920 + + def test_custom_params(self): + """Custom init parameters are stored correctly.""" + task = SleepStagingDREAMT( + n_classes=2, + signal_columns=["ACC_X", "ACC_Y"], + epoch_seconds=15.0, + sampling_rate=32, + apply_filters=False, + ) + assert task.n_classes == 2 + assert task.signal_columns == ["ACC_X", "ACC_Y"] + assert task.epoch_seconds == 15.0 + assert task.sampling_rate == 32 + assert task.apply_filters is False + assert task.epoch_len == 480 + + def test_invalid_n_classes_raises(self): + """n_classes not in {2, 3, 5} raises ValueError.""" + with pytest.raises(ValueError, match="n_classes"): + SleepStagingDREAMT(n_classes=4) + + def test_invalid_n_classes_other(self): + """n_classes=1 also raises ValueError.""" + with pytest.raises(ValueError): + SleepStagingDREAMT(n_classes=1) + + def test_class_attributes(self): + """Task has correct class-level attributes.""" + task = SleepStagingDREAMT() + assert task.task_name == "SleepStagingDREAMT" + assert task.input_schema == {"signal": "tensor"} + assert task.output_schema == {"label": "multiclass"} + + +# ----------------------------------------------------------- +# Tests — 5-class +# ----------------------------------------------------------- + + +class TestFiveClass: + """5-class sleep staging tests.""" + + def test_sample_count(self, tmp_path): + """Correct number of valid epochs returned.""" + stages = ["W", "N1", "N2", "N3", "R"] + csv = _make_csv(5, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(n_classes=5, apply_filters=False) + samples = task(patient) + assert len(samples) == 5 + + def test_label_mapping(self, tmp_path): + """5-class maps W=0, N1=1, N2=2, N3=3, R=4.""" + stages = ["W", "N1", "N2", "N3", "R"] + csv = _make_csv(5, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(n_classes=5, apply_filters=False) + samples = task(patient) + labels = [s["label"] for s in samples] + assert labels == [0, 1, 2, 3, 4] + + def test_signal_shape(self, tmp_path): + """Signal shape is (n_channels, 1920).""" + stages = ["W", "N2"] + csv = _make_csv(2, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(n_classes=5, apply_filters=False) + samples = task(patient) + assert samples[0]["signal"].shape == (8, 1920) + assert samples[0]["signal"].dtype == np.float32 + + def test_patient_id(self, tmp_path): + """Samples carry the correct patient_id.""" + csv = _make_csv( + 2, ["W", "N1"], str(tmp_path), patient_id="S042" + ) + patient = _make_patient(csv, patient_id="S042") + task = SleepStagingDREAMT(apply_filters=False) + samples = task(patient) + assert all(s["patient_id"] == "S042" for s in samples) + + def test_epoch_indices_sequential(self, tmp_path): + """epoch_index is sequential starting from 0.""" + stages = ["W", "N1", "N2", "N3", "R"] * 3 + csv = _make_csv(15, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(apply_filters=False) + samples = task(patient) + indices = [s["epoch_index"] for s in samples] + assert indices == list(range(len(samples))) + + +# ----------------------------------------------------------- +# Tests — 3-class +# ----------------------------------------------------------- + + +class TestThreeClass: + """3-class sleep staging tests.""" + + def test_label_mapping(self, tmp_path): + """3-class maps W=0, NREM=1, REM=2.""" + stages = ["W", "N1", "N2", "N3", "R"] + csv = _make_csv(5, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(n_classes=3, apply_filters=False) + samples = task(patient) + labels = [s["label"] for s in samples] + assert labels == [0, 1, 1, 1, 2] + + +# ----------------------------------------------------------- +# Tests — 2-class +# ----------------------------------------------------------- + + +class TestTwoClass: + """2-class (wake vs sleep) tests.""" + + def test_label_mapping(self, tmp_path): + """2-class maps W=0, all sleep=1.""" + stages = ["W", "N1", "N2", "N3", "R"] + csv = _make_csv(5, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(n_classes=2, apply_filters=False) + samples = task(patient) + labels = [s["label"] for s in samples] + assert labels == [0, 1, 1, 1, 1] + + +# ----------------------------------------------------------- +# Tests — Stage exclusion +# ----------------------------------------------------------- + + +class TestStageExclusion: + """P and Missing stage exclusion tests.""" + + def test_p_stage_excluded(self, tmp_path): + """Epochs with P (preparation) stage are dropped.""" + stages = ["P", "P", "W", "N1"] + csv = _make_csv(4, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(apply_filters=False) + samples = task(patient) + labels = [s["label"] for s in samples] + # After dropping P rows, remaining data may yield fewer epochs + # All returned labels should be valid (no P) + assert all(lbl in {0, 1, 2, 3, 4} for lbl in labels) + + def test_missing_stage_excluded(self, tmp_path): + """Epochs with Missing stage are dropped.""" + stages = ["W", "Missing", "N2"] + csv = _make_csv(3, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(apply_filters=False) + samples = task(patient) + # Missing rows are dropped, so we may get fewer epochs + for s in samples: + assert s["label"] in {0, 1, 2, 3, 4} + + def test_all_p_returns_empty(self, tmp_path): + """Patient with only P stages returns empty list.""" + stages = ["P", "P", "P"] + csv = _make_csv(3, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(apply_filters=False) + samples = task(patient) + assert samples == [] + + def test_all_missing_returns_empty(self, tmp_path): + """Patient with only Missing stages returns empty list.""" + stages = ["Missing", "Missing"] + csv = _make_csv(2, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(apply_filters=False) + samples = task(patient) + assert samples == [] + + +# ----------------------------------------------------------- +# Tests — Signal column subsetting +# ----------------------------------------------------------- + + +class TestSignalSubset: + """Signal column selection tests.""" + + def test_acc_only(self, tmp_path): + """Selecting ACC channels gives shape (3, 1920).""" + stages = ["W", "N2"] + csv = _make_csv(2, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT( + signal_columns=["ACC_X", "ACC_Y", "ACC_Z"], + apply_filters=False, + ) + samples = task(patient) + assert samples[0]["signal"].shape == (3, 1920) + + def test_single_channel(self, tmp_path): + """Single channel gives shape (1, 1920).""" + stages = ["W"] + csv = _make_csv(1, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT( + signal_columns=["HR"], + apply_filters=False, + ) + samples = task(patient) + assert samples[0]["signal"].shape == (1, 1920) + + def test_bvp_temp(self, tmp_path): + """Custom subset of BVP + TEMP gives shape (2, 1920).""" + stages = ["N3"] + csv = _make_csv(1, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT( + signal_columns=["BVP", "TEMP"], + apply_filters=False, + ) + samples = task(patient) + assert samples[0]["signal"].shape == (2, 1920) + + +# ----------------------------------------------------------- +# Tests — Filtering +# ----------------------------------------------------------- + + +class TestFiltering: + """Signal filtering tests.""" + + def test_filters_run_without_error(self, tmp_path): + """Filters execute without raising exceptions.""" + stages = ["W", "N2", "R"] + csv = _make_csv(3, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT(apply_filters=True) + samples = task(patient) + assert len(samples) > 0 + assert samples[0]["signal"].shape == (8, 1920) + + def test_temp_winsorization(self, tmp_path): + """TEMP values are clipped to [31, 40] after filtering.""" + stages = ["W"] + csv = _make_csv(1, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT( + signal_columns=["TEMP"], + apply_filters=True, + ) + samples = task(patient) + temp_signal = samples[0]["signal"][0] + assert np.all(temp_signal >= 31.0) + assert np.all(temp_signal <= 40.0) + + def test_filters_disabled(self, tmp_path): + """With apply_filters=False, TEMP is not clipped.""" + # Generate TEMP data that will have values outside [31, 40] + stages = ["W"] + csv = _make_csv(1, stages, str(tmp_path)) + patient = _make_patient(csv) + task = SleepStagingDREAMT( + signal_columns=["TEMP"], + apply_filters=False, + ) + samples = task(patient) + temp_signal = samples[0]["signal"][0] + # Synthetic data has range ~[28, 43], so some should be outside + has_below = np.any(temp_signal < 31.0) + has_above = np.any(temp_signal > 40.0) + assert has_below or has_above + + +# ----------------------------------------------------------- +# Tests — Edge cases +# ----------------------------------------------------------- + + +class TestEdgeCases: + """Edge case handling tests.""" + + def test_empty_patient_no_file(self): + """Patient with no file returns empty list.""" + patient = _make_patient(None, patient_id="S_EMPTY") + task = SleepStagingDREAMT() + samples = task(patient) + assert samples == [] + + def test_empty_patient_none_string(self, tmp_path): + """Patient with file_64hz='None' returns empty list.""" + patient = _make_patient("None", patient_id="S_NONE") + task = SleepStagingDREAMT() + samples = task(patient) + assert samples == [] + + def test_multi_patient_isolation(self, tmp_path): + """Each patient's samples reference only its own id.""" + csv_a = _make_csv( + 3, ["W", "N1", "N2"], str(tmp_path), patient_id="P_A" + ) + csv_b = _make_csv( + 2, ["N3", "R"], str(tmp_path), patient_id="P_B" + ) + patient_a = _make_patient(csv_a, patient_id="P_A") + patient_b = _make_patient(csv_b, patient_id="P_B") + + task = SleepStagingDREAMT(apply_filters=False) + samples_a = task(patient_a) + samples_b = task(patient_b) + + assert all(s["patient_id"] == "P_A" for s in samples_a) + assert all(s["patient_id"] == "P_B" for s in samples_b) + + # epoch_index restarts at 0 for each patient + if samples_a: + assert samples_a[0]["epoch_index"] == 0 + if samples_b: + assert samples_b[0]["epoch_index"] == 0 + + def test_nonexistent_file(self, tmp_path): + """Nonexistent CSV path returns empty list.""" + fake_path = os.path.join(str(tmp_path), "does_not_exist.csv") + patient = _make_patient(fake_path) + task = SleepStagingDREAMT() + samples = task(patient) + assert samples == [] From 4bc51dc31c867760ca685ef074e620f2190a3501 Mon Sep 17 00:00:00 2001 From: "Erie M. Adames" Date: Thu, 9 Apr 2026 16:38:04 -0400 Subject: [PATCH 2/3] update example with rnn --- examples/dreamt_sleep_staging_rnn.py | 518 +++++++++++++++++ examples/dreamt_sleep_wake_lstm.py | 735 ------------------------ pyhealth/tasks/sleep_staging_dreamt.py | 4 +- tests/core/test_sleep_staging_dreamt.py | 3 - 4 files changed, 520 insertions(+), 740 deletions(-) create mode 100644 examples/dreamt_sleep_staging_rnn.py delete mode 100644 examples/dreamt_sleep_wake_lstm.py diff --git a/examples/dreamt_sleep_staging_rnn.py b/examples/dreamt_sleep_staging_rnn.py new file mode 100644 index 000000000..c3127076b --- /dev/null +++ b/examples/dreamt_sleep_staging_rnn.py @@ -0,0 +1,518 @@ +"""DREAMT Sleep Staging — Ablation Study with PyHealth RNN. + +This script runs two ablation experiments on the DREAMT +sleep staging task using ``SleepStagingDREAMT``: + +1. **Signal-subset ablation** (binary wake/sleep): + ACC-only vs BVP/HRV-only vs EDA+TEMP-only vs ALL signals. + +2. **Label-granularity ablation** (ALL signals): + 2-class (wake/sleep) vs 5-class (W/N1/N2/N3/R). + +The model is PyHealth's built-in ``RNN`` (LSTM variant), trained +using PyHealth's ``Trainer`` with patient-level data splits. + +Each 30-second epoch's raw multi-channel signal is reduced to +per-channel statistics (mean, std, min, max) to form a compact +feature vector fed to the model. + +Usage — full DREAMT run:: + + python dreamt_sleep_staging_rnn.py --root /path/to/dreamt + +Usage — synthetic demo (no dataset required):: + + python dreamt_sleep_staging_rnn.py --demo + +Metrics: F1 (macro), Accuracy, Cohen's Kappa. + +Results / Findings +------------------ + +**Demo mode** (synthetic data, 6 patients, 2 training epochs): + +Results are non-meaningful and serve only to verify that the +full pipeline (epoching -> feature extraction -> PyHealth RNN +training -> evaluation) runs end-to-end without error. Expected +output is near-random performance. + +**Paper reference** (Wang et al. CHIL 2024, Table 2): + +The original paper reports wake/sleep (2-class) detection on +80 participants (after artifact QC) using LightGBM / GPBoost +with hand-crafted features from all 8 E4 channels and 5-fold +participant-level CV. Key results from Table 2: + +- Baseline LightGBM: F1 = 0.777, Acc = 0.816, Kappa = 0.605 +- Best (GPBoost + Apnea Severity RE + LSTM post-processing): + F1 = 0.823, Acc = 0.857, Kappa = 0.694 + +The paper does not report per-signal-subset ablations; those +are original to this script. This ablation also uses a simpler +feature set (4 summary stats per channel) and a neural model +(LSTM via PyHealth RNN) with a 70/10/20 patient-level split, +so results are expected to differ from the paper. + +Reference: + Wang et al. "Addressing wearable sleep tracking inequity: + a new dataset and novel methods for a population with sleep + disorders." CHIL 2024, PMLR 248:380-396. +""" + +import argparse +import os +import tempfile +import warnings +from typing import Any, Dict, List, Optional + +import numpy as np + +from pyhealth.datasets import ( + create_sample_dataset, + get_dataloader, + split_by_patient, +) +from pyhealth.models import RNN +from pyhealth.tasks.sleep_staging_dreamt import ( + ALL_SIGNAL_COLUMNS, + SleepStagingDREAMT, +) +from pyhealth.trainer import Trainer + +warnings.filterwarnings("ignore", category=FutureWarning) + +EPOCH_LEN: int = 30 * 64 # 1920 samples per 30-s epoch at 64 Hz + +SIGNAL_SUBSETS: Dict[str, List[str]] = { + "ACC": ["ACC_X", "ACC_Y", "ACC_Z"], + "BVP_HRV": ["BVP", "HR", "IBI"], + "EDA_TEMP": ["EDA", "TEMP"], + "ALL": list(ALL_SIGNAL_COLUMNS), +} + + +def _epoch_features(signal: np.ndarray) -> List[float]: + """Convert a raw epoch signal to a compact feature vector. + + Computes mean, std, min, and max per channel. + + Args: + signal: Array of shape ``(n_channels, epoch_len)``. + + Returns: + Flat list of length ``4 * n_channels``. + """ + feats: List[float] = [] + for ch in range(signal.shape[0]): + s = signal[ch].astype(np.float64) + feats.extend([ + float(np.mean(s)), + float(np.std(s)), + float(np.min(s)), + float(np.max(s)), + ]) + return feats + + +def _prepare_samples( + raw_samples: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Transform task-output samples into feature-vector samples. + + Each raw sample has ``signal`` (n_channels, epoch_len). + This replaces it with a flat feature vector suitable for + PyHealth's ``TensorProcessor``. + + Args: + raw_samples: Output of ``SleepStagingDREAMT(patient)``. + + Returns: + List of dicts with ``patient_id``, ``features``, ``label``. + """ + return [ + { + "patient_id": s["patient_id"], + "features": _epoch_features(s["signal"]), + "label": s["label"], + } + for s in raw_samples + ] + + +def run_config( + raw_samples: List[Dict[str, Any]], + num_classes: int, + device: str = "cpu", + epochs: int = 5, + hidden_dim: int = 64, + split_ratios: Optional[List[float]] = None, +) -> Dict[str, float]: + """Run one ablation configuration with PyHealth RNN and Trainer. + + Args: + raw_samples: Epoch samples from ``SleepStagingDREAMT``. + num_classes: Number of classification classes (2 or 5). + device: Torch device string. + epochs: Training epochs. + hidden_dim: LSTM hidden dimension. + split_ratios: Patient-level train/val/test ratios. + + Returns: + Dictionary of evaluation metric scores. + """ + if split_ratios is None: + split_ratios = [0.7, 0.1, 0.2] + + samples = _prepare_samples(raw_samples) + + dataset = create_sample_dataset( + samples=samples, + input_schema={"features": "tensor"}, + output_schema={"label": "multiclass"}, + dataset_name="dreamt", + task_name="sleep_staging", + ) + + train_ds, val_ds, test_ds = split_by_patient( + dataset, split_ratios, seed=42, + ) + + train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False) + test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False) + + model = RNN( + dataset=dataset, + embedding_dim=hidden_dim, + hidden_dim=hidden_dim, + rnn_type="LSTM", + num_layers=1, + dropout=0.0, + ) + + trainer = Trainer( + model=model, + metrics=["accuracy", "f1_macro", "cohen_kappa"], + device=device, + enable_logging=False, + ) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=epochs, + ) + + results = trainer.evaluate(test_loader) + return results + + +# ----------------------------------------------------------- +# Demo data generation +# ----------------------------------------------------------- + + +def _generate_demo_csv( + tmpdir: str, + patient_id: str, + n_epochs: int, + rng: np.random.RandomState, +) -> str: + """Create one synthetic 64 Hz CSV file. + + Args: + tmpdir: Directory to write the CSV. + patient_id: Used in the filename. + n_epochs: Number of 30-s epochs to generate. + rng: Random state for reproducibility. + + Returns: + Path to the written CSV file. + """ + import pandas as pd + + stages_pool = ["W", "N1", "N2", "N3", "R"] + rows = n_epochs * EPOCH_LEN + data = { + "TIMESTAMP": np.arange(rows) / 64.0, + "BVP": rng.randn(rows) * 50, + "IBI": np.clip(rng.rand(rows) * 0.2 + 0.7, 0, 2), + "EDA": rng.rand(rows) * 5 + 0.1, + "TEMP": rng.rand(rows) * 4 + 33, + "ACC_X": rng.randn(rows) * 10, + "ACC_Y": rng.randn(rows) * 10, + "ACC_Z": rng.randn(rows) * 10, + "HR": rng.rand(rows) * 30 + 60, + } + stage_col = [] + for i in range(n_epochs): + st = stages_pool[i % len(stages_pool)] + stage_col.extend([st] * EPOCH_LEN) + data["Sleep_Stage"] = stage_col + + csv_path = os.path.join(tmpdir, f"{patient_id}_whole_df.csv") + pd.DataFrame(data).to_csv(csv_path, index=False) + return csv_path + + +def _generate_demo_samples( + n_classes: int = 2, + signal_columns: Optional[List[str]] = None, + n_patients: int = 6, + epochs_per_patient: int = 15, + seed: int = 123, +) -> List[Dict[str, Any]]: + """Create synthetic samples for demo mode. + + Args: + n_classes: Number of label classes (2, 3, or 5). + signal_columns: Which signal columns to include. + n_patients: Number of synthetic patients. + epochs_per_patient: 30-s epochs per patient. + seed: Random seed for reproducibility. + + Returns: + List of sample dicts ready for training. + """ + from types import SimpleNamespace + + rng = np.random.RandomState(seed) + all_samples: List[Dict[str, Any]] = [] + + with tempfile.TemporaryDirectory() as tmpdir: + for p in range(n_patients): + pid = f"DEMO_{p:03d}" + csv_path = _generate_demo_csv( + tmpdir, pid, epochs_per_patient, rng, + ) + + evt = SimpleNamespace(file_64hz=csv_path) + patient = SimpleNamespace( + patient_id=pid, + get_events=lambda et=None, e=evt: [e], + ) + + task = SleepStagingDREAMT( + n_classes=n_classes, + signal_columns=signal_columns, + apply_filters=False, + ) + all_samples.extend(task(patient)) + + return all_samples + + +# ----------------------------------------------------------- +# Ablation runners +# ----------------------------------------------------------- + +DEFAULT_ROOT = os.path.expanduser("~/.pyhealth/dreamt") + + +def _resolve_root(root_arg: Optional[str]) -> str: + """Find a valid DREAMT root, or exit with guidance. + + Args: + root_arg: User-supplied ``--root`` value, or None. + + Returns: + Absolute path to the DREAMT version directory. + + Raises: + SystemExit: If no valid directory is found. + """ + candidates = ( + [root_arg] + if root_arg + else [ + DEFAULT_ROOT, + os.path.expanduser("~/data/dreamt"), + os.path.expanduser("~/dreamt"), + ] + ) + for path in candidates: + if path and os.path.isdir(path): + info = os.path.join(path, "participant_info.csv") + if os.path.isfile(info): + return path + for sub in sorted(os.listdir(path)): + subpath = os.path.join(path, sub) + if os.path.isdir(subpath) and os.path.isfile( + os.path.join(subpath, "participant_info.csv") + ): + return subpath + print( + "ERROR: Could not find the DREAMT dataset.\n" + "\n" + "Download from PhysioNet (credentialed access):\n" + " https://physionet.org/content/dreamt/\n" + "\n" + "Then either:\n" + f" - Extract to {DEFAULT_ROOT}/\n" + " - Or pass --root /path/to/dreamt/version/\n" + "\n" + "The directory must contain participant_info.csv\n" + "and a data_64Hz/ folder with per-participant CSVs." + ) + raise SystemExit(1) + + +def _run_ablations_real(args: argparse.Namespace) -> None: + """Run ablations on the real DREAMT dataset. + + Args: + args: Parsed command-line arguments. + """ + from pyhealth.datasets import DREAMTDataset + + root = _resolve_root(args.root) + print(f"Loading DREAMT dataset from {root} ...") + dataset = DREAMTDataset(root=root) + + print("\n" + "=" * 60) + print("ABLATION 1: Signal Subset (2-class wake/sleep)") + print("=" * 60) + + for subset_name, columns in SIGNAL_SUBSETS.items(): + print(f"\n--- Signal subset: {subset_name} ---") + task = SleepStagingDREAMT( + n_classes=2, + signal_columns=columns, + ) + sample_ds = dataset.set_task(task) + raw = [sample_ds[i] for i in range(len(sample_ds))] + print(f" Total samples: {len(raw)}") + results = run_config( + raw, + num_classes=2, + epochs=args.epochs, + hidden_dim=args.hidden_dim, + device=args.device, + ) + print(f" Results: {results}") + + print("\n" + "=" * 60) + print("ABLATION 2: Label Granularity (ALL signals)") + print("=" * 60) + + for nc in [2, 5]: + print(f"\n--- {nc}-class ---") + task = SleepStagingDREAMT(n_classes=nc) + sd = dataset.set_task(task) + raw = [sd[i] for i in range(len(sd))] + print(f" Total samples: {len(raw)}") + results = run_config( + raw, + num_classes=nc, + epochs=args.epochs, + hidden_dim=args.hidden_dim, + device=args.device, + ) + print(f" Results: {results}") + + +def _run_ablations_demo(args: argparse.Namespace) -> None: + """Run ablations on synthetic demo data. + + Args: + args: Parsed command-line arguments. + """ + print("=== DEMO MODE (synthetic data) ===\n") + print("Generating 6 synthetic patients (15 epochs each) ...") + + demo_epochs = min(args.epochs, 2) + + print("\n" + "=" * 60) + print("ABLATION 1: Signal Subset (2-class, demo)") + print("=" * 60) + + for subset_name, columns in SIGNAL_SUBSETS.items(): + print(f"\n--- Signal subset: {subset_name} ---") + seed = abs(hash(subset_name)) % (2**31) + raw = _generate_demo_samples( + n_classes=2, + signal_columns=columns, + n_patients=6, + seed=seed, + ) + print(f" Total samples: {len(raw)}") + results = run_config( + raw, + num_classes=2, + epochs=demo_epochs, + hidden_dim=args.hidden_dim, + device=args.device, + split_ratios=[0.5, 0.17, 0.33], + ) + print(f" Results: {results}") + + print("\n" + "=" * 60) + print("ABLATION 2: Label Granularity (demo)") + print("=" * 60) + + for nc in [2, 5]: + print(f"\n--- {nc}-class ---") + raw = _generate_demo_samples( + n_classes=nc, n_patients=6, seed=123, + ) + print(f" Total samples: {len(raw)}") + results = run_config( + raw, + num_classes=nc, + epochs=demo_epochs, + hidden_dim=args.hidden_dim, + device=args.device, + split_ratios=[0.5, 0.17, 0.33], + ) + print(f" Results: {results}") + + print("\nDemo complete.") + + +def main() -> None: + """Entry point for the DREAMT sleep staging ablation study.""" + parser = argparse.ArgumentParser( + description="DREAMT sleep staging ablation (PyHealth RNN)", + ) + parser.add_argument( + "--root", + default=None, + help=( + "Path to DREAMT dataset. " + f"Default: {DEFAULT_ROOT}" + ), + ) + parser.add_argument( + "--demo", + action="store_true", + help=( + "Run with synthetic data instead of real " + "DREAMT (no dataset download required)." + ), + ) + parser.add_argument( + "--epochs", + type=int, + default=30, + help="Training epochs (default: 30)", + ) + parser.add_argument( + "--hidden_dim", + type=int, + default=64, + help="RNN hidden dimension (default: 64)", + ) + parser.add_argument( + "--device", + default="cpu", + help="Device: cpu or cuda (default: cpu)", + ) + args = parser.parse_args() + + if args.demo: + _run_ablations_demo(args) + else: + _run_ablations_real(args) + + +if __name__ == "__main__": + main() diff --git a/examples/dreamt_sleep_wake_lstm.py b/examples/dreamt_sleep_wake_lstm.py deleted file mode 100644 index a0b63fd30..000000000 --- a/examples/dreamt_sleep_wake_lstm.py +++ /dev/null @@ -1,735 +0,0 @@ -"""DREAMT Sleep Staging LSTM — Ablation Study. - -This script runs two ablation experiments on the DREAMT -sleep staging task using ``SleepStagingDREAMT``: - -1. **Signal-subset ablation** (binary wake/sleep): - ACC-only vs BVP/HRV-only vs EDA+TEMP-only vs ALL signals. - -2. **Label-granularity ablation** (ALL signals): - 2-class (wake/sleep) vs 5-class (W/N1/N2/N3/R). - -A single-layer unidirectional LSTM is trained with 5-fold -participant-level cross-validation (no subject leakage). - -Each 30-second epoch's raw multi-channel signal is reduced to -per-channel statistics (mean, std, min, max) to form a compact -feature vector suitable for the LSTM. - -Usage — full DREAMT run:: - - python dreamt_sleep_wake_lstm.py --root /path/to/dreamt - -Usage — synthetic demo (no dataset required):: - - python dreamt_sleep_wake_lstm.py --demo - -Metrics: F1, AUROC, Accuracy, Cohen's Kappa. - -Results / Findings ------------------- - -**Demo mode** (synthetic data, 3 patients, 2 training epochs): - -Results are non-meaningful and serve only to verify that the -full pipeline (epoching -> feature extraction -> LSTM training --> evaluation) runs end-to-end without error. Expected output -is near-random performance (F1 ~ 0.2-0.5, Kappa ~ 0). - -Reference: - Wang et al. "Addressing wearable sleep tracking inequity: - a new dataset and novel methods for a population with sleep - disorders." CHIL 2024, PMLR 248:380-396. -""" - -import argparse -import os -import tempfile -import warnings -from collections import defaultdict -from typing import Any, Dict, List, Optional - -import numpy as np -import torch -import torch.nn as nn -from sklearn.metrics import ( - accuracy_score, - cohen_kappa_score, - f1_score, - roc_auc_score, -) -from torch.utils.data import DataLoader, Dataset -from tqdm import tqdm - -from pyhealth.tasks.sleep_staging_dreamt import ( - ALL_SIGNAL_COLUMNS, - SleepStagingDREAMT, -) - -warnings.filterwarnings("ignore") - -EPOCH_LEN: int = 30 * 64 # 1920 samples per 30-s epoch at 64 Hz - -SIGNAL_SUBSETS: Dict[str, List[str]] = { - "ACC": ["ACC_X", "ACC_Y", "ACC_Z"], - "BVP_HRV": ["BVP", "HR", "IBI"], - "EDA_TEMP": ["EDA", "TEMP"], - "ALL": list(ALL_SIGNAL_COLUMNS), -} - - -def _epoch_features(signal: np.ndarray) -> np.ndarray: - """Convert a raw epoch signal to a compact feature vector. - - Computes mean, std, min, and max per channel. - - Args: - signal: Array of shape ``(n_channels, epoch_len)``. - - Returns: - 1-D feature vector of length ``4 * n_channels``. - """ - if isinstance(signal, torch.Tensor): - signal = signal.numpy() - feats: List[float] = [] - for ch in range(signal.shape[0]): - s = signal[ch].astype(np.float64) - feats.extend([ - float(np.mean(s)), - float(np.std(s)), - float(np.min(s)), - float(np.max(s)), - ]) - return np.array(feats, dtype=np.float32) - - -# ----------------------------------------------------------- -# LSTM model -# ----------------------------------------------------------- - - -class SleepLSTM(nn.Module): - """Single-layer unidirectional LSTM for sleep tasks.""" - - def __init__( - self, - input_dim: int, - hidden_dim: int = 64, - num_classes: int = 2, - ) -> None: - super().__init__() - self.lstm = nn.LSTM( - input_dim, - hidden_dim, - num_layers=1, - batch_first=True, - ) - self.fc = nn.Linear(hidden_dim, num_classes) - - def forward( - self, - x: torch.Tensor, - ) -> torch.Tensor: - """Forward pass. - - Args: - x: Input tensor ``(batch, seq_len, features)``. - - Returns: - Logits ``(batch, seq_len, num_classes)``. - """ - out, _ = self.lstm(x) - return self.fc(out) - - -# ----------------------------------------------------------- -# Dataset wrapper -# ----------------------------------------------------------- - - -class SequenceDataset(Dataset): - """Groups epoch samples by patient into sequences.""" - - def __init__( - self, - samples: List[Dict[str, Any]], - ) -> None: - patient_map: Dict[str, list] = defaultdict(list) - for s in samples: - patient_map[s["patient_id"]].append(s) - - self.sequences: List[torch.Tensor] = [] - self.labels_list: List[torch.Tensor] = [] - for pid in sorted(patient_map): - epochs = sorted( - patient_map[pid], - key=lambda e: e["epoch_index"], - ) - signals = np.stack( - [_epoch_features(e["signal"]) for e in epochs], - axis=0, - ) - labels = np.array( - [e["label"] for e in epochs], - ) - self.sequences.append( - torch.tensor(signals, dtype=torch.float32) - ) - self.labels_list.append( - torch.tensor(labels, dtype=torch.long) - ) - - def __len__(self) -> int: - return len(self.sequences) - - def __getitem__( - self, - idx: int, - ) -> tuple: - return self.sequences[idx], self.labels_list[idx] - - -def collate_fn( - batch: list, -) -> tuple: - """Pad variable-length sequences in a batch.""" - seqs, labels = zip(*batch) - max_len = max(s.shape[0] for s in seqs) - feat_dim = seqs[0].shape[1] - padded_seqs = torch.zeros( - len(seqs), - max_len, - feat_dim, - ) - padded_labels = torch.full( - (len(seqs), max_len), - -1, - dtype=torch.long, - ) - masks = torch.zeros( - len(seqs), - max_len, - dtype=torch.bool, - ) - for i, (s, lbl) in enumerate(zip(seqs, labels)): - length = s.shape[0] - padded_seqs[i, :length] = s - padded_labels[i, :length] = lbl - masks[i, :length] = True - return padded_seqs, padded_labels, masks - - -# ----------------------------------------------------------- -# Training / evaluation -# ----------------------------------------------------------- - - -def train_and_evaluate( - train_samples: List[Dict[str, Any]], - test_samples: List[Dict[str, Any]], - num_classes: int = 2, - epochs: int = 30, - lr: float = 1e-3, - hidden_dim: int = 64, - device: str = "cpu", -) -> Dict[str, float]: - """Train LSTM and return test metrics.""" - if not train_samples or not test_samples: - return {} - - n_channels = train_samples[0]["signal"].shape[0] - feat_dim = 4 * n_channels - - train_ds = SequenceDataset(train_samples) - test_ds = SequenceDataset(test_samples) - - train_loader = DataLoader( - train_ds, - batch_size=8, - shuffle=True, - collate_fn=collate_fn, - ) - test_loader = DataLoader( - test_ds, - batch_size=8, - shuffle=False, - collate_fn=collate_fn, - ) - - model = SleepLSTM( - feat_dim, - hidden_dim, - num_classes, - ).to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=lr) - criterion = nn.CrossEntropyLoss(ignore_index=-1) - - model.train() - for _ in tqdm(range(epochs), desc=" Training", unit="epoch", leave=False): - for seqs, labels, masks in train_loader: - seqs = seqs.to(device) - labels = labels.to(device) - optimizer.zero_grad() - logits = model(seqs) - loss = criterion( - logits.reshape(-1, num_classes), - labels.reshape(-1), - ) - loss.backward() - optimizer.step() - - model.eval() - all_preds: List[int] = [] - all_labels: List[int] = [] - all_probs: list = [] - with torch.no_grad(): - for seqs, labels, masks in test_loader: - seqs = seqs.to(device) - logits = model(seqs) - probs = torch.softmax(logits, dim=-1).cpu() - preds = logits.argmax(dim=-1).cpu() - for i in range(seqs.shape[0]): - valid = masks[i] - all_preds.extend(preds[i][valid].numpy().tolist()) - all_labels.extend(labels[i][valid].numpy().tolist()) - if num_classes == 2: - all_probs.extend( - probs[i][valid][:, 1].numpy().tolist() - ) - else: - all_probs.extend( - probs[i][valid].numpy().tolist() - ) - - y_true = np.array(all_labels) - y_pred = np.array(all_preds) - avg = "binary" if num_classes == 2 else "macro" - - results: Dict[str, float] = { - "f1": f1_score( - y_true, - y_pred, - average=avg, - zero_division=0, - ), - "accuracy": accuracy_score(y_true, y_pred), - "kappa": cohen_kappa_score(y_true, y_pred), - } - - try: - if num_classes == 2: - results["auroc"] = roc_auc_score( - y_true, - np.array(all_probs), - ) - else: - results["auroc"] = roc_auc_score( - y_true, - np.array(all_probs), - multi_class="ovr", - average="macro", - ) - except ValueError: - results["auroc"] = float("nan") - - return results - - -def participant_cv( - samples: List[Dict[str, Any]], - n_folds: int = 5, - num_classes: int = 2, - **kwargs: Any, -) -> Dict[str, str]: - """5-fold participant-level cross-validation.""" - patient_ids = sorted(set(s["patient_id"] for s in samples)) - np.random.seed(42) - np.random.shuffle(patient_ids) - - fold_size = max(1, len(patient_ids) // n_folds) - fold_results: List[Dict[str, float]] = [] - - for fold in tqdm(range(n_folds), desc=" CV folds", unit="fold"): - start = fold * fold_size - end = ( - start + fold_size - if fold < n_folds - 1 - else len(patient_ids) - ) - test_ids = set(patient_ids[start:end]) - train_ids = set(patient_ids) - test_ids - - if not train_ids or not test_ids: - continue - - train_s = [ - s for s in samples if s["patient_id"] in train_ids - ] - test_s = [ - s for s in samples if s["patient_id"] in test_ids - ] - - res = train_and_evaluate( - train_s, - test_s, - num_classes=num_classes, - **kwargs, - ) - if res: - fold_results.append(res) - tqdm.write( - f" Fold {fold + 1}: " - f"F1={res['f1']:.3f} " - f"AUROC={res['auroc']:.3f} " - f"Acc={res['accuracy']:.3f} " - f"Kappa={res['kappa']:.3f}" - ) - - if not fold_results: - return {} - - avg: Dict[str, str] = {} - for key in fold_results[0]: - vals = [r[key] for r in fold_results] - avg[key] = f"{np.mean(vals):.3f} +/- {np.std(vals):.3f}" - return avg - - -# ----------------------------------------------------------- -# Demo data generation -# ----------------------------------------------------------- - - -def _generate_demo_csv( - tmpdir: str, - patient_id: str, - n_epochs: int, - rng: np.random.RandomState, -) -> str: - """Create one synthetic 64 Hz CSV file. - - Args: - tmpdir: Directory to write the CSV. - patient_id: Used in the filename. - n_epochs: Number of 30-s epochs to generate. - rng: Random state for reproducibility. - - Returns: - Path to the written CSV file. - """ - import pandas as pd - - stages_pool = ["W", "N1", "N2", "N3", "R"] - rows = n_epochs * EPOCH_LEN - data = { - "TIMESTAMP": np.arange(rows) / 64.0, - "BVP": rng.randn(rows) * 50, - "IBI": np.clip(rng.rand(rows) * 0.2 + 0.7, 0, 2), - "EDA": rng.rand(rows) * 5 + 0.1, - "TEMP": rng.rand(rows) * 4 + 33, - "ACC_X": rng.randn(rows) * 10, - "ACC_Y": rng.randn(rows) * 10, - "ACC_Z": rng.randn(rows) * 10, - "HR": rng.rand(rows) * 30 + 60, - } - stage_col = [] - for i in range(n_epochs): - st = stages_pool[i % len(stages_pool)] - stage_col.extend([st] * EPOCH_LEN) - data["Sleep_Stage"] = stage_col - - csv_path = os.path.join(tmpdir, f"{patient_id}_whole_df.csv") - pd.DataFrame(data).to_csv(csv_path, index=False) - return csv_path - - -def _generate_demo_samples( - n_classes: int = 2, - signal_columns: Optional[List[str]] = None, - n_patients: int = 3, - epochs_per_patient: int = 20, - seed: int = 123, -) -> List[Dict[str, Any]]: - """Create synthetic samples for demo mode. - - Args: - n_classes: Number of label classes (2, 3, or 5). - signal_columns: Which signal columns to include. - n_patients: Number of synthetic patients. - epochs_per_patient: 30-s epochs per patient. - seed: Random seed for reproducibility. - - Returns: - List of sample dicts ready for training. - """ - from types import SimpleNamespace - - rng = np.random.RandomState(seed) - all_samples: List[Dict[str, Any]] = [] - - with tempfile.TemporaryDirectory() as tmpdir: - for p in range(n_patients): - pid = f"DEMO_{p:03d}" - csv_path = _generate_demo_csv( - tmpdir, pid, epochs_per_patient, rng, - ) - - evt = SimpleNamespace(file_64hz=csv_path) - patient = SimpleNamespace( - patient_id=pid, - get_events=lambda et=None, e=evt: [e], - ) - - task = SleepStagingDREAMT( - n_classes=n_classes, - signal_columns=signal_columns, - apply_filters=False, - ) - all_samples.extend(task(patient)) - - return all_samples - - -# ----------------------------------------------------------- -# Main -# ----------------------------------------------------------- - -DEFAULT_ROOT = os.path.expanduser("~/.pyhealth/dreamt") - - -def _resolve_root( - root_arg: Optional[str], -) -> str: - """Find a valid DREAMT root, or exit with guidance. - - Args: - root_arg: User-supplied ``--root`` value, or None. - - Returns: - Absolute path to the DREAMT version directory. - - Raises: - SystemExit: If no valid directory is found. - """ - candidates = ( - [root_arg] - if root_arg - else [ - DEFAULT_ROOT, - os.path.expanduser("~/data/dreamt"), - os.path.expanduser("~/dreamt"), - ] - ) - for path in candidates: - if path and os.path.isdir(path): - info = os.path.join(path, "participant_info.csv") - if os.path.isfile(info): - return path - for sub in sorted(os.listdir(path)): - subpath = os.path.join(path, sub) - if os.path.isdir(subpath) and os.path.isfile( - os.path.join(subpath, "participant_info.csv") - ): - return subpath - print( - "ERROR: Could not find the DREAMT dataset.\n" - "\n" - "Download from PhysioNet (credentialed access):\n" - " https://physionet.org/content/dreamt/\n" - "\n" - "Then either:\n" - f" - Extract to {DEFAULT_ROOT}/\n" - " - Or pass --root /path/to/dreamt/version/\n" - "\n" - "The directory must contain participant_info.csv\n" - "and a data_64Hz/ folder with per-participant CSVs." - ) - raise SystemExit(1) - - -def _run_ablations_real(args: argparse.Namespace) -> None: - """Run ablations on the real DREAMT dataset. - - Args: - args: Parsed command-line arguments. - """ - from pyhealth.datasets import DREAMTDataset - - root = _resolve_root(args.root) - print(f"Loading DREAMT dataset from {root} ...") - dataset = DREAMTDataset(root=root) - - print("\n" + "=" * 60) - print("ABLATION 1: Signal Subset (2-class wake/sleep)") - print("=" * 60) - - for subset_name, columns in tqdm(SIGNAL_SUBSETS.items(), desc="Signal subsets", unit="subset"): - tqdm.write(f"\n--- Signal subset: {subset_name} ---") - task = SleepStagingDREAMT( - n_classes=2, - signal_columns=columns, - ) - sample_ds = dataset.set_task(task) - samples = [sample_ds[i] for i in tqdm(range(len(sample_ds)), desc=" Loading samples", leave=False)] - tqdm.write(f" Total samples: {len(samples)}") - avg = participant_cv( - samples, - num_classes=2, - epochs=args.epochs, - hidden_dim=args.hidden_dim, - device=args.device, - ) - tqdm.write(f" Average: {avg}") - - print("\n" + "=" * 60) - print("ABLATION 2: Label Granularity (ALL signals)") - print("=" * 60) - - print("\n--- 2-class (wake vs sleep) ---") - task_2 = SleepStagingDREAMT(n_classes=2) - sd_2 = dataset.set_task(task_2) - samps_2 = [sd_2[i] for i in tqdm(range(len(sd_2)), desc=" Loading samples", leave=False)] - avg_2 = participant_cv( - samps_2, - num_classes=2, - epochs=args.epochs, - hidden_dim=args.hidden_dim, - device=args.device, - ) - print(f" Average: {avg_2}") - - print("\n--- 5-class (W/N1/N2/N3/R) ---") - task_5 = SleepStagingDREAMT(n_classes=5) - sd_5 = dataset.set_task(task_5) - samps_5 = [sd_5[i] for i in tqdm(range(len(sd_5)), desc=" Loading samples", leave=False)] - avg_5 = participant_cv( - samps_5, - num_classes=5, - epochs=args.epochs, - hidden_dim=args.hidden_dim, - device=args.device, - ) - print(f" Average: {avg_5}") - - -def _run_ablations_demo(args: argparse.Namespace) -> None: - """Run ablations on synthetic demo data. - - Args: - args: Parsed command-line arguments. - """ - print("=== DEMO MODE (synthetic data) ===\n") - print("Generating 3 synthetic patients (20 epochs each) ...") - - demo_epochs = min(args.epochs, 2) - - print("\n" + "=" * 60) - print("ABLATION 1: Signal Subset (2-class, demo)") - print("=" * 60) - - for subset_name, columns in tqdm(SIGNAL_SUBSETS.items(), desc="Signal subsets", unit="subset"): - tqdm.write(f"\n--- Signal subset: {subset_name} ---") - seed = abs(hash(subset_name)) % (2**31) - sub_samples = _generate_demo_samples( - n_classes=2, - signal_columns=columns, - n_patients=3, - seed=seed, - ) - tqdm.write(f" Total samples: {len(sub_samples)}") - n_pids = len(set(s["patient_id"] for s in sub_samples)) - avg = participant_cv( - sub_samples, - n_folds=min(3, n_pids), - num_classes=2, - epochs=demo_epochs, - hidden_dim=args.hidden_dim, - device=args.device, - ) - tqdm.write(f" Average: {avg}") - - print("\n" + "=" * 60) - print("ABLATION 2: Label Granularity (demo)") - print("=" * 60) - - print("\n--- 2-class (wake vs sleep) ---") - samples_2 = _generate_demo_samples( - n_classes=2, n_patients=3, seed=123, - ) - n_pids_2 = len(set(s["patient_id"] for s in samples_2)) - avg_2 = participant_cv( - samples_2, - n_folds=min(3, n_pids_2), - num_classes=2, - epochs=demo_epochs, - hidden_dim=args.hidden_dim, - device=args.device, - ) - print(f" Average: {avg_2}") - - print("\n--- 5-class (W/N1/N2/N3/R) ---") - samples_5 = _generate_demo_samples( - n_classes=5, n_patients=3, seed=123, - ) - n_pids_5 = len(set(s["patient_id"] for s in samples_5)) - avg_5 = participant_cv( - samples_5, - n_folds=min(3, n_pids_5), - num_classes=5, - epochs=demo_epochs, - hidden_dim=args.hidden_dim, - device=args.device, - ) - print(f" Average: {avg_5}") - - print("\nDemo complete.") - - -def main() -> None: - """Entry point for the DREAMT LSTM ablation study.""" - parser = argparse.ArgumentParser( - description="DREAMT LSTM ablation study", - ) - parser.add_argument( - "--root", - default=None, - help=( - "Path to DREAMT dataset. " - f"Default: {DEFAULT_ROOT}" - ), - ) - parser.add_argument( - "--demo", - action="store_true", - help=( - "Run with synthetic data instead of real " - "DREAMT (no dataset download required)." - ), - ) - parser.add_argument( - "--epochs", - type=int, - default=30, - help="Training epochs per fold", - ) - parser.add_argument( - "--hidden_dim", - type=int, - default=64, - help="LSTM hidden dimension", - ) - parser.add_argument( - "--device", - default="cpu", - help="Device (cpu or cuda)", - ) - args = parser.parse_args() - - if args.demo: - _run_ablations_demo(args) - else: - _run_ablations_real(args) - - -if __name__ == "__main__": - main() diff --git a/pyhealth/tasks/sleep_staging_dreamt.py b/pyhealth/tasks/sleep_staging_dreamt.py index 330674903..7774ea8e8 100644 --- a/pyhealth/tasks/sleep_staging_dreamt.py +++ b/pyhealth/tasks/sleep_staging_dreamt.py @@ -280,7 +280,7 @@ def __call__( try: events = patient.get_events(event_type="dreamt_sleep") - except Exception: + except (TypeError, KeyError): events = patient.get_events() if not events: @@ -295,7 +295,7 @@ def __call__( try: df = pd.read_csv(str(file_path)) - except Exception as exc: + except (FileNotFoundError, pd.errors.EmptyDataError, OSError) as exc: logger.warning("Could not read %s: %s", file_path, exc) return [] diff --git a/tests/core/test_sleep_staging_dreamt.py b/tests/core/test_sleep_staging_dreamt.py index 80812f532..c2c8671ab 100644 --- a/tests/core/test_sleep_staging_dreamt.py +++ b/tests/core/test_sleep_staging_dreamt.py @@ -14,9 +14,6 @@ from pyhealth.tasks.sleep_staging_dreamt import ( ALL_SIGNAL_COLUMNS, - LABEL_MAP_2CLASS, - LABEL_MAP_3CLASS, - LABEL_MAP_5CLASS, SleepStagingDREAMT, ) From 35951987ed509c50a653e61453d07219886f3812 Mon Sep 17 00:00:00 2001 From: "Erie M. Adames" Date: Thu, 9 Apr 2026 16:44:09 -0400 Subject: [PATCH 3/3] minor fix --- pyhealth/tasks/sleep_staging_dreamt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/tasks/sleep_staging_dreamt.py b/pyhealth/tasks/sleep_staging_dreamt.py index 7774ea8e8..21477e8cd 100644 --- a/pyhealth/tasks/sleep_staging_dreamt.py +++ b/pyhealth/tasks/sleep_staging_dreamt.py @@ -184,7 +184,7 @@ class SleepStagingDREAMT(BaseTask): - **BVP**: Chebyshev Type II bandpass, 0.5-20 Hz - **TEMP**: Winsorized to [31, 40] C - Epochs labeled ``"P"`` (preparation) or ``"Missing"`` are excluded. + Epochs labeled ``"P"`` or ``"Missing"`` are excluded. Attributes: task_name: ``"SleepStagingDREAMT"``