Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 275 additions & 0 deletions tests/core/test_rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
import unittest

import torch

from pyhealth.datasets import create_sample_dataset, get_dataloader
from pyhealth.models import RNN
from pyhealth.models.rnn import MultimodalRNN


class TestRNN(unittest.TestCase):
"""Test cases for the RNN model."""

def setUp(self):
"""Set up test data and model."""
self.samples = [
{
"patient_id": "patient-0",
"visit_id": "visit-0",
"conditions": ["cond-33", "cond-86", "cond-80", "cond-12"],
"procedures": ["proc-12", "proc-45", "proc-23"],
"label": 0,
},
{
"patient_id": "patient-1",
"visit_id": "visit-1",
"conditions": ["cond-33", "cond-86", "cond-80"],
"procedures": ["proc-12"],
"label": 1,
},
]

self.input_schema = {
"conditions": "sequence",
"procedures": "sequence",
}
self.output_schema = {"label": "binary"}

self.dataset = create_sample_dataset(
samples=self.samples,
input_schema=self.input_schema,
output_schema=self.output_schema,
dataset_name="test",
)

self.model = RNN(dataset=self.dataset)

def test_model_initialization(self):
"""Test that the RNN model initializes correctly."""
self.assertIsInstance(self.model, RNN)
self.assertEqual(self.model.embedding_dim, 128)
self.assertEqual(self.model.hidden_dim, 128)
self.assertEqual(len(self.model.feature_keys), 2)
self.assertIn("conditions", self.model.feature_keys)
self.assertIn("procedures", self.model.feature_keys)
self.assertEqual(self.model.label_key, "label")

def test_model_forward(self):
"""Test that the RNN model forward pass works correctly."""
train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
data_batch = next(iter(train_loader))

with torch.no_grad():
ret = self.model(**data_batch)

self.assertIn("loss", ret)
self.assertIn("y_prob", ret)
self.assertIn("y_true", ret)
self.assertIn("logit", ret)

self.assertEqual(ret["y_prob"].shape[0], 2)
self.assertEqual(ret["y_true"].shape[0], 2)
self.assertEqual(ret["logit"].shape[0], 2)
self.assertEqual(ret["loss"].dim(), 0)

def test_model_backward(self):
"""Test that the RNN model backward pass works correctly."""
train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
data_batch = next(iter(train_loader))

ret = self.model(**data_batch)
ret["loss"].backward()

has_gradient = False
for param in self.model.parameters():
if param.requires_grad and param.grad is not None:
has_gradient = True
break
self.assertTrue(
has_gradient, "No parameters have gradients after backward pass"
)

def test_model_with_embedding(self):
"""Test that the RNN model returns embeddings when requested."""
train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
data_batch = next(iter(train_loader))
data_batch["embed"] = True

with torch.no_grad():
ret = self.model(**data_batch)

self.assertIn("embed", ret)
self.assertEqual(ret["embed"].shape[0], 2)
expected_embed_dim = len(self.model.feature_keys) * self.model.hidden_dim
self.assertEqual(ret["embed"].shape[1], expected_embed_dim)

def test_custom_hyperparameters(self):
"""Test RNN model with custom hyperparameters."""
model = RNN(
dataset=self.dataset,
embedding_dim=64,
hidden_dim=32,
)

self.assertEqual(model.embedding_dim, 64)
self.assertEqual(model.hidden_dim, 32)

train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
data_batch = next(iter(train_loader))

with torch.no_grad():
ret = model(**data_batch)

self.assertIn("loss", ret)
self.assertIn("y_prob", ret)

def test_rnn_type_lstm(self):
"""Test RNN model with LSTM cell type."""
model = RNN(
dataset=self.dataset,
rnn_type="LSTM",
)

train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
data_batch = next(iter(train_loader))

with torch.no_grad():
ret = model(**data_batch)

self.assertIn("loss", ret)
self.assertEqual(ret["y_prob"].shape[0], 2)

def test_rnn_type_vanilla(self):
"""Test RNN model with vanilla RNN cell type."""
model = RNN(
dataset=self.dataset,
rnn_type="RNN",
)

train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
data_batch = next(iter(train_loader))

with torch.no_grad():
ret = model(**data_batch)

self.assertIn("loss", ret)
self.assertEqual(ret["y_prob"].shape[0], 2)

def test_bidirectional(self):
"""Test RNN model with bidirectional layers."""
model = RNN(
dataset=self.dataset,
bidirectional=True,
)

train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
data_batch = next(iter(train_loader))

with torch.no_grad():
ret = model(**data_batch)

self.assertIn("loss", ret)
self.assertEqual(ret["y_prob"].shape[0], 2)


class TestMultimodalRNN(unittest.TestCase):
"""Test cases for the MultimodalRNN model with mixed input modalities."""

def setUp(self):
"""Set up test data with both sequential and non-sequential features."""
self.samples = [
{
"patient_id": "patient-0",
"visit_id": "visit-0",
"conditions": ["cond-33", "cond-86", "cond-80"],
"demographics": ["asian", "male"],
"vitals": [120.0, 80.0, 98.6],
"label": 1,
},
{
"patient_id": "patient-1",
"visit_id": "visit-1",
"conditions": ["cond-12", "cond-52"],
"demographics": ["white", "female"],
"vitals": [110.0, 75.0, 98.2],
"label": 0,
},
]

self.input_schema = {
"conditions": "sequence",
"demographics": "multi_hot",
"vitals": "tensor",
}
self.output_schema = {"label": "binary"}

self.dataset = create_sample_dataset(
samples=self.samples,
input_schema=self.input_schema,
output_schema=self.output_schema,
dataset_name="test",
)

self.model = MultimodalRNN(dataset=self.dataset)

def test_model_initialization(self):
"""Test that the MultimodalRNN model initializes correctly."""
self.assertIsInstance(self.model, MultimodalRNN)
self.assertEqual(len(self.model.feature_keys), 3)
self.assertIn("conditions", self.model.sequential_features)
self.assertIn("demographics", self.model.non_sequential_features)
self.assertIn("vitals", self.model.non_sequential_features)

def test_model_forward(self):
"""Test that the MultimodalRNN forward pass works correctly."""
train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
data_batch = next(iter(train_loader))

with torch.no_grad():
ret = self.model(**data_batch)

self.assertIn("loss", ret)
self.assertIn("y_prob", ret)
self.assertIn("y_true", ret)
self.assertIn("logit", ret)

self.assertEqual(ret["y_prob"].shape[0], 2)
self.assertEqual(ret["loss"].dim(), 0)

def test_model_backward(self):
"""Test that the MultimodalRNN backward pass works correctly."""
train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
data_batch = next(iter(train_loader))

ret = self.model(**data_batch)
ret["loss"].backward()

has_gradient = False
for param in self.model.parameters():
if param.requires_grad and param.grad is not None:
has_gradient = True
break
self.assertTrue(
has_gradient, "No parameters have gradients after backward pass"
)

def test_model_with_embedding(self):
"""Test that the MultimodalRNN returns embeddings when requested."""
train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
data_batch = next(iter(train_loader))
data_batch["embed"] = True

with torch.no_grad():
ret = self.model(**data_batch)

self.assertIn("embed", ret)
self.assertEqual(ret["embed"].shape[0], 2)
expected_embed_dim = (
len(self.model.sequential_features) * self.model.hidden_dim
+ len(self.model.non_sequential_features) * self.model.embedding_dim
)
self.assertEqual(ret["embed"].shape[1], expected_embed_dim)


if __name__ == "__main__":
unittest.main()
Loading