Skip to content

Add CNN-LSTM model for ICU Mortality Prediction#943

Draft
nikhita2 wants to merge 1 commit intosunlabuiuc:masterfrom
nikhita2:nikhita2/add-cnn-lstm-model
Draft

Add CNN-LSTM model for ICU Mortality Prediction#943
nikhita2 wants to merge 1 commit intosunlabuiuc:masterfrom
nikhita2:nikhita2/add-cnn-lstm-model

Conversation

@nikhita2
Copy link
Copy Markdown

@nikhita2 nikhita2 commented Apr 5, 2026

Contributor: Nikhita Shanker (nikhita2)

Contribution Type: Model

Paper: "Robust Mortality Prediction in the Intensive Care Unit using Temporal Difference Learning" (Frost et al.)
Paper Link: https://arxiv.org/pdf/2411.04285
Repository: https://github.com/tdgfrost/td-icu-mortality

Description: This PR adds the CNNLSTMPredictor model to PyHealth, implementing the CNN-LSTM hybrid architecture from the td-icu-mortality paper. The model processes discrete medical codes through:

  1. Embedding layer — maps discrete codes (e.g., ICD codes) to dense vectors via PyHealth's EmbeddingModel
  2. CNN encoder — Conv1d + BatchNorm + ReLU + MaxPool layers extract local patterns from embedded sequences
  3. LSTM encoder — multi-layer LSTM captures sequential dependencies across the CNN-encoded features
  4. Dense decoder — BatchNorm → Linear → ReLU → Dropout → BatchNorm → Linear produces the final prediction

The model inherits from pyhealth.models.BaseModel, implements the required forward() method, and is compatible with any PyHealth SampleDataset with sequence input features.

Relation to Paper Replication: This PR directly implements the CNN-LSTM model from our replicated paper (Frost et al.). The architecture in cnn_lstm.py follows the original paper's CNN-LSTM design, and the ablation study extends the paper by exploring hyperparameter sensitivity (learning rate, hidden dimension, dropout, batch size), whereas the original paper uses fixed default values for these hyperparameters.

Files to Review:

  • pyhealth/models/cnn_lstm.py — Core model implementation (CNNLSTMPredictor)
  • tests/core/test_cnn_lstm.py — Synthetic data unit tests (7 tests, runs in seconds)
  • examples/mimic4_icu_mortality_cnn_lstm.py — Ablation study script with learning rate, hidden dim, dropout, and batch size sweeps
  • docs/api/models/pyhealth.models.cnn_lstm.rst — API documentation RST file
  • docs/api/models.rst — Updated index (added cnn_lstm to models toctree)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant