diff --git a/tests/core/test_rnn.py b/tests/core/test_rnn.py new file mode 100644 index 000000000..a1bd52905 --- /dev/null +++ b/tests/core/test_rnn.py @@ -0,0 +1,275 @@ +import unittest + +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import RNN +from pyhealth.models.rnn import MultimodalRNN + + +class TestRNN(unittest.TestCase): + """Test cases for the RNN model.""" + + def setUp(self): + """Set up test data and model.""" + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": ["cond-33", "cond-86", "cond-80", "cond-12"], + "procedures": ["proc-12", "proc-45", "proc-23"], + "label": 0, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": ["cond-33", "cond-86", "cond-80"], + "procedures": ["proc-12"], + "label": 1, + }, + ] + + self.input_schema = { + "conditions": "sequence", + "procedures": "sequence", + } + self.output_schema = {"label": "binary"} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + + self.model = RNN(dataset=self.dataset) + + def test_model_initialization(self): + """Test that the RNN model initializes correctly.""" + self.assertIsInstance(self.model, RNN) + self.assertEqual(self.model.embedding_dim, 128) + self.assertEqual(self.model.hidden_dim, 128) + self.assertEqual(len(self.model.feature_keys), 2) + self.assertIn("conditions", self.model.feature_keys) + self.assertIn("procedures", self.model.feature_keys) + self.assertEqual(self.model.label_key, "label") + + def test_model_forward(self): + """Test that the RNN model forward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + + self.assertEqual(ret["y_prob"].shape[0], 2) + self.assertEqual(ret["y_true"].shape[0], 2) + self.assertEqual(ret["logit"].shape[0], 2) + self.assertEqual(ret["loss"].dim(), 0) + + def test_model_backward(self): + """Test that the RNN model backward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + ret = self.model(**data_batch) + ret["loss"].backward() + + has_gradient = False + for param in self.model.parameters(): + if param.requires_grad and param.grad is not None: + has_gradient = True + break + self.assertTrue( + has_gradient, "No parameters have gradients after backward pass" + ) + + def test_model_with_embedding(self): + """Test that the RNN model returns embeddings when requested.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + data_batch["embed"] = True + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertIn("embed", ret) + self.assertEqual(ret["embed"].shape[0], 2) + expected_embed_dim = len(self.model.feature_keys) * self.model.hidden_dim + self.assertEqual(ret["embed"].shape[1], expected_embed_dim) + + def test_custom_hyperparameters(self): + """Test RNN model with custom hyperparameters.""" + model = RNN( + dataset=self.dataset, + embedding_dim=64, + hidden_dim=32, + ) + + self.assertEqual(model.embedding_dim, 64) + self.assertEqual(model.hidden_dim, 32) + + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + + def test_rnn_type_lstm(self): + """Test RNN model with LSTM cell type.""" + model = RNN( + dataset=self.dataset, + rnn_type="LSTM", + ) + + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertEqual(ret["y_prob"].shape[0], 2) + + def test_rnn_type_vanilla(self): + """Test RNN model with vanilla RNN cell type.""" + model = RNN( + dataset=self.dataset, + rnn_type="RNN", + ) + + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertEqual(ret["y_prob"].shape[0], 2) + + def test_bidirectional(self): + """Test RNN model with bidirectional layers.""" + model = RNN( + dataset=self.dataset, + bidirectional=True, + ) + + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertEqual(ret["y_prob"].shape[0], 2) + + +class TestMultimodalRNN(unittest.TestCase): + """Test cases for the MultimodalRNN model with mixed input modalities.""" + + def setUp(self): + """Set up test data with both sequential and non-sequential features.""" + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": ["cond-33", "cond-86", "cond-80"], + "demographics": ["asian", "male"], + "vitals": [120.0, 80.0, 98.6], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": ["cond-12", "cond-52"], + "demographics": ["white", "female"], + "vitals": [110.0, 75.0, 98.2], + "label": 0, + }, + ] + + self.input_schema = { + "conditions": "sequence", + "demographics": "multi_hot", + "vitals": "tensor", + } + self.output_schema = {"label": "binary"} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + + self.model = MultimodalRNN(dataset=self.dataset) + + def test_model_initialization(self): + """Test that the MultimodalRNN model initializes correctly.""" + self.assertIsInstance(self.model, MultimodalRNN) + self.assertEqual(len(self.model.feature_keys), 3) + self.assertIn("conditions", self.model.sequential_features) + self.assertIn("demographics", self.model.non_sequential_features) + self.assertIn("vitals", self.model.non_sequential_features) + + def test_model_forward(self): + """Test that the MultimodalRNN forward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + + self.assertEqual(ret["y_prob"].shape[0], 2) + self.assertEqual(ret["loss"].dim(), 0) + + def test_model_backward(self): + """Test that the MultimodalRNN backward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + ret = self.model(**data_batch) + ret["loss"].backward() + + has_gradient = False + for param in self.model.parameters(): + if param.requires_grad and param.grad is not None: + has_gradient = True + break + self.assertTrue( + has_gradient, "No parameters have gradients after backward pass" + ) + + def test_model_with_embedding(self): + """Test that the MultimodalRNN returns embeddings when requested.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + data_batch["embed"] = True + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertIn("embed", ret) + self.assertEqual(ret["embed"].shape[0], 2) + expected_embed_dim = ( + len(self.model.sequential_features) * self.model.hidden_dim + + len(self.model.non_sequential_features) * self.model.embedding_dim + ) + self.assertEqual(ret["embed"].shape[1], expected_embed_dim) + + +if __name__ == "__main__": + unittest.main()