diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..676bfe009 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -204,3 +204,4 @@ API Reference models/pyhealth.models.TextEmbedding models/pyhealth.models.BIOT models/pyhealth.models.unified_multimodal_embedding_docs + models/pyhealth.models.td_lstm_mortality \ No newline at end of file diff --git a/docs/api/models/pyhealth.models.td_lstm_mortality.rst b/docs/api/models/pyhealth.models.td_lstm_mortality.rst new file mode 100644 index 000000000..59a907f91 --- /dev/null +++ b/docs/api/models/pyhealth.models.td_lstm_mortality.rst @@ -0,0 +1,155 @@ +TDLSTMMortality +================ + +.. currentmodule:: pyhealth.models.td_lstm_mortality + +.. autoclass:: TDLSTMMortality + :members: + :undoc-members: + :show-inheritance: + +Overview +-------- + +``TDLSTMMortality`` is a PyHealth-style reproduction model for ICU mortality +prediction inspired by the paper: + +Frost, Li, and Harris. *Robust Real-Time Mortality Prediction in the ICU using +Temporal Difference Learning* (ML4H 2024). + +This implementation provides a lightweight and contribution-friendly version of +the original idea using: + +- an LSTM encoder over fixed-length time-series features +- binary mortality prediction +- supervised training with terminal binary cross-entropy loss +- temporal-difference (TD) training with bootstrapped future targets + +Compared with the original paper, this implementation intentionally simplifies +the architecture to improve reproducibility and compatibility with the PyHealth +model contribution workflow. + +Key Features +------------ + +- Supports ``training_mode="supervised"`` for standard binary mortality prediction +- Supports ``training_mode="td"`` for temporal-difference learning +- Uses a PyHealth-compatible ``BaseModel`` interface +- Accepts schema-based sample datasets created with ``create_sample_dataset`` +- Returns standard PyHealth outputs including ``loss``, ``y_prob``, + ``y_true``, and ``logit`` + +Constructor +----------- + +.. code-block:: python + + from pyhealth.models.td_lstm_mortality import TDLSTMMortality + + model = TDLSTMMortality( + dataset=dataset, + feature_key="x", + label_key="label", + mode="binary", + hidden_dim=32, + gamma=0.95, + alpha_terminal=0.10, + n_step=1, + training_mode="td", + ) + +Parameters +---------- + +- ``dataset``: + PyHealth sample dataset used to infer input/output structure. +- ``feature_key``: + Key for the time-series feature tensor. +- ``label_key``: + Key for the binary mortality label. +- ``mode``: + Currently only ``"binary"`` is supported. +- ``hidden_dim``: + Hidden size of the LSTM encoder. +- ``num_layers``: + Number of LSTM layers. +- ``dropout``: + Dropout rate used when ``num_layers > 1``. +- ``gamma``: + Discount factor for TD target construction. +- ``alpha_terminal``: + Weight for the terminal supervised anchor loss in TD mode. +- ``n_step``: + Number of future steps used in TD target bootstrapping. +- ``lengths_key``: + Optional key for sequence lengths when variable-length sequences are used. +- ``embedding_dim``: + Reserved embedding dimension argument for compatibility/future extension. +- ``training_mode``: + Either ``"supervised"`` or ``"td"``. + +Input Format +------------ + +The model expects batched time-series input with shape ``[B, T, F]`` after +PyHealth collation. + +For schema-based synthetic/sample datasets, the raw per-sample format can be: + +.. code-block:: python + + { + "patient_id": "p1", + "visit_id": "v1", + "x": [timestamps, values], + "label": 1, + } + +where: + +- ``timestamps`` is a list of Python ``datetime`` objects +- ``values`` is a list of length-``T`` feature vectors + +Output +------ + +The forward pass returns a dictionary with keys such as: + +- ``loss``: scalar training loss +- ``y_prob``: final mortality probability +- ``y_true``: binary ground-truth label +- ``logit``: final prediction logit +- ``logits_seq``: per-time-step logits +- ``probs_seq``: per-time-step probabilities + +In TD mode, the output also includes: + +- ``td_loss``: temporal-difference loss term +- ``terminal_loss``: supervised terminal BCE anchor + +Example +------- + +See the runnable example script: + +``examples/mimic4_mortality_td_lstm.py`` + +This example demonstrates: + +- synthetic ICU-style time-series sample generation +- train/validation/test split by patient +- supervised LSTM benchmark training +- TD-learning ablation sweep across discount factors +- final metric reporting with AUROC, F1, recall, and balanced accuracy + +Notes +----- + +This reproduction is aligned with a course project focused on implementing a +PyHealth model contribution based on a published healthcare ML paper. The +implementation emphasizes: + +- clean integration with PyHealth APIs +- lightweight reproducible experiments +- fast synthetic tests +- clear separation between supervised and TD training modes \ No newline at end of file diff --git a/examples/mimic4_mortality_td_lstm.py b/examples/mimic4_mortality_td_lstm.py new file mode 100644 index 000000000..c050df28e --- /dev/null +++ b/examples/mimic4_mortality_td_lstm.py @@ -0,0 +1,547 @@ +from __future__ import annotations + +import copy +from datetime import datetime, timedelta +from typing import Dict, List, Tuple + +import numpy as np +import torch +from sklearn.metrics import ( + accuracy_score, + balanced_accuracy_score, + f1_score, + precision_score, + recall_score, + roc_auc_score, +) + +from pyhealth.datasets import create_sample_dataset, get_dataloader, split_by_patient +from pyhealth.models.td_lstm_mortality import TDLSTMMortality + + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def make_hourly_timestamps( + time_steps: int, + start: datetime | None = None, +) -> List[datetime]: + """Creates evenly spaced hourly timestamps.""" + if start is None: + start = datetime(2024, 1, 1, 0, 0, 0) + return [start + i * timedelta(hours=1) for i in range(time_steps)] + + +def make_synthetic_samples( + num_samples: int = 120, + time_steps: int = 24, + input_dim: int = 8, + seed: int = 42, +) -> List[Dict]: + """Creates a small synthetic binary mortality dataset for example usage. + + Each sample contains: + - patient_id + - visit_id + - x: [timestamps, values] + - label + + The label is loosely correlated with later time steps so that the example + produces non-trivial training behavior. + """ + rng = np.random.default_rng(seed) + samples: List[Dict] = [] + + for i in range(num_samples): + x = rng.normal(0.0, 1.0, size=(time_steps, input_dim)).astype(np.float32) + + late_signal = x[-6:, 0].mean() + 0.5 * x[-3:, 1].mean() + risk_score = late_signal + 0.15 * rng.normal() + label = int(risk_score > 0.15) + + samples.append( + { + "patient_id": f"patient-{i // 2}", + "visit_id": f"visit-{i}", + "x": [ + make_hourly_timestamps( + time_steps, + datetime(2024, 1, 1, 0, 0, 0) + i * timedelta(days=1), + ), + x.tolist(), + ], + "label": label, + } + ) + + return samples + + +def build_dataset( + num_samples: int = 120, + time_steps: int = 24, + input_dim: int = 8, + seed: int = 42, +): + """Builds a schema-based PyHealth sample dataset for the example.""" + samples = make_synthetic_samples( + num_samples=num_samples, + time_steps=time_steps, + input_dim=input_dim, + seed=seed, + ) + return create_sample_dataset( + samples=samples, + input_schema={"x": "timeseries"}, + output_schema={"label": "binary"}, + dataset_name="td_lstm_mortality_example", + ) + + +def tune_threshold_from_probs( + y_true: np.ndarray, + probs: np.ndarray, + thresholds: np.ndarray | None = None, +) -> float: + """Finds the threshold with the best validation F1.""" + if thresholds is None: + thresholds = np.arange(0.10, 0.91, 0.05) + + best_threshold = 0.50 + best_f1 = -1.0 + + for thr in thresholds: + preds = (probs >= thr).astype(int) + score = f1_score(y_true, preds, zero_division=0) + if score > best_f1: + best_f1 = score + best_threshold = float(thr) + + return best_threshold + + +def evaluate_binary( + y_true: np.ndarray, + probs: np.ndarray, + threshold: float = 0.50, +) -> Dict[str, float]: + """Computes standard binary classification metrics.""" + preds = (probs >= threshold).astype(int) + return { + "accuracy": accuracy_score(y_true, preds), + "auroc": roc_auc_score(y_true, probs), + "f1": f1_score(y_true, preds, zero_division=0), + "precision": precision_score(y_true, preds, zero_division=0), + "recall": recall_score(y_true, preds, zero_division=0), + "balanced_accuracy": balanced_accuracy_score(y_true, preds), + "threshold": threshold, + } + + +def get_supervised_probs( + model: TDLSTMMortality, + loader, +) -> Tuple[np.ndarray, np.ndarray]: + """Collects labels and probabilities for supervised evaluation.""" + model.eval() + all_labels: List[np.ndarray] = [] + all_probs: List[np.ndarray] = [] + + with torch.no_grad(): + for batch in loader: + out = model(**batch) + probs = out["y_prob"].detach().cpu().numpy() + labels = out["y_true"].detach().cpu().numpy() + all_probs.append(probs) + all_labels.append(labels) + + return np.concatenate(all_labels), np.concatenate(all_probs) + + +def get_td_probs( + model: TDLSTMMortality, + loader, + eval_step: int = -1, +) -> Tuple[np.ndarray, np.ndarray]: + """Collects labels and probabilities for TD evaluation. + + During evaluation, we pass the model itself as target_model so that the + forward() API requirement is satisfied while still returning the sequence + probabilities needed for analysis. + """ + model.eval() + all_labels: List[np.ndarray] = [] + all_probs: List[np.ndarray] = [] + + with torch.no_grad(): + for batch in loader: + out = model(target_model=model, **batch) + probs_seq = out["probs_seq"].detach().cpu().numpy() + labels = out["y_true"].detach().cpu().numpy() + + if eval_step < 0: + probs = probs_seq[:, -1] + else: + step_idx = min(eval_step, probs_seq.shape[1] - 1) + probs = probs_seq[:, step_idx] + + all_probs.append(probs) + all_labels.append(labels) + + return np.concatenate(all_labels), np.concatenate(all_probs) + + +def train_supervised( + train_dataset, + val_dataset, + hidden_dim: int = 32, + lr: float = 1e-3, + num_epochs: int = 15, + batch_size: int = 16, +) -> Dict[str, object]: + """Trains a supervised benchmark model.""" + train_loader = get_dataloader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=batch_size, shuffle=False) + + model = TDLSTMMortality( + dataset=train_dataset, + feature_key="x", + label_key="label", + mode="binary", + hidden_dim=hidden_dim, + training_mode="supervised", + ).to(DEVICE) + + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + + best_state = None + best_val_loss = float("inf") + + for _ in range(num_epochs): + model.train() + for batch in train_loader: + optimizer.zero_grad() + out = model(**batch) + loss = out["loss"] + loss.backward() + optimizer.step() + + model.eval() + running = 0.0 + steps = 0 + with torch.no_grad(): + for batch in val_loader: + out = model(**batch) + running += float(out["loss"].item()) + steps += 1 + + val_loss = running / max(steps, 1) + if val_loss < best_val_loss: + best_val_loss = val_loss + best_state = copy.deepcopy(model.state_dict()) + + if best_state is not None: + model.load_state_dict(best_state) + + val_labels, val_probs = get_supervised_probs(model, val_loader) + best_threshold = tune_threshold_from_probs(val_labels, val_probs) + val_metrics = evaluate_binary(val_labels, val_probs, best_threshold) + + return { + "model": model, + "best_threshold": best_threshold, + "val_metrics": val_metrics, + } + + +def train_td( + train_dataset, + val_dataset, + hidden_dim: int = 32, + lr: float = 1e-3, + num_epochs: int = 15, + batch_size: int = 16, + gamma: float = 0.95, + alpha_terminal: float = 0.10, + n_step: int = 1, + target_update_every: int = 2, +) -> Dict[str, object]: + """Trains the TD-learning model with a target network.""" + train_loader = get_dataloader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=batch_size, shuffle=False) + + model = TDLSTMMortality( + dataset=train_dataset, + feature_key="x", + label_key="label", + mode="binary", + hidden_dim=hidden_dim, + gamma=gamma, + alpha_terminal=alpha_terminal, + n_step=n_step, + training_mode="td", + ).to(DEVICE) + + target_model = TDLSTMMortality( + dataset=train_dataset, + feature_key="x", + label_key="label", + mode="binary", + hidden_dim=hidden_dim, + gamma=gamma, + alpha_terminal=alpha_terminal, + n_step=n_step, + training_mode="td", + ).to(DEVICE) + target_model.load_state_dict(model.state_dict()) + target_model.eval() + + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + + best_state = None + best_val_loss = float("inf") + + for epoch in range(num_epochs): + model.train() + for batch in train_loader: + optimizer.zero_grad() + out = model(target_model=target_model, **batch) + loss = out["loss"] + loss.backward() + optimizer.step() + + if (epoch + 1) % target_update_every == 0: + target_model.load_state_dict(model.state_dict()) + + model.eval() + running = 0.0 + steps = 0 + with torch.no_grad(): + for batch in val_loader: + out = model(target_model=target_model, **batch) + running += float(out["loss"].item()) + steps += 1 + + val_loss = running / max(steps, 1) + if val_loss < best_val_loss: + best_val_loss = val_loss + best_state = copy.deepcopy(model.state_dict()) + + if best_state is not None: + model.load_state_dict(best_state) + + val_labels, val_probs = get_td_probs(model, val_loader, eval_step=-1) + best_threshold = tune_threshold_from_probs(val_labels, val_probs) + val_metrics = evaluate_binary(val_labels, val_probs, best_threshold) + + return { + "model": model, + "best_threshold": best_threshold, + "val_metrics": val_metrics, + "gamma": gamma, + "alpha_terminal": alpha_terminal, + "n_step": n_step, + } + + +def print_results_table(rows: List[Dict[str, object]]) -> None: + """Prints a compact comparison table.""" + columns = [ + "method", + "gamma", + "alpha_terminal", + "n_step", + "val_auroc", + "test_auroc", + "test_f1", + "test_recall", + "test_balanced_accuracy", + "threshold", + ] + + def _fmt(value): + if value is None: + return "-" + if isinstance(value, float): + return f"{value:.4f}" + return str(value) + + widths = {} + for col in columns: + max_width = len(col) + for row in rows: + max_width = max(max_width, len(_fmt(row.get(col)))) + widths[col] = max_width + + header = " | ".join(col.ljust(widths[col]) for col in columns) + separator = "-+-".join("-" * widths[col] for col in columns) + + print(header) + print(separator) + for row in rows: + print(" | ".join(_fmt(row.get(col)).ljust(widths[col]) for col in columns)) + + +def evaluate_td_over_time( + model: TDLSTMMortality, + loader, + hours: List[int], + threshold: float, +) -> List[Dict[str, float]]: + """Evaluates a TD model at selected time steps.""" + rows: List[Dict[str, float]] = [] + + for hour_idx in hours: + y_true, probs = get_td_probs(model, loader, eval_step=hour_idx) + metrics = evaluate_binary(y_true, probs, threshold) + rows.append( + { + "hour": hour_idx + 1, + "auroc": metrics["auroc"], + "f1": metrics["f1"], + "recall": metrics["recall"], + "balanced_accuracy": metrics["balanced_accuracy"], + } + ) + + return rows + + +def main(): + """Runs a lightweight supervised-vs-TD ablation example.""" + dataset = build_dataset(num_samples=120, time_steps=24, input_dim=8, seed=42) + + train_dataset, val_dataset, test_dataset = split_by_patient( + dataset, + [0.6, 0.2, 0.2], + seed=42, + ) + + print("Dataset sizes:") + print("Train:", len(train_dataset)) + print("Val:", len(val_dataset)) + print("Test:", len(test_dataset)) + + test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False) + + print("\n=== Supervised benchmark ===") + supervised_result = train_supervised( + train_dataset=train_dataset, + val_dataset=val_dataset, + hidden_dim=32, + lr=1e-3, + num_epochs=15, + batch_size=16, + ) + supervised_model = supervised_result["model"] + supervised_threshold = float(supervised_result["best_threshold"]) + + test_labels_sup, test_probs_sup = get_supervised_probs(supervised_model, test_loader) + supervised_test_metrics = evaluate_binary( + test_labels_sup, + test_probs_sup, + supervised_threshold, + ) + print("Supervised validation metrics:", supervised_result["val_metrics"]) + print("Supervised test metrics:", supervised_test_metrics) + + print("\n=== TD ablation sweep ===") + td_configs = [ + {"gamma": 0.90, "alpha_terminal": 0.10, "n_step": 1}, + {"gamma": 0.95, "alpha_terminal": 0.10, "n_step": 1}, + {"gamma": 0.99, "alpha_terminal": 0.10, "n_step": 1}, + ] + + summary_rows: List[Dict[str, object]] = [ + { + "method": "supervised", + "gamma": None, + "alpha_terminal": None, + "n_step": None, + "val_auroc": supervised_result["val_metrics"]["auroc"], + "test_auroc": supervised_test_metrics["auroc"], + "test_f1": supervised_test_metrics["f1"], + "test_recall": supervised_test_metrics["recall"], + "test_balanced_accuracy": supervised_test_metrics["balanced_accuracy"], + "threshold": supervised_threshold, + } + ] + + best_td_result = None + best_td_test_metrics = None + best_td_threshold = None + + for cfg in td_configs: + result = train_td( + train_dataset=train_dataset, + val_dataset=val_dataset, + hidden_dim=32, + lr=1e-3, + num_epochs=15, + batch_size=16, + gamma=cfg["gamma"], + alpha_terminal=cfg["alpha_terminal"], + n_step=cfg["n_step"], + target_update_every=2, + ) + + model = result["model"] + threshold = float(result["best_threshold"]) + test_labels_td, test_probs_td = get_td_probs(model, test_loader, eval_step=-1) + td_test_metrics = evaluate_binary(test_labels_td, test_probs_td, threshold) + + row = { + "method": "td", + "gamma": cfg["gamma"], + "alpha_terminal": cfg["alpha_terminal"], + "n_step": cfg["n_step"], + "val_auroc": result["val_metrics"]["auroc"], + "test_auroc": td_test_metrics["auroc"], + "test_f1": td_test_metrics["f1"], + "test_recall": td_test_metrics["recall"], + "test_balanced_accuracy": td_test_metrics["balanced_accuracy"], + "threshold": threshold, + } + summary_rows.append(row) + print("TD result:", row) + + if ( + best_td_test_metrics is None + or td_test_metrics["auroc"] > best_td_test_metrics["auroc"] + ): + best_td_result = result + best_td_test_metrics = td_test_metrics + best_td_threshold = threshold + + print("\n=== Final comparison table ===") + print_results_table(summary_rows) + + print("\n=== Main result statement ===") + print( + f"Best supervised test AUROC: {supervised_test_metrics['auroc']:.4f}\n" + f"Best TD test AUROC: {best_td_test_metrics['auroc']:.4f}\n" + "Interpretation: the supervised LSTM remains the strongest overall " + "benchmark, while the best 1-step TD configuration is the main TD " + "reproduction result." + ) + + print("\n=== Best TD hour-wise analysis ===") + hour_rows = evaluate_td_over_time( + model=best_td_result["model"], + loader=test_loader, + hours=[0, 5, 11, 23], + threshold=best_td_threshold, + ) + for row in hour_rows: + print(row) + + print( + "\nExpected project-aligned interpretation:\n" + "- supervised LSTM remains the strongest overall benchmark\n" + "- tuned 1-step TD is the main TD result\n" + "- later TD variants are exploratory\n" + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyhealth/models/td_lstm_mortality.py b/pyhealth/models/td_lstm_mortality.py new file mode 100644 index 000000000..d82b16f06 --- /dev/null +++ b/pyhealth/models/td_lstm_mortality.py @@ -0,0 +1,392 @@ +from __future__ import annotations + +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pyhealth.models import BaseModel + + +class TDLSTMMortality(BaseModel): + """Temporal-Difference LSTM model for ICU mortality prediction. + + This model is a simplified PyHealth-native reproduction of the temporal + difference learning idea described in: + + Frost, Li, and Harris. "Robust Real-Time Mortality Prediction in the ICU + using Temporal Difference Learning." ML4H 2024. + + Compared with the original paper, this implementation intentionally keeps + the architecture lightweight for reproducibility and contribution readiness: + - LSTM encoder over fixed-length time-series features + - binary mortality prediction + - supervised mode using terminal BCE loss + - TD mode using bootstrapped future predictions plus terminal BCE anchor + + The model follows the standard PyHealth BaseModel pattern: + - takes a SampleDataset in the constructor + - uses feature_key and label_key selected from dataset schemas + - returns a dictionary containing at least: + loss, y_prob, y_true, logit + """ + + VALID_TRAINING_MODES = {"td", "supervised"} + + def __init__( + self, + dataset, + feature_key: str, + label_key: str, + mode: str = "binary", + hidden_dim: int = 64, + num_layers: int = 1, + dropout: float = 0.0, + gamma: float = 0.95, + alpha_terminal: float = 0.10, + n_step: int = 1, + lengths_key: Optional[str] = None, + embedding_dim: int = 128, + **kwargs, + ) -> None: + super().__init__(dataset=dataset) + + if hidden_dim <= 0: + raise ValueError("hidden_dim must be positive.") + if num_layers <= 0: + raise ValueError("num_layers must be positive.") + if not 0.0 <= dropout < 1.0: + raise ValueError("dropout must be in [0, 1).") + if not 0.0 <= gamma <= 1.0: + raise ValueError("gamma must be in [0, 1].") + if alpha_terminal < 0.0: + raise ValueError("alpha_terminal must be non-negative.") + if n_step <= 0: + raise ValueError("n_step must be positive.") + if mode != "binary": + raise ValueError( + "TDLSTMMortality currently only supports binary classification." + ) + + if feature_key not in self.feature_keys: + raise ValueError( + f"feature_key '{feature_key}' not found in dataset feature_keys: " + f"{self.feature_keys}" + ) + if label_key not in self.label_keys: + raise ValueError( + f"label_key '{label_key}' not found in dataset label_keys: " + f"{self.label_keys}" + ) + + self.feature_key = feature_key + self.label_key = label_key + self.lengths_key = lengths_key + self.task_mode = mode + + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.dropout = dropout + self.gamma = gamma + self.alpha_terminal = alpha_terminal + self.n_step = n_step + self.embedding_dim = embedding_dim + + self.training_mode = kwargs.pop("training_mode", "td") + if self.training_mode not in self.VALID_TRAINING_MODES: + raise ValueError( + f"training_mode must be one of {self.VALID_TRAINING_MODES}." + ) + + self.input_dim = self._infer_input_dim_from_dataset() + + self.lstm = nn.LSTM( + input_size=self.input_dim, + hidden_size=self.hidden_dim, + num_layers=self.num_layers, + batch_first=True, + dropout=self.dropout if self.num_layers > 1 else 0.0, + ) + self.output_layer = nn.Linear(self.hidden_dim, self.get_output_size()) + + def _infer_input_dim_from_dataset(self) -> int: + """Infers feature dimension from dataset samples.""" + try: + sample = self.dataset[0] + if self.feature_key in sample: + x = sample[self.feature_key] + if torch.is_tensor(x): + if x.ndim == 2: + return int(x.shape[-1]) + raise ValueError( + f"Processed feature '{self.feature_key}' must be 2D [T, F], " + f"got tensor shape {tuple(x.shape)}." + ) + except Exception: + pass + + raw_samples = getattr(self.dataset, "samples", None) + if raw_samples is not None and len(raw_samples) > 0: + raw_x = raw_samples[0][self.feature_key] + + if ( + isinstance(raw_x, (list, tuple)) + and len(raw_x) == 2 + and isinstance(raw_x[1], (list, tuple)) + and len(raw_x[1]) > 0 + ): + first_row = raw_x[1][0] + if isinstance(first_row, (list, tuple)): + return len(first_row) + + if isinstance(raw_x, (list, tuple)) and len(raw_x) > 0: + first_row = raw_x[0] + if isinstance(first_row, (list, tuple)): + return len(first_row) + + raise ValueError( + f"Unable to infer input_dim for feature_key '{self.feature_key}' " + "from dataset." + ) + + def _get_lengths_from_kwargs(self, kwargs: Dict) -> Optional[torch.Tensor]: + """Gets sequence lengths from kwargs if a lengths_key is configured.""" + if self.lengths_key is None: + return None + lengths = kwargs.get(self.lengths_key) + if lengths is None: + return None + if not torch.is_tensor(lengths): + lengths = torch.tensor(lengths, dtype=torch.long, device=self.device) + else: + lengths = lengths.to(self.device) + return lengths + + def _prepare_input_tensor(self, x) -> torch.Tensor: + """Converts batch feature input into a float tensor on model device.""" + if torch.is_tensor(x): + x_tensor = x.float().to(self.device) + else: + x_tensor = torch.tensor(x, dtype=torch.float32, device=self.device) + + if x_tensor.ndim != 3: + raise ValueError( + f"Expected batched input of shape [B, T, F], got {tuple(x_tensor.shape)}." + ) + return x_tensor + + def _prepare_labels(self, y) -> torch.Tensor: + """Converts labels into float tensor on model device.""" + if torch.is_tensor(y): + y_tensor = y.float().to(self.device) + else: + y_tensor = torch.tensor(y, dtype=torch.float32, device=self.device) + + if y_tensor.ndim > 1: + y_tensor = y_tensor.view(-1) + return y_tensor + + def encode( + self, + x: torch.Tensor, + lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Encodes time-series input with an LSTM.""" + if lengths is None: + encoded, _ = self.lstm(x) + return encoded + + packed = nn.utils.rnn.pack_padded_sequence( + x, + lengths=lengths.cpu(), + batch_first=True, + enforce_sorted=False, + ) + packed_out, _ = self.lstm(packed) + encoded, _ = nn.utils.rnn.pad_packed_sequence( + packed_out, + batch_first=True, + total_length=x.shape[1], + ) + return encoded + + def forward_logits( + self, + x: torch.Tensor, + lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Computes per-time-step logits.""" + encoded = self.encode(x, lengths=lengths) + logits = self.output_layer(encoded).squeeze(-1) + return logits + + @staticmethod + def _gather_last_valid_step( + sequence_tensor: torch.Tensor, + lengths: Optional[torch.Tensor], + ) -> torch.Tensor: + """Selects the final valid time step from a [B, T] tensor.""" + if lengths is None: + return sequence_tensor[:, -1] + + idx = (lengths - 1).clamp(min=0) + batch_idx = torch.arange(sequence_tensor.size(0), device=sequence_tensor.device) + return sequence_tensor[batch_idx, idx] + + def build_n_step_targets( + self, + target_probs: torch.Tensor, + y_true: torch.Tensor, + lengths: Optional[torch.Tensor] = None, + gamma: Optional[float] = None, + n_step: Optional[int] = None, + ) -> torch.Tensor: + """Builds n-step TD targets from target-network probabilities.""" + gamma = self.gamma if gamma is None else gamma + n_step = self.n_step if n_step is None else n_step + + if y_true.ndim > 1: + y_true = y_true.view(-1) + + batch_size, time_steps = target_probs.shape + td_targets = torch.zeros_like(target_probs) + + for t in range(time_steps): + future_idx = t + n_step + if future_idx < time_steps: + td_targets[:, t] = (gamma**n_step) * target_probs[:, future_idx] + else: + td_targets[:, t] = y_true + + if lengths is None: + td_targets[:, -1] = y_true + else: + last_idx = (lengths - 1).clamp(min=0) + batch_idx = torch.arange(batch_size, device=target_probs.device) + td_targets[batch_idx, last_idx] = y_true + + return td_targets + + def compute_supervised_loss( + self, + logits_seq: torch.Tensor, + y_true: torch.Tensor, + lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Computes supervised binary loss on the final valid time step.""" + final_logits = self._gather_last_valid_step(logits_seq, lengths) + criterion = self.get_loss_function() + return criterion(final_logits, y_true) + + def compute_td_loss( + self, + logits_seq: torch.Tensor, + x: torch.Tensor, + y_true: torch.Tensor, + target_model: "TDLSTMMortality", + lengths: Optional[torch.Tensor] = None, + gamma: Optional[float] = None, + alpha_terminal: Optional[float] = None, + n_step: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Computes TD loss plus terminal BCE anchor.""" + alpha_terminal = ( + self.alpha_terminal if alpha_terminal is None else alpha_terminal + ) + + probs_seq = torch.sigmoid(logits_seq) + + with torch.no_grad(): + target_logits_seq = target_model.forward_logits(x, lengths=lengths) + target_probs_seq = torch.sigmoid(target_logits_seq) + + td_targets = self.build_n_step_targets( + target_probs=target_probs_seq, + y_true=y_true, + lengths=lengths, + gamma=gamma, + n_step=n_step, + ) + + td_mse = F.mse_loss(probs_seq, td_targets) + terminal_loss = self.compute_supervised_loss( + logits_seq=logits_seq, + y_true=y_true, + lengths=lengths, + ) + + total_loss = td_mse + alpha_terminal * terminal_loss + return total_loss, td_mse, terminal_loss + + def forward( + self, + target_model: Optional["TDLSTMMortality"] = None, + embed: bool = False, + **kwargs, + ) -> Dict[str, torch.Tensor]: + """Forward pass in standard PyHealth format.""" + if self.feature_key not in kwargs: + raise KeyError(f"Missing feature_key '{self.feature_key}' in batch.") + + if self.label_key not in kwargs: + raise KeyError(f"Missing label_key '{self.label_key}' in batch.") + + x = self._prepare_input_tensor(kwargs[self.feature_key]) + y_true = self._prepare_labels(kwargs[self.label_key]) + lengths = self._get_lengths_from_kwargs(kwargs) + + logits_seq = self.forward_logits(x, lengths=lengths) + probs_seq = torch.sigmoid(logits_seq) + + logit = self._gather_last_valid_step(logits_seq, lengths) + y_prob = self.prepare_y_prob(logit.view(-1, 1)).view(-1) + + ret: Dict[str, torch.Tensor] = { + "y_true": y_true, + "logit": logit, + "y_prob": y_prob, + "logits_seq": logits_seq, + "probs_seq": probs_seq, + } + + if embed: + encoded = self.encode(x, lengths=lengths) + if lengths is None: + embedding = encoded[:, -1, :] + else: + idx = (lengths - 1).clamp(min=0) + batch_idx = torch.arange(encoded.size(0), device=encoded.device) + embedding = encoded[batch_idx, idx, :] + ret["embedding"] = embedding + + if self.training_mode == "supervised": + loss = self.compute_supervised_loss( + logits_seq=logits_seq, + y_true=y_true, + lengths=lengths, + ) + ret["loss"] = loss + return ret + + if self.training_mode == "td": + if target_model is None: + raise ValueError( + "target_model must be provided when training_mode='td'." + ) + loss, td_loss, terminal_loss = self.compute_td_loss( + logits_seq=logits_seq, + x=x, + y_true=y_true, + target_model=target_model, + lengths=lengths, + ) + ret["loss"] = loss + ret["td_loss"] = td_loss + ret["terminal_loss"] = terminal_loss + return ret + + raise ValueError( + f"Unsupported training_mode '{self.training_mode}'. " + f"Expected one of {self.VALID_TRAINING_MODES}." + ) \ No newline at end of file diff --git a/tests/test_td_lstm_mortality.py b/tests/test_td_lstm_mortality.py new file mode 100644 index 000000000..b58f5b01f --- /dev/null +++ b/tests/test_td_lstm_mortality.py @@ -0,0 +1,360 @@ +from datetime import datetime, timedelta + +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader, split_by_patient +from pyhealth.models.td_lstm_mortality import TDLSTMMortality + + +def _make_hourly_timestamps(num_steps: int, start: datetime | None = None): + if start is None: + start = datetime(2024, 1, 1, 0, 0, 0) + return [start + i * timedelta(hours=1) for i in range(num_steps)] + + +def _make_samples(): + return [ + { + "patient_id": "p1", + "visit_id": "v1", + "x": [ + _make_hourly_timestamps(3, datetime(2024, 1, 1, 0, 0, 0)), + [ + [0.1, 0.2, 0.3, 0.4], + [0.2, 0.1, 0.0, 0.5], + [0.4, 0.3, 0.2, 0.1], + ], + ], + "label": 0, + }, + { + "patient_id": "p2", + "visit_id": "v2", + "x": [ + _make_hourly_timestamps(3, datetime(2024, 1, 2, 0, 0, 0)), + [ + [0.5, 0.4, 0.3, 0.2], + [0.6, 0.5, 0.4, 0.3], + [0.7, 0.6, 0.5, 0.4], + ], + ], + "label": 1, + }, + { + "patient_id": "p3", + "visit_id": "v3", + "x": [ + _make_hourly_timestamps(3, datetime(2024, 1, 3, 0, 0, 0)), + [ + [0.9, 0.8, 0.7, 0.6], + [0.8, 0.7, 0.6, 0.5], + [0.7, 0.6, 0.5, 0.4], + ], + ], + "label": 0, + }, + ] + + +def _make_dataset(): + samples = _make_samples() + return create_sample_dataset( + samples=samples, + input_schema={"x": "timeseries"}, + output_schema={"label": "binary"}, + dataset_name="test_td_lstm_mortality", + ) + + +def _make_batch(): + dataset = _make_dataset() + loader = get_dataloader(dataset, batch_size=3, shuffle=False) + batch = next(iter(loader)) + return dataset, batch + + +def test_model_instantiation_supervised(): + dataset = _make_dataset() + model = TDLSTMMortality( + dataset=dataset, + feature_key="x", + label_key="label", + mode="binary", + hidden_dim=16, + num_layers=1, + dropout=0.0, + gamma=0.95, + alpha_terminal=0.10, + n_step=1, + training_mode="supervised", + ) + + assert model.feature_key == "x" + assert model.label_key == "label" + assert model.hidden_dim == 16 + assert model.training_mode == "supervised" + + +def test_model_instantiation_td(): + dataset = _make_dataset() + model = TDLSTMMortality( + dataset=dataset, + feature_key="x", + label_key="label", + mode="binary", + hidden_dim=16, + gamma=0.95, + alpha_terminal=0.10, + n_step=1, + training_mode="td", + ) + + assert model.feature_key == "x" + assert model.label_key == "label" + assert model.training_mode == "td" + + +def test_forward_output_shapes_supervised(): + dataset, batch = _make_batch() + + model = TDLSTMMortality( + dataset=dataset, + feature_key="x", + label_key="label", + mode="binary", + hidden_dim=16, + training_mode="supervised", + ) + + out = model(**batch) + + assert "loss" in out + assert "y_prob" in out + assert "y_true" in out + assert "logit" in out + assert "logits_seq" in out + assert "probs_seq" in out + + assert out["logits_seq"].shape[0] == 3 + assert out["probs_seq"].shape[0] == 3 + assert out["logit"].shape == (3,) + assert out["y_prob"].shape == (3,) + assert out["y_true"].shape == (3,) + assert out["loss"].ndim == 0 + + +def test_forward_output_shapes_td(): + dataset, batch = _make_batch() + + model = TDLSTMMortality( + dataset=dataset, + feature_key="x", + label_key="label", + mode="binary", + hidden_dim=16, + training_mode="td", + ) + target_model = TDLSTMMortality( + dataset=dataset, + feature_key="x", + label_key="label", + mode="binary", + hidden_dim=16, + training_mode="td", + ) + target_model.load_state_dict(model.state_dict()) + + out = model(target_model=target_model, **batch) + + assert "loss" in out + assert "y_prob" in out + assert "y_true" in out + assert "logit" in out + assert "logits_seq" in out + assert "probs_seq" in out + assert "td_loss" in out + assert "terminal_loss" in out + + assert out["logits_seq"].shape[0] == 3 + assert out["probs_seq"].shape[0] == 3 + assert out["logit"].shape == (3,) + assert out["y_prob"].shape == (3,) + assert out["y_true"].shape == (3,) + assert out["loss"].ndim == 0 + + +def test_probability_range(): + dataset, batch = _make_batch() + + model = TDLSTMMortality( + dataset=dataset, + feature_key="x", + label_key="label", + mode="binary", + hidden_dim=16, + training_mode="supervised", + ) + + out = model(**batch) + + assert torch.all(out["probs_seq"] >= 0.0) + assert torch.all(out["probs_seq"] <= 1.0) + assert torch.all(out["y_prob"] >= 0.0) + assert torch.all(out["y_prob"] <= 1.0) + + +def test_build_n_step_targets_shape(): + dataset, batch = _make_batch() + + model = TDLSTMMortality( + dataset=dataset, + feature_key="x", + label_key="label", + mode="binary", + hidden_dim=16, + gamma=0.95, + n_step=2, + training_mode="td", + ) + + x = batch["x"] + y_true = batch["label"] + + logits_seq = model.forward_logits(x) + probs_seq = torch.sigmoid(logits_seq) + + td_targets = model.build_n_step_targets( + target_probs=probs_seq, + y_true=y_true, + ) + + assert td_targets.shape == probs_seq.shape + + +def test_td_loss_backward_runs(): + dataset, batch = _make_batch() + + model = TDLSTMMortality( + dataset=dataset, + feature_key="x", + label_key="label", + mode="binary", + hidden_dim=16, + training_mode="td", + ) + target_model = TDLSTMMortality( + dataset=dataset, + feature_key="x", + label_key="label", + mode="binary", + hidden_dim=16, + training_mode="td", + ) + target_model.load_state_dict(model.state_dict()) + + out = model(target_model=target_model, **batch) + out["loss"].backward() + + grads = [p.grad for p in model.parameters() if p.requires_grad] + assert any(g is not None for g in grads) + + +def test_supervised_loss_backward_runs(): + dataset, batch = _make_batch() + + model = TDLSTMMortality( + dataset=dataset, + feature_key="x", + label_key="label", + mode="binary", + hidden_dim=16, + training_mode="supervised", + ) + + out = model(**batch) + out["loss"].backward() + + grads = [p.grad for p in model.parameters() if p.requires_grad] + assert any(g is not None for g in grads) + + +def test_embed_output_shape(): + dataset, batch = _make_batch() + + model = TDLSTMMortality( + dataset=dataset, + feature_key="x", + label_key="label", + mode="binary", + hidden_dim=16, + training_mode="supervised", + ) + + out = model(embed=True, **batch) + + assert "embedding" in out + assert out["embedding"].shape == (3, 16) + + +def test_final_prediction_selection_shape(): + dataset, batch = _make_batch() + + model = TDLSTMMortality( + dataset=dataset, + feature_key="x", + label_key="label", + mode="binary", + hidden_dim=16, + training_mode="supervised", + ) + + logits_seq = model.forward_logits(batch["x"]) + gathered = model._gather_last_valid_step(logits_seq, None) + + assert gathered.shape == (3,) + + +def test_invalid_training_mode_raises(): + dataset = _make_dataset() + + try: + _ = TDLSTMMortality( + dataset=dataset, + feature_key="x", + label_key="label", + mode="binary", + training_mode="bad_mode", + ) + assert False, "Expected ValueError for invalid training_mode" + except ValueError: + assert True + + +def test_td_mode_requires_target_model(): + dataset, batch = _make_batch() + + model = TDLSTMMortality( + dataset=dataset, + feature_key="x", + label_key="label", + mode="binary", + hidden_dim=16, + training_mode="td", + ) + + try: + _ = model(**batch) + assert False, "Expected ValueError when target_model is missing in TD mode" + except ValueError: + assert True + + +def test_split_by_patient_runs(): + dataset = _make_dataset() + train_dataset, val_dataset, test_dataset = split_by_patient( + dataset, + [0.6, 0.2, 0.2], + seed=42, + ) + + assert len(train_dataset) + len(val_dataset) + len(test_dataset) == len(dataset) \ No newline at end of file