diff --git a/LABRADOR_TEST_SUMMARY.md b/LABRADOR_TEST_SUMMARY.md new file mode 100644 index 000000000..02c58310b --- /dev/null +++ b/LABRADOR_TEST_SUMMARY.md @@ -0,0 +1,90 @@ +# LabradorModel Test + +**Test File:** `tests/core/test_labrador.py` + +## Overview + +Minimal end-to-end smoke test for `LabradorModel` that verifies: +- Dataset creation with aligned lab codes and values +- Model initialization with correct hyperparameters +- Forward pass with expected output structure +- Backward pass with gradient computation +- Embedding extraction when requested + +## Test Methods + +| Test | Purpose | +|------|---------| +| `test_model_initialization` | Verifies model initializes with correct hyperparameters (embed_dim, num_heads, num_layers, feature keys, label key) | +| `test_model_forward` | Checks forward pass returns `loss`, `y_prob`, `y_true`, `logit` with correct shapes and batch sizes | +| `test_model_backward` | Confirms backward pass computes gradients on parameters | +| `test_model_with_embedding` | Validates embedding extraction returns correct shape `[batch_size, embed_dim]` | + +--- + +## Key Implementation Details + +## Test Data Structure + +**Minimal Synthetic Dataset (2 samples, 4 labs each):** +```python +samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "lab_codes": ["lab-1", "lab-2", "lab-3", "lab-4"], # Categorical + "lab_values": [1.0, 2.5, 3.0, 4.5], # Continuous + "label": 0, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "lab_codes": ["lab-1", "lab-2", "lab-3", "lab-4"], + "lab_values": [2.1, 1.8, 2.9, 3.5], + "label": 1, + }, +] +``` + +**Input Schema:** +```python +input_schema = { + "lab_codes": "sequence", # → Categorical tokens [batch, 4] + "lab_values": "tensor", # → Float values [batch, 4] +} +output_schema = {"label": "binary"} +``` + +**Model Configuration:** +```python +LabradorModel( + dataset=dataset, + code_feature_key="lab_codes", + value_feature_key="lab_values", + embed_dim=32, # Lightweight for testing + num_heads=2, + num_layers=1, +) +``` + +--- + +## Running the Test + +**Option 1: Using pixi (recommended)** +```bash +cd /home/leemh/PyHealth +make init # First time only +make test +``` + +**Option 2: Direct unittest** +```bash +cd /home/leemh/PyHealth +python3 -m unittest tests.core.test_labrador.TestLabradorModel -v +``` + +**Option 3: Single test method** +```bash +python3 -m unittest tests.core.test_labrador.TestLabradorModel.test_model_forward -v +``` diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..92c44f7ee 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -197,6 +197,7 @@ API Reference models/pyhealth.models.MedLink models/pyhealth.models.TCN models/pyhealth.models.TFMTokenizer + models/pyhealth.models.labrador models/pyhealth.models.GAN models/pyhealth.models.VAE models/pyhealth.models.SDOH diff --git a/docs/api/models/pyhealth.models.labrador.rst b/docs/api/models/pyhealth.models.labrador.rst new file mode 100644 index 000000000..347439515 --- /dev/null +++ b/docs/api/models/pyhealth.models.labrador.rst @@ -0,0 +1,9 @@ +pyhealth.models.labrador +=================================== + +LabradorModel for aligned lab code and lab value inputs. + +.. autoclass:: pyhealth.models.LabradorModel + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/labrador_ablations_quickstart.ipynb b/examples/labrador_ablations_quickstart.ipynb new file mode 100644 index 000000000..ac25f0904 --- /dev/null +++ b/examples/labrador_ablations_quickstart.ipynb @@ -0,0 +1,1037 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "2542e8fb", + "metadata": {}, + "source": [ + "# Labrador Minimal Downstream Ablation\n", + "\n", + "This notebook demonstrates a minimal downstream use of `LabradorModel` on synthetic lab data.\n", + "\n", + "We:\n", + "\n", + "1. generate structured synthetic lab samples,\n", + "2. define a nonlinear binary classification task,\n", + "3. train the model as a classifier,\n", + "4. compare performance under simple hyperparameter changes,\n", + "5. demonstrate the masked lab pretraining process done in the paper\n", + "\n", + "The goal is to provide a small, runnable example that illustrates model behavior under different configurations.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "639b1279", + "metadata": {}, + "source": [ + "## 1) Setup\n", + "\n", + "This section imports the required libraries, defines a few constants, and sets the random seed so the example is reproducible." + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "676b1a77", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Torch: 2.7.1+cu126\n", + "CUDA available: True\n" + ] + } + ], + "source": [ + "import random\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "\n", + "from pyhealth.datasets import create_sample_dataset, get_dataloader\n", + "from pyhealth.models import LabradorModel\n", + "\n", + "\n", + "NUM_LABS = 10\n", + "LAB_CODES = [f\"lab-{i}\" for i in range(NUM_LABS)]\n", + "SEED = 42\n", + "TRAIN_RATIO = 0.6\n", + "VAL_RATIO = 0.2\n", + "DOWNSTREAM_SAMPLES = 600\n", + "EPOCHS = 5\n", + "BATCH_SIZE = 64\n", + "\n", + "\n", + "def set_seed(seed: int = 42):\n", + " random.seed(seed)\n", + " np.random.seed(seed)\n", + " torch.manual_seed(seed)\n", + " if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(seed)\n", + "\n", + "\n", + "set_seed(SEED)\n", + "print(f\"Torch: {torch.__version__}\")\n", + "print(f\"CUDA available: {torch.cuda.is_available()}\")" + ] + }, + { + "cell_type": "markdown", + "id": "850cc42e", + "metadata": {}, + "source": [ + "## 2) Synthetic Data\n", + "\n", + "We construct synthetic lab data with basic structure using **latent profiles**.\n", + "\n", + "Each sample is generated by first selecting a hidden profile, which controls:\n", + "\n", + "* which lab codes are more likely to appear together\n", + "* how their values are shifted\n", + "\n", + "A profile can be thought of as a simple “patient type” that induces patterns in the labs.\n", + "\n", + "### Intuition\n", + "\n", + "For example:\n", + "\n", + "* **Profile A (metabolic-like)**\n", + "\n", + " * more likely to include labs such as `lab_2`, `lab_7`, `lab_8`\n", + " * values tend to be higher on average\n", + "\n", + "* **Profile B (baseline-like)**\n", + "\n", + " * labs are more uniformly distributed\n", + " * values are centered around lower or neutral ranges\n", + "\n", + "A generated sample might look like (metabolic profile):\n", + "\n", + "* Codes: `[2, 7, 1, 9]`\n", + "* Values: `[0.8, 0.7, 0.2, 0.9]`\n", + "\n", + "Another sample from a different profile (baseline):\n", + "\n", + "* Codes: `[1, 3, 5, 0]`\n", + "* Values: `[0.2, 0.3, 0.4, 0.1]`\n", + "\n", + "These profiles create **co-occurrence structure** (which labs appear together) and **value patterns**, making the data more realistic than purely random samples.\n", + "\n", + "---\n", + "\n", + "### Downstream label\n", + "\n", + "The label is defined by a nonlinear rule:\n", + "\n", + "* positive if `(lab_2 × lab_7)` is large and `lab_1` is below a threshold\n", + "\n", + "This requires the model to:\n", + "\n", + "* identify specific lab types\n", + "* reason about their values\n", + "* capture interactions between labs\n", + "\n", + "Overall, the synthetic setup introduces just enough structure to make the task non-trivial while remaining lightweight.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "f2bc5d41", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_structured_fake_batch(\n", + " batch_size: int = 32,\n", + " num_labs: int = NUM_LABS,\n", + "):\n", + " \"\"\"Generate synthetic lab bags with profile-specific code/value structure.\"\"\"\n", + " panel_probs = torch.tensor(\n", + " [\n", + " [0.05, 0.18, 0.07, 0.06, 0.05, 0.05, 0.18, 0.18, 0.10, 0.08],\n", + " [0.12, 0.18, 0.16, 0.15, 0.06, 0.05, 0.13, 0.05, 0.05, 0.05],\n", + " [0.10, 0.08, 0.05, 0.05, 0.16, 0.16, 0.10, 0.08, 0.12, 0.10],\n", + " ],\n", + " dtype=torch.float32,\n", + " )\n", + "\n", + " anchor_codes = torch.tensor([1, 2, 7], dtype=torch.long)\n", + " panel_ids = torch.randint(0, 3, (batch_size,))\n", + "\n", + " codes = torch.empty(batch_size, num_labs, dtype=torch.long)\n", + " values = torch.empty(batch_size, num_labs, dtype=torch.float32)\n", + "\n", + " for i in range(batch_size):\n", + " panel = panel_ids[i].item()\n", + " extra_codes = torch.multinomial(\n", + " panel_probs[panel],\n", + " num_samples=num_labs - len(anchor_codes),\n", + " replacement=True,\n", + " )\n", + " sample_codes = torch.cat([anchor_codes, extra_codes], dim=0)\n", + " sample_codes = sample_codes[torch.randperm(num_labs)]\n", + "\n", + " sample_values = sample_codes.float() / (num_labs - 1)\n", + "\n", + " if panel == 0:\n", + " sample_values[(sample_codes == 1) | (sample_codes == 6) | (sample_codes == 7)] += 0.18\n", + " sample_values[sample_codes == 0] -= 0.05\n", + " elif panel == 1:\n", + " sample_values[(sample_codes == 2) | (sample_codes == 3)] += 0.15\n", + " sample_values[sample_codes == 0] += 0.25\n", + " sample_values[sample_codes == 1] += 0.05\n", + " sample_values[sample_codes == 6] += 0.05\n", + " else:\n", + " sample_values[\n", + " (sample_codes == 4)\n", + " | (sample_codes == 5)\n", + " | (sample_codes == 8)\n", + " | (sample_codes == 9)\n", + " ] += 0.18\n", + " sample_values[sample_codes == 1] -= 0.06\n", + " sample_values[sample_codes == 0] += 0.08\n", + "\n", + " sample_values += 0.05 * torch.randn_like(sample_values)\n", + " sample_values = torch.clamp(sample_values, 0.0, 1.0)\n", + "\n", + " codes[i] = sample_codes\n", + " values[i] = sample_values\n", + "\n", + " return codes, values\n", + "\n", + "\n", + "def get_mean_value_for_code(\n", + " codes: torch.Tensor,\n", + " values: torch.Tensor,\n", + " target_code: int,\n", + ") -> torch.Tensor:\n", + " mask = (codes == target_code).float()\n", + " counts = mask.sum(dim=1)\n", + " summed = (values * mask).sum(dim=1)\n", + " return torch.where(\n", + " counts > 0,\n", + " summed / counts.clamp(min=1.0),\n", + " torch.zeros_like(summed),\n", + " )\n", + "\n", + "\n", + "def generate_downstream_labels(\n", + " codes: torch.Tensor,\n", + " values: torch.Tensor,\n", + ") -> torch.Tensor:\n", + " code_1_value = get_mean_value_for_code(codes, values, target_code=1)\n", + " code_2_value = get_mean_value_for_code(codes, values, target_code=2)\n", + " code_7_value = get_mean_value_for_code(codes, values, target_code=7)\n", + " return ((code_2_value * code_7_value > 0.20) & (code_1_value < 0.50)).long()\n", + "\n", + "\n", + "def make_downstream_samples(\n", + " n_samples: int,\n", + " seed: int,\n", + " num_labs: int = NUM_LABS,\n", + "):\n", + " set_seed(seed)\n", + " codes, values = generate_structured_fake_batch(\n", + " batch_size=n_samples,\n", + " num_labs=num_labs,\n", + " )\n", + " labels = generate_downstream_labels(codes, values)\n", + "\n", + " samples = []\n", + " for i in range(n_samples):\n", + " samples.append(\n", + " {\n", + " \"patient_id\": f\"patient-{i}\",\n", + " \"visit_id\": f\"visit-{i}\",\n", + " \"lab_codes\": [LAB_CODES[int(code)] for code in codes[i].tolist()],\n", + " \"lab_values\": values[i].tolist(),\n", + " \"label\": int(labels[i].item()),\n", + " }\n", + " )\n", + " return samples\n", + "\n", + "\n", + "def build_dataset(samples):\n", + " return create_sample_dataset(\n", + " samples=samples,\n", + " input_schema={\"lab_codes\": \"sequence\", \"lab_values\": \"tensor\"},\n", + " output_schema={\"label\": \"binary\"},\n", + " dataset_name=\"labrador_downstream_demo\",\n", + " )\n", + "\n", + "\n", + "def split_samples(samples):\n", + " n_samples = len(samples)\n", + " train_end = int(n_samples * TRAIN_RATIO)\n", + " val_end = int(n_samples * (TRAIN_RATIO + VAL_RATIO))\n", + " return samples[:train_end], samples[train_end:val_end], samples[val_end:]" + ] + }, + { + "cell_type": "markdown", + "id": "df0a7a18", + "metadata": {}, + "source": [ + "## 3) Training and Evaluation Helpers\n", + "\n", + "This section defines the minimal training loop, evaluation function, and one helper that runs a single Labrador configuration on the downstream task." + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "c94ad128", + "metadata": {}, + "outputs": [], + "source": [ + "def train_model(model, dataset, lr: float):\n", + " model.train()\n", + " optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", + " loader = get_dataloader(dataset, batch_size=BATCH_SIZE, shuffle=True)\n", + "\n", + " for _ in range(EPOCHS):\n", + " for batch in loader:\n", + " optimizer.zero_grad()\n", + " output = model(**batch)\n", + " output[\"loss\"].backward()\n", + " optimizer.step()\n", + "\n", + "\n", + "def compute_binary_metrics(y_true: torch.Tensor, y_pred: torch.Tensor):\n", + " y_true = y_true.float()\n", + " y_pred = y_pred.float()\n", + "\n", + " tp = ((y_pred == 1) & (y_true == 1)).sum().item()\n", + " tn = ((y_pred == 0) & (y_true == 0)).sum().item()\n", + " fp = ((y_pred == 1) & (y_true == 0)).sum().item()\n", + " fn = ((y_pred == 0) & (y_true == 1)).sum().item()\n", + "\n", + " total = max(tp + tn + fp + fn, 1)\n", + " accuracy = (tp + tn) / total\n", + "\n", + " precision_den = max(tp + fp, 1)\n", + " recall_den = max(tp + fn, 1)\n", + " precision = tp / precision_den\n", + " recall = tp / recall_den\n", + "\n", + " f1_den = max(precision + recall, 1e-12)\n", + " f1 = 2 * precision * recall / f1_den\n", + "\n", + " return {\n", + " \"accuracy\": float(accuracy),\n", + " \"precision\": float(precision),\n", + " \"recall\": float(recall),\n", + " \"f1\": float(f1),\n", + " }\n", + "\n", + "\n", + "def evaluate_model(model, dataset):\n", + " model.eval()\n", + " loader = get_dataloader(dataset, batch_size=BATCH_SIZE, shuffle=False)\n", + " y_true_all = []\n", + " y_pred_all = []\n", + "\n", + " with torch.no_grad():\n", + " for batch in loader:\n", + " output = model(**batch)\n", + " y_true = output[\"y_true\"].detach().cpu().view(-1).float()\n", + " y_prob = output[\"y_prob\"].detach().cpu()\n", + "\n", + " if y_prob.dim() == 2 and y_prob.shape[1] > 1:\n", + " y_pred = torch.argmax(y_prob, dim=1).view(-1).float()\n", + " else:\n", + " y_pred = (y_prob.view(-1) > 0.5).float()\n", + "\n", + " y_true_all.append(y_true)\n", + " y_pred_all.append(y_pred)\n", + "\n", + " y_true_all = torch.cat(y_true_all)\n", + " y_pred_all = torch.cat(y_pred_all)\n", + " return compute_binary_metrics(y_true_all, y_pred_all)\n", + "\n", + "\n", + "def run_experiment(train_dataset, test_dataset, embed_dim: int, lr: float):\n", + " set_seed(SEED)\n", + " model = LabradorModel(\n", + " dataset=train_dataset,\n", + " code_feature_key=\"lab_codes\",\n", + " value_feature_key=\"lab_values\",\n", + " embed_dim=embed_dim,\n", + " num_heads=2,\n", + " num_layers=1,\n", + " )\n", + " train_model(model, train_dataset, lr=lr)\n", + " return evaluate_model(model, test_dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "a9dc9e8e", + "metadata": {}, + "source": [ + "## 4) Create the Synthetic Downstream Dataset\n", + "\n", + "This section generates one small synthetic dataset, shuffles it, splits it into train/validation/test subsets, and converts each split into a PyHealth dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "1ae4d1a1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Label label vocab: {0: 0, 1: 1}\n", + "Label label vocab: {0: 0, 1: 1}\n", + "Label label vocab: {0: 0, 1: 1}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train/val/test: 360/120/120\n", + "train positive rate: 0.603\n", + "test positive rate: 0.600\n" + ] + } + ], + "source": [ + "samples = make_downstream_samples(DOWNSTREAM_SAMPLES, seed=SEED)\n", + "random.Random(SEED).shuffle(samples)\n", + "train_samples, val_samples, test_samples = split_samples(samples)\n", + "\n", + "train_dataset = build_dataset(train_samples)\n", + "val_dataset = build_dataset(val_samples)\n", + "test_dataset = build_dataset(test_samples)\n", + "\n", + "print(f\"train/val/test: {len(train_dataset)}/{len(val_dataset)}/{len(test_dataset)}\")\n", + "print(f\"train positive rate: {np.mean([sample['label'] for sample in train_samples]):.3f}\")\n", + "print(f\"test positive rate: {np.mean([sample['label'] for sample in test_samples]):.3f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "94fdba18", + "metadata": {}, + "source": [ + "## 5) Embedding Dimension Ablation\n", + "\n", + "We vary model capacity by changing the embedding dimension:\n", + "\n", + "* `embed_dim = 64`\n", + "* `embed_dim = 128`\n", + "* `embed_dim = 256`\n", + "\n", + "The learning rate is fixed at `1e-3`.\n", + "\n", + "Each configuration is trained and evaluated on the same dataset, allowing direct comparison of model capacity.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "2ecd9e53", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
embed_dimaccuracyprecisionrecallf1
0640.4666670.5909090.3611110.448276
11280.5000000.6034480.4861110.538462
22560.4250000.5483870.2361110.330097
\n", + "
" + ], + "text/plain": [ + " embed_dim accuracy precision recall f1\n", + "0 64 0.466667 0.590909 0.361111 0.448276\n", + "1 128 0.500000 0.603448 0.486111 0.538462\n", + "2 256 0.425000 0.548387 0.236111 0.330097" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embed_dim_results = []\n", + "for embed_dim in [64, 128, 256]:\n", + " metrics = run_experiment(\n", + " train_dataset=train_dataset,\n", + " test_dataset=test_dataset,\n", + " embed_dim=embed_dim,\n", + " lr=1e-3,\n", + " )\n", + " embed_dim_results.append({\"embed_dim\": embed_dim, **metrics})\n", + "\n", + "embed_dim_df = pd.DataFrame(embed_dim_results)\n", + "embed_dim_df" + ] + }, + { + "cell_type": "markdown", + "id": "83deadcb", + "metadata": {}, + "source": [ + "## 6) Learning Rate Ablation\n", + "\n", + "We vary the optimizer learning rate while keeping model size fixed (`embed_dim = 128`):\n", + "\n", + "* `lr = 1e-4`\n", + "* `lr = 1e-3`\n", + "* `lr = 5e-3`\n", + "\n", + "This evaluates how sensitive training performance is to optimization settings.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "96cde05b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
learning_rateaccuracyprecisionrecallf1
00.00010.6083330.6106190.9583330.745946
10.00100.5000000.6034480.4861110.538462
20.00500.4250000.5517240.2222220.316832
\n", + "
" + ], + "text/plain": [ + " learning_rate accuracy precision recall f1\n", + "0 0.0001 0.608333 0.610619 0.958333 0.745946\n", + "1 0.0010 0.500000 0.603448 0.486111 0.538462\n", + "2 0.0050 0.425000 0.551724 0.222222 0.316832" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "learning_rate_results = []\n", + "for lr in [1e-4, 1e-3, 5e-3]:\n", + " metrics = run_experiment(\n", + " train_dataset=train_dataset,\n", + " test_dataset=test_dataset,\n", + " embed_dim=128,\n", + " lr=lr,\n", + " )\n", + " learning_rate_results.append({\"learning_rate\": lr, **metrics})\n", + "\n", + "learning_rate_df = pd.DataFrame(learning_rate_results)\n", + "learning_rate_df" + ] + }, + { + "cell_type": "markdown", + "id": "62b53d17", + "metadata": {}, + "source": [ + "## 7) Summary\n", + "\n", + "This section prints the final metric comparisons in a compact format, including accuracy, precision, recall, and F1." + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "df4b9dba", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Embedding Dimension Ablation\n", + "embed_dim=64 -> accuracy=0.4667, precision=0.5909, recall=0.3611, f1=0.4483\n", + "embed_dim=128 -> accuracy=0.5000, precision=0.6034, recall=0.4861, f1=0.5385\n", + "embed_dim=256 -> accuracy=0.4250, precision=0.5484, recall=0.2361, f1=0.3301\n", + "\n", + "Learning Rate Ablation\n", + "lr=1e-04 -> accuracy=0.6083, precision=0.6106, recall=0.9583, f1=0.7459\n", + "lr=1e-03 -> accuracy=0.5000, precision=0.6034, recall=0.4861, f1=0.5385\n", + "lr=5e-03 -> accuracy=0.4250, precision=0.5517, recall=0.2222, f1=0.3168\n", + "\n", + "Embedding Dimension Results\n", + " embed_dim accuracy precision recall f1\n", + " 64 0.466667 0.590909 0.361111 0.448276\n", + " 128 0.500000 0.603448 0.486111 0.538462\n", + " 256 0.425000 0.548387 0.236111 0.330097\n", + "\n", + "Learning Rate Results\n", + " learning_rate accuracy precision recall f1\n", + " 0.0001 0.608333 0.610619 0.958333 0.745946\n", + " 0.0010 0.500000 0.603448 0.486111 0.538462\n", + " 0.0050 0.425000 0.551724 0.222222 0.316832\n" + ] + } + ], + "source": [ + "print(\"Embedding Dimension Ablation\")\n", + "for row in embed_dim_results:\n", + " print(\n", + " f\"embed_dim={row['embed_dim']} \"\n", + " f\"-> accuracy={row['accuracy']:.4f}, precision={row['precision']:.4f}, \"\n", + " f\"recall={row['recall']:.4f}, f1={row['f1']:.4f}\"\n", + " )\n", + "print()\n", + "\n", + "print(\"Learning Rate Ablation\")\n", + "for row in learning_rate_results:\n", + " print(\n", + " f\"lr={row['learning_rate']:.0e} \"\n", + " f\"-> accuracy={row['accuracy']:.4f}, precision={row['precision']:.4f}, \"\n", + " f\"recall={row['recall']:.4f}, f1={row['f1']:.4f}\"\n", + " )\n", + "print()\n", + "\n", + "print(\"Embedding Dimension Results\")\n", + "print(embed_dim_df.to_string(index=False))\n", + "print()\n", + "\n", + "print(\"Learning Rate Results\")\n", + "print(learning_rate_df.to_string(index=False))" + ] + }, + { + "cell_type": "markdown", + "id": "9b65316e", + "metadata": {}, + "source": [ + "## 8) Lightweight Masked-Lab Demonstration\n", + "\n", + "This section provides a simplified demonstration of the **masked pretraining objective** used in the Labrador paper.\n", + "\n", + "In Labrador, the model is trained by masking one lab in a patient’s lab set and predicting:\n", + "\n", + "* the **lab identity (code)**, and\n", + "* the **lab value**\n", + "\n", + "This is similar to masked language modeling, where the model learns to infer missing information from context.\n", + "\n", + "---\n", + "\n", + "### Example\n", + "\n", + "Given labs:\n", + "\n", + "* Codes: `[2, 7, 1, 9]`\n", + "* Values: `[0.8, 0.7, 0.2, 0.9]`\n", + "\n", + "Mask one position:\n", + "\n", + "* Codes: `[2, [MASK], 1, 9]`\n", + "* Values: `[0.8, 0.0, 0.2, 0.9]`\n", + "\n", + "The model predicts the missing code (`7`) and value (`0.7`) using the remaining labs.\n", + "\n", + "---\n", + "\n", + "### What's implemented here\n", + "\n", + "* Mask one lab per sample\n", + "* Use `LabradorModel` encoder (embedding + transformer)\n", + "* Add small heads to predict:\n", + "\n", + " * masked code (cross-entropy)\n", + " * masked value (MSE)\n", + "* Report evaluation metrics (accuracy, precision, recall, F1) for masked code prediction and MSE for masked value prediction.\n", + "\n", + "---\n", + "\n", + "### Scope note\n", + "\n", + "This is a **lightweight demo only**:\n", + "\n", + "* uses synthetic data\n", + "* small model and short training\n", + "* no transfer to downstream tasks\n", + "\n", + "It is meant to illustrate the **pretraining mechanism**, not reproduce the full paper results." + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "ca0ef6a1", + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "\n", + "def mask_one_lab(codes: torch.Tensor, values: torch.Tensor, mask_token_id: int):\n", + " \"\"\"Mask one random lab position per sample.\"\"\"\n", + " batch_size, seq_len = codes.shape\n", + " mask_idx = torch.randint(0, seq_len, (batch_size,))\n", + "\n", + " row_idx = torch.arange(batch_size)\n", + " target_code = codes[row_idx, mask_idx].clone()\n", + " target_value = values[row_idx, mask_idx].clone()\n", + "\n", + " masked_codes = codes.clone()\n", + " masked_values = values.clone()\n", + "\n", + " masked_codes[row_idx, mask_idx] = mask_token_id\n", + " masked_values[row_idx, mask_idx] = 0.0\n", + "\n", + " return masked_codes, masked_values, mask_idx, target_code, target_value\n", + "\n", + "\n", + "def encode_with_labrador(model: LabradorModel, codes: torch.Tensor, values: torch.Tensor):\n", + " \"\"\"Reuse LabradorModel encoder blocks to produce token representations.\"\"\"\n", + " code_emb = model.code_embedding(codes)\n", + " value_emb = model.value_projection(values.unsqueeze(-1))\n", + "\n", + " x = code_emb + value_emb\n", + " x = model.value_fusion(x)\n", + " x = model.value_act(x)\n", + " x = model.value_norm(x)\n", + "\n", + " # Keep all token positions active in this synthetic masked-lab demo.\n", + " token_mask = torch.ones_like(codes, dtype=torch.bool)\n", + " x = model.transformer(x, src_key_padding_mask=~token_mask)\n", + " return x\n", + "\n", + "\n", + "def compute_multiclass_metrics(\n", + " y_true: torch.Tensor,\n", + " y_pred: torch.Tensor,\n", + " num_classes: int,\n", + "):\n", + " \"\"\"Macro precision/recall/F1 for masked code prediction.\"\"\"\n", + " y_true = y_true.long()\n", + " y_pred = y_pred.long()\n", + "\n", + " acc = (y_true == y_pred).float().mean().item()\n", + "\n", + " precisions = []\n", + " recalls = []\n", + " f1s = []\n", + "\n", + " for cls in range(num_classes):\n", + " tp = ((y_pred == cls) & (y_true == cls)).sum().item()\n", + " fp = ((y_pred == cls) & (y_true != cls)).sum().item()\n", + " fn = ((y_pred != cls) & (y_true == cls)).sum().item()\n", + "\n", + " precision = tp / max(tp + fp, 1)\n", + " recall = tp / max(tp + fn, 1)\n", + " f1 = 0.0 if (precision + recall) == 0 else (2 * precision * recall) / (precision + recall)\n", + "\n", + " precisions.append(precision)\n", + " recalls.append(recall)\n", + " f1s.append(f1)\n", + "\n", + " return {\n", + " \"accuracy\": float(acc),\n", + " \"precision\": float(np.mean(precisions)),\n", + " \"recall\": float(np.mean(recalls)),\n", + " \"f1\": float(np.mean(f1s)),\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "6ca044b2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ce_lossmse_losstotal_lossaccuracyprecisionrecallf1
01.9205230.0839622.0044850.31250.0892860.1418760.104206
\n", + "
" + ], + "text/plain": [ + " ce_loss mse_loss total_loss accuracy precision recall f1\n", + "0 1.920523 0.083962 2.004485 0.3125 0.089286 0.141876 0.104206" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "set_seed(SEED)\n", + "\n", + "# Eval synthetic batch for masked prediction demo.\n", + "eval_codes, eval_values = generate_structured_fake_batch(batch_size=128, num_labs=NUM_LABS)\n", + "\n", + "# Reuse the same LabradorModel encoder blocks.\n", + "masked_model = LabradorModel(\n", + " dataset=train_dataset,\n", + " code_feature_key=\"lab_codes\",\n", + " value_feature_key=\"lab_values\",\n", + " embed_dim=64,\n", + " num_heads=2,\n", + " num_layers=1,\n", + ")\n", + "\n", + "# Introduce a dedicated mask token outside the normal code range [0, num_labs-1].\n", + "mask_token_id = masked_model.num_labs\n", + "\n", + "# Expand embedding table by one row so mask token is valid.\n", + "old_embedding = masked_model.code_embedding\n", + "expanded_embedding = nn.Embedding(masked_model.num_labs + 1, masked_model.embed_dim)\n", + "with torch.no_grad():\n", + " expanded_embedding.weight[: masked_model.num_labs].copy_(old_embedding.weight)\n", + " expanded_embedding.weight[mask_token_id].copy_(old_embedding.weight.mean(dim=0))\n", + "masked_model.code_embedding = expanded_embedding\n", + "\n", + "# Lightweight masked heads (not added to core model class).\n", + "mask_code_head = nn.Linear(masked_model.embed_dim, masked_model.num_labs)\n", + "mask_value_head = nn.Sequential(\n", + " nn.Linear(masked_model.embed_dim + masked_model.num_labs, masked_model.embed_dim),\n", + " nn.ReLU(),\n", + " nn.Linear(masked_model.embed_dim, 1),\n", + " nn.Sigmoid(),\n", + ")\n", + "\n", + "optimizer = torch.optim.Adam(\n", + " list(masked_model.parameters()) + list(mask_code_head.parameters()) + list(mask_value_head.parameters()),\n", + " lr=1e-3,\n", + ")\n", + "ce_loss_fn = nn.CrossEntropyLoss()\n", + "mse_loss_fn = nn.MSELoss()\n", + "\n", + "MASK_STEPS = 500\n", + "for _ in range(MASK_STEPS):\n", + " train_codes, train_values = generate_structured_fake_batch(batch_size=256, num_labs=NUM_LABS)\n", + " masked_codes, masked_values, mask_idx, target_code, target_value = mask_one_lab(\n", + " train_codes,\n", + " train_values,\n", + " mask_token_id,\n", + " )\n", + "\n", + " x = encode_with_labrador(masked_model, masked_codes, masked_values)\n", + "\n", + " code_logits = mask_code_head(x)\n", + " code_probs = torch.softmax(code_logits, dim=-1)\n", + " value_pred = mask_value_head(torch.cat([x, code_probs], dim=-1)).squeeze(-1)\n", + "\n", + " row_idx = torch.arange(masked_codes.size(0))\n", + " masked_code_logits = code_logits[row_idx, mask_idx]\n", + " masked_value_pred = value_pred[row_idx, mask_idx]\n", + "\n", + " loss_code = ce_loss_fn(masked_code_logits, target_code)\n", + " loss_value = mse_loss_fn(masked_value_pred, target_value)\n", + " loss = loss_code + loss_value\n", + "\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + "# Evaluation on held-out synthetic batch\n", + "masked_codes, masked_values, mask_idx, target_code, target_value = mask_one_lab(\n", + " eval_codes,\n", + " eval_values,\n", + " mask_token_id,\n", + ")\n", + "\n", + "with torch.no_grad():\n", + " x = encode_with_labrador(masked_model, masked_codes, masked_values)\n", + " code_logits = mask_code_head(x)\n", + " code_probs = torch.softmax(code_logits, dim=-1)\n", + " value_pred = mask_value_head(torch.cat([x, code_probs], dim=-1)).squeeze(-1)\n", + "\n", + "row_idx = torch.arange(masked_codes.size(0))\n", + "masked_code_logits = code_logits[row_idx, mask_idx]\n", + "masked_value_pred = value_pred[row_idx, mask_idx]\n", + "\n", + "loss_code_eval = ce_loss_fn(masked_code_logits, target_code)\n", + "loss_value_eval = mse_loss_fn(masked_value_pred, target_value)\n", + "loss_total_eval = loss_code_eval + loss_value_eval\n", + "\n", + "pred_code = torch.argmax(masked_code_logits, dim=1)\n", + "code_metrics = compute_multiclass_metrics(target_code, pred_code, num_classes=masked_model.num_labs)\n", + "\n", + "masked_demo_results = pd.DataFrame([\n", + " {\n", + " \"ce_loss\": float(loss_code_eval.item()),\n", + " \"mse_loss\": float(loss_value_eval.item()),\n", + " \"total_loss\": float(loss_total_eval.item()),\n", + " \"accuracy\": code_metrics[\"accuracy\"],\n", + " \"precision\": code_metrics[\"precision\"],\n", + " \"recall\": code_metrics[\"recall\"],\n", + " \"f1\": code_metrics[\"f1\"],\n", + " }\n", + "])\n", + "\n", + "masked_demo_results" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..885af2d98 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -9,6 +9,7 @@ from .embedding import EmbeddingModel from .gamenet import GAMENet, GAMENetLayer from .jamba_ehr import JambaEHR, JambaLayer +from .labrador import LabradorModel from .logistic_regression import LogisticRegression from .gan import GAN from .gnn import GAT, GCN diff --git a/pyhealth/models/labrador.py b/pyhealth/models/labrador.py new file mode 100644 index 000000000..961ad90d5 --- /dev/null +++ b/pyhealth/models/labrador.py @@ -0,0 +1,288 @@ +from typing import Dict, Optional, Tuple, cast + +import torch +import torch.nn as nn + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel + + +class LabradorModel(BaseModel): + """Transformer-based model for tabular lab code/value inputs. + + The model consumes two aligned feature streams: + 1) categorical lab codes (token ids) + 2) continuous lab values + + Architecture: + - code embedding + - value projection + - additive fusion -> linear -> ReLU -> LayerNorm + - Transformer encoder (no positional encoding) + - mean pooling over lab dimension + - MLP classifier head + + Args: + dataset: The dataset used by PyHealth trainers. + code_feature_key: Input feature key containing lab code tokens. + value_feature_key: Input feature key containing lab values. + embed_dim: Hidden size for embeddings and transformer blocks. + num_heads: Number of attention heads. + num_layers: Number of transformer encoder layers. + dropout: Dropout used by the transformer encoder layer. + ff_hidden_dim: Feed-forward width inside each transformer layer. + classifier_hidden_dim: Hidden width of classifier MLP head. + + Examples: + >>> from pyhealth.datasets import create_sample_dataset + >>> from pyhealth.models import LabradorModel + >>> samples = [ + ... { + ... "patient_id": "p-0", + ... "visit_id": "v-0", + ... "lab_codes": ["lab-1", "lab-2"], + ... "lab_values": [0.2, 0.8], + ... "label": 1, + ... } + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"lab_codes": "sequence", "lab_values": "tensor"}, + ... output_schema={"label": "binary"}, + ... dataset_name="labrador_demo", + ... ) + >>> model = LabradorModel( + ... dataset=dataset, + ... code_feature_key="lab_codes", + ... value_feature_key="lab_values", + ... ) + """ + + def __init__( + self, + dataset: SampleDataset, + code_feature_key: str, + value_feature_key: str, + embed_dim: int = 128, + num_heads: int = 2, + num_layers: int = 2, + dropout: float = 0.1, + ff_hidden_dim: Optional[int] = None, + classifier_hidden_dim: Optional[int] = None, + ): + super().__init__(dataset=dataset) + + assert len(self.label_keys) == 1, "Only one label key is supported" + self.label_key = self.label_keys[0] + + if code_feature_key not in self.feature_keys: + raise ValueError( + f"code_feature_key='{code_feature_key}' not found in dataset input schema: " + f"{self.feature_keys}" + ) + if value_feature_key not in self.feature_keys: + raise ValueError( + f"value_feature_key='{value_feature_key}' not found in dataset input schema: " + f"{self.feature_keys}" + ) + + self.code_feature_key = code_feature_key + self.value_feature_key = value_feature_key + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_layers = num_layers + + num_labs = self.dataset.input_processors[self.code_feature_key].size() + if num_labs is None: + raise ValueError( + "LabradorModel requires a tokenized categorical code feature with known vocabulary size. " + f"Feature '{self.code_feature_key}' returned size=None." + ) + self.num_labs = int(num_labs) + + ff_hidden_dim = ff_hidden_dim or embed_dim + classifier_hidden_dim = classifier_hidden_dim or embed_dim + + self.code_embedding = nn.Embedding(self.num_labs, embed_dim) + self.value_projection = nn.Linear(1, embed_dim) + self.value_fusion = nn.Linear(embed_dim, embed_dim) + self.value_act = nn.ReLU() + self.value_norm = nn.LayerNorm(embed_dim) + + encoder_layer = nn.TransformerEncoderLayer( + d_model=embed_dim, + nhead=num_heads, + dim_feedforward=ff_hidden_dim, + dropout=dropout, + batch_first=True, + activation="relu", + ) + self.transformer = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_layers, + ) + + output_size = self.get_output_size() + self.classifier = nn.Sequential( + nn.Linear(embed_dim, classifier_hidden_dim), + nn.ReLU(), + nn.Linear(classifier_hidden_dim, output_size), + ) + + def _extract_value_and_mask( + self, feature_key: str, feature: torch.Tensor | Tuple[torch.Tensor, ...] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Extract value tensor and optional mask from a processed feature. + + Args: + feature_key: Name of the dataset feature being decoded. + feature: Processor output for one feature. This may be either a + tensor or a tuple containing tensors such as value and mask. + + Returns: + A tuple ``(value, mask)`` where ``value`` is the feature tensor and + ``mask`` is the optional validity mask if provided by the processor. + + Raises: + ValueError: If the processor schema does not contain a ``value`` + field. + """ + if isinstance(feature, torch.Tensor): + feature_tuple: Tuple[torch.Tensor, ...] = (feature,) + else: + feature_tuple = feature + + schema = self.dataset.input_processors[feature_key].schema() + + value = feature_tuple[schema.index("value")] if "value" in schema else None + if value is None: + raise ValueError( + f"Feature '{feature_key}' must contain 'value' in processor schema." + ) + + mask = feature_tuple[schema.index("mask")] if "mask" in schema else None + if len(feature_tuple) == len(schema) + 1 and mask is None: + mask = feature_tuple[-1] + + return value, mask + + @staticmethod + def _ensure_2d(tensor: torch.Tensor, name: str) -> torch.Tensor: + """Normalize tensor to ``[batch, num_labs]`` for aligned lab streams. + + Args: + tensor: Input tensor representing lab codes, values, or masks. + name: Human-readable tensor name used in error messages. + + Returns: + A 2D tensor of shape ``[batch, num_labs]``. + + Raises: + ValueError: If the input tensor cannot be interpreted as a 2D lab + matrix. + """ + if tensor.dim() == 2: + return tensor + if tensor.dim() == 3 and tensor.size(-1) == 1: + return tensor.squeeze(-1) + raise ValueError( + f"Expected {name} to have shape [batch, num_labs] " + f"(or [..., 1]), got {tuple(tensor.shape)}" + ) + + @staticmethod + def _mean_pool(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Apply masked mean pooling over the lab dimension. + + Args: + x: Token representations of shape ``[batch, num_labs, embed_dim]``. + mask: Float mask of shape ``[batch, num_labs]`` indicating valid + lab positions. + + Returns: + Pooled patient embeddings of shape ``[batch, embed_dim]``. + """ + denom = mask.sum(dim=1, keepdim=True).clamp(min=1.0) + return (x * mask.unsqueeze(-1)).sum(dim=1) / denom + + def forward( + self, + **kwargs: torch.Tensor | Tuple[torch.Tensor, ...], + ) -> Dict[str, torch.Tensor]: + """Forward propagation. + + Args: + **kwargs: Keyword arguments containing the code feature, + value feature, optional label, and optional ``embed`` flag. + The two feature tensors must describe aligned lab sequences for + the same samples. + + Returns: + A dictionary containing model outputs. This always includes: + - ``logit``: classification logits. + - ``y_prob``: predicted probabilities. + When labels are provided, the dictionary also includes: + - ``loss``: supervised task loss. + - ``y_true``: ground-truth labels. + When ``embed=True`` is passed, the dictionary also includes: + - ``embed``: pooled patient embedding. + + Raises: + ValueError: If code and value features do not have matching shapes. + """ + code_values, code_mask = self._extract_value_and_mask( + self.code_feature_key, kwargs[self.code_feature_key] + ) + lab_values, value_mask = self._extract_value_and_mask( + self.value_feature_key, kwargs[self.value_feature_key] + ) + + codes = self._ensure_2d(code_values.to(self.device), "code feature").long() + values = self._ensure_2d(lab_values.to(self.device), "value feature").float() + + if codes.shape != values.shape: + raise ValueError( + f"Code/value feature shapes must match, got codes={tuple(codes.shape)} " + f"and values={tuple(values.shape)}" + ) + + if code_mask is not None: + mask = self._ensure_2d(code_mask.to(self.device), "code mask").bool() + elif value_mask is not None: + mask = self._ensure_2d(value_mask.to(self.device), "value mask").bool() + else: + mask = codes != 0 + + # Avoid all-masked rows to keep transformer behavior stable. + invalid_rows = ~mask.any(dim=1) + if invalid_rows.any(): + mask[invalid_rows, 0] = True + + code_emb = self.code_embedding(codes) + value_emb = self.value_projection(values.unsqueeze(-1)) + + x = code_emb + value_emb + x = self.value_fusion(x) + x = self.value_act(x) + x = self.value_norm(x) + + x = self.transformer(x, src_key_padding_mask=~mask) + patient_emb = self._mean_pool(x, mask.float()) + logits = self.classifier(patient_emb) + y_prob = self.prepare_y_prob(logits) + + results: Dict[str, torch.Tensor] = { + "logit": logits, + "y_prob": y_prob, + } + + if self.label_key in kwargs: + y_true = cast(torch.Tensor, kwargs[self.label_key]).to(self.device) + loss = self.get_loss_function()(logits, y_true) + results["loss"] = loss + results["y_true"] = y_true + + if kwargs.get("embed", False): + results["embed"] = patient_emb + + return results \ No newline at end of file diff --git a/tests/core/test_labrador.py b/tests/core/test_labrador.py new file mode 100644 index 000000000..5986bf3f9 --- /dev/null +++ b/tests/core/test_labrador.py @@ -0,0 +1,125 @@ +import unittest +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import LabradorModel + + +class TestLabradorModel(unittest.TestCase): + """Minimal smoke test for LabradorModel.""" + + def setUp(self): + """Set up test data and model.""" + # Create minimal synthetic samples with aligned lab codes and values + # Both lab_codes and lab_values represent the same 4 labs per sample + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "lab_codes": ["lab-1", "lab-2", "lab-3", "lab-4"], + "lab_values": [1.0, 2.5, 3.0, 4.5], + "label": 0, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "lab_codes": ["lab-1", "lab-2", "lab-3", "lab-4"], + "lab_values": [2.1, 1.8, 2.9, 3.5], + "label": 1, + }, + ] + + # Define input and output schemas + self.input_schema = { + "lab_codes": "sequence", # Categorical lab codes + "lab_values": "tensor", # Continuous lab values + } + self.output_schema = {"label": "binary"} # Binary classification + + # Create dataset + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test_labrador", + ) + + # Create model + self.model = LabradorModel( + dataset=self.dataset, + code_feature_key="lab_codes", + value_feature_key="lab_values", + embed_dim=32, + num_heads=2, + num_layers=1, + ) + + def test_model_initialization(self): + """Test that LabradorModel initializes correctly.""" + self.assertIsInstance(self.model, LabradorModel) + self.assertEqual(self.model.embed_dim, 32) + self.assertEqual(self.model.num_heads, 2) + self.assertEqual(self.model.num_layers, 1) + self.assertEqual(self.model.code_feature_key, "lab_codes") + self.assertEqual(self.model.value_feature_key, "lab_values") + self.assertEqual(self.model.label_key, "label") + + def test_model_forward(self): + """Test that LabradorModel forward pass works correctly.""" + # Create data loader + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + # Forward pass + with torch.no_grad(): + ret = self.model(**data_batch) + + # Check output structure + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + + # Check tensor shapes + self.assertEqual(ret["y_prob"].shape[0], 2) # batch size + self.assertEqual(ret["y_true"].shape[0], 2) # batch size + self.assertEqual(ret["logit"].shape[0], 2) # batch size + + # Check that loss is a scalar + self.assertEqual(ret["loss"].dim(), 0) + + def test_model_backward(self): + """Test that LabradorModel backward pass works correctly.""" + # Create data loader + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + # Forward pass + ret = self.model(**data_batch) + + # Backward pass + ret["loss"].backward() + + # Check that at least one parameter has gradients + has_gradient = any( + param.requires_grad and param.grad is not None + for param in self.model.parameters() + ) + self.assertTrue(has_gradient, "No parameters have gradients after backward pass") + + def test_model_with_embedding(self): + """Test that LabradorModel returns embeddings when requested.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + data_batch["embed"] = True + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertIn("embed", ret) + self.assertEqual(ret["embed"].shape[0], 2) # batch size + self.assertEqual(ret["embed"].shape[1], 32) # embed_dim + + +if __name__ == "__main__": + unittest.main()