Skip to content

Add TDLSTMMortality model for ICU mortality prediction via temporal-difference learning#955

Open
MrPrabhatShreeprakashSingh wants to merge 1 commit intosunlabuiuc:masterfrom
MrPrabhatShreeprakashSingh:td-lstm-mortality
Open

Add TDLSTMMortality model for ICU mortality prediction via temporal-difference learning#955
MrPrabhatShreeprakashSingh wants to merge 1 commit intosunlabuiuc:masterfrom
MrPrabhatShreeprakashSingh:td-lstm-mortality

Conversation

@MrPrabhatShreeprakashSingh
Copy link
Copy Markdown

This pull request adds a new PyHealth model contribution:

  • Model: TDLSTMMortality
  • Contribution type: Model
  • Paper: Frost, Li, and Harris, Robust Real-Time Mortality Prediction in the Intensive Care Unit using Temporal Difference Learning (ML4H 2024)

This implementation is a PyHealth-compatible reproduction of the paper’s core temporal-difference learning idea for ICU mortality prediction. The contributed model provides:

  • an LSTM-based sequential encoder for time-series mortality prediction
  • a supervised training mode using terminal BCE loss
  • a TD-learning mode using bootstrapped future predictions plus a terminal BCE anchor
  • synthetic tests and a runnable example ablation script

This contribution is intentionally a simplified PyHealth-native reproduction of the paper’s central idea rather than a full reimplementation of the original CNN+LSTM/state-marker pipeline.

Contributors

  • Udit Sharma (udits2@illinois.edu)
  • Joe Haenel (jhaenel2@illinois.edu)
  • Prabhat Singh (pssingh2@illinois.edu)

Original paper

Robust Real-Time Mortality Prediction in the Intensive Care Unit using Temporal Difference Learning

Why this contribution

The project reproduces the paper’s central idea in a lightweight, PyHealth-native form suitable for educational reproducibility, testing, and future extension. Compared with the original paper, the architecture is simplified to an LSTM-only version to improve compatibility with PyHealth’s model interface and contribution workflow.

Files to review

Core implementation

  • pyhealth/models/td_lstm_mortality.py

Tests

  • tests/test_td_lstm_mortality.py

Example / ablation

  • examples/mimic4_mortality_td_lstm.py

Documentation

  • docs/api/models/pyhealth.models.td_lstm_mortality.rst
  • docs/api/models.rst

Implementation notes

  • Inherits from BaseModel
  • Supports training_mode="supervised" and training_mode="td"
  • Infers input_dim from dataset samples for compatibility with local PyHealth APIs
  • Uses [timestamps, values] synthetic timeseries sample format in tests/examples
  • TD evaluation in the example uses model(target_model=model, **batch) to satisfy the TD forward API

Example results

The included example script runs end to end on synthetic data and reports:

  • supervised validation/test metrics
  • a TD ablation sweep over gamma
  • a compact comparison table across supervised and TD settings
  • a project-aligned interpretation that supervised LSTM remains the strongest
    overall benchmark, while tuned 1-step TD is the main TD result

Checklist

  • Model implementation added
  • Synthetic tests added
  • Runnable example added
  • Documentation RST added
  • Index RST updated
  • Original paper linked
  • Rebased onto latest main
  • Allow edits by maintainers enabled

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant