diff --git a/LABRADOR_TEST_SUMMARY.md b/LABRADOR_TEST_SUMMARY.md
new file mode 100644
index 000000000..02c58310b
--- /dev/null
+++ b/LABRADOR_TEST_SUMMARY.md
@@ -0,0 +1,90 @@
+# LabradorModel Test
+
+**Test File:** `tests/core/test_labrador.py`
+
+## Overview
+
+Minimal end-to-end smoke test for `LabradorModel` that verifies:
+- Dataset creation with aligned lab codes and values
+- Model initialization with correct hyperparameters
+- Forward pass with expected output structure
+- Backward pass with gradient computation
+- Embedding extraction when requested
+
+## Test Methods
+
+| Test | Purpose |
+|------|---------|
+| `test_model_initialization` | Verifies model initializes with correct hyperparameters (embed_dim, num_heads, num_layers, feature keys, label key) |
+| `test_model_forward` | Checks forward pass returns `loss`, `y_prob`, `y_true`, `logit` with correct shapes and batch sizes |
+| `test_model_backward` | Confirms backward pass computes gradients on parameters |
+| `test_model_with_embedding` | Validates embedding extraction returns correct shape `[batch_size, embed_dim]` |
+
+---
+
+## Key Implementation Details
+
+## Test Data Structure
+
+**Minimal Synthetic Dataset (2 samples, 4 labs each):**
+```python
+samples = [
+ {
+ "patient_id": "patient-0",
+ "visit_id": "visit-0",
+ "lab_codes": ["lab-1", "lab-2", "lab-3", "lab-4"], # Categorical
+ "lab_values": [1.0, 2.5, 3.0, 4.5], # Continuous
+ "label": 0,
+ },
+ {
+ "patient_id": "patient-1",
+ "visit_id": "visit-1",
+ "lab_codes": ["lab-1", "lab-2", "lab-3", "lab-4"],
+ "lab_values": [2.1, 1.8, 2.9, 3.5],
+ "label": 1,
+ },
+]
+```
+
+**Input Schema:**
+```python
+input_schema = {
+ "lab_codes": "sequence", # → Categorical tokens [batch, 4]
+ "lab_values": "tensor", # → Float values [batch, 4]
+}
+output_schema = {"label": "binary"}
+```
+
+**Model Configuration:**
+```python
+LabradorModel(
+ dataset=dataset,
+ code_feature_key="lab_codes",
+ value_feature_key="lab_values",
+ embed_dim=32, # Lightweight for testing
+ num_heads=2,
+ num_layers=1,
+)
+```
+
+---
+
+## Running the Test
+
+**Option 1: Using pixi (recommended)**
+```bash
+cd /home/leemh/PyHealth
+make init # First time only
+make test
+```
+
+**Option 2: Direct unittest**
+```bash
+cd /home/leemh/PyHealth
+python3 -m unittest tests.core.test_labrador.TestLabradorModel -v
+```
+
+**Option 3: Single test method**
+```bash
+python3 -m unittest tests.core.test_labrador.TestLabradorModel.test_model_forward -v
+```
diff --git a/docs/api/models.rst b/docs/api/models.rst
index 7368dec94..92c44f7ee 100644
--- a/docs/api/models.rst
+++ b/docs/api/models.rst
@@ -197,6 +197,7 @@ API Reference
models/pyhealth.models.MedLink
models/pyhealth.models.TCN
models/pyhealth.models.TFMTokenizer
+ models/pyhealth.models.labrador
models/pyhealth.models.GAN
models/pyhealth.models.VAE
models/pyhealth.models.SDOH
diff --git a/docs/api/models/pyhealth.models.labrador.rst b/docs/api/models/pyhealth.models.labrador.rst
new file mode 100644
index 000000000..347439515
--- /dev/null
+++ b/docs/api/models/pyhealth.models.labrador.rst
@@ -0,0 +1,9 @@
+pyhealth.models.labrador
+===================================
+
+LabradorModel for aligned lab code and lab value inputs.
+
+.. autoclass:: pyhealth.models.LabradorModel
+ :members:
+ :undoc-members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/examples/labrador_ablations_quickstart.ipynb b/examples/labrador_ablations_quickstart.ipynb
new file mode 100644
index 000000000..ac25f0904
--- /dev/null
+++ b/examples/labrador_ablations_quickstart.ipynb
@@ -0,0 +1,1037 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "2542e8fb",
+ "metadata": {},
+ "source": [
+ "# Labrador Minimal Downstream Ablation\n",
+ "\n",
+ "This notebook demonstrates a minimal downstream use of `LabradorModel` on synthetic lab data.\n",
+ "\n",
+ "We:\n",
+ "\n",
+ "1. generate structured synthetic lab samples,\n",
+ "2. define a nonlinear binary classification task,\n",
+ "3. train the model as a classifier,\n",
+ "4. compare performance under simple hyperparameter changes,\n",
+ "5. demonstrate the masked lab pretraining process done in the paper\n",
+ "\n",
+ "The goal is to provide a small, runnable example that illustrates model behavior under different configurations.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "639b1279",
+ "metadata": {},
+ "source": [
+ "## 1) Setup\n",
+ "\n",
+ "This section imports the required libraries, defines a few constants, and sets the random seed so the example is reproducible."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "id": "676b1a77",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Torch: 2.7.1+cu126\n",
+ "CUDA available: True\n"
+ ]
+ }
+ ],
+ "source": [
+ "import random\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import torch\n",
+ "\n",
+ "from pyhealth.datasets import create_sample_dataset, get_dataloader\n",
+ "from pyhealth.models import LabradorModel\n",
+ "\n",
+ "\n",
+ "NUM_LABS = 10\n",
+ "LAB_CODES = [f\"lab-{i}\" for i in range(NUM_LABS)]\n",
+ "SEED = 42\n",
+ "TRAIN_RATIO = 0.6\n",
+ "VAL_RATIO = 0.2\n",
+ "DOWNSTREAM_SAMPLES = 600\n",
+ "EPOCHS = 5\n",
+ "BATCH_SIZE = 64\n",
+ "\n",
+ "\n",
+ "def set_seed(seed: int = 42):\n",
+ " random.seed(seed)\n",
+ " np.random.seed(seed)\n",
+ " torch.manual_seed(seed)\n",
+ " if torch.cuda.is_available():\n",
+ " torch.cuda.manual_seed_all(seed)\n",
+ "\n",
+ "\n",
+ "set_seed(SEED)\n",
+ "print(f\"Torch: {torch.__version__}\")\n",
+ "print(f\"CUDA available: {torch.cuda.is_available()}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "850cc42e",
+ "metadata": {},
+ "source": [
+ "## 2) Synthetic Data\n",
+ "\n",
+ "We construct synthetic lab data with basic structure using **latent profiles**.\n",
+ "\n",
+ "Each sample is generated by first selecting a hidden profile, which controls:\n",
+ "\n",
+ "* which lab codes are more likely to appear together\n",
+ "* how their values are shifted\n",
+ "\n",
+ "A profile can be thought of as a simple “patient type” that induces patterns in the labs.\n",
+ "\n",
+ "### Intuition\n",
+ "\n",
+ "For example:\n",
+ "\n",
+ "* **Profile A (metabolic-like)**\n",
+ "\n",
+ " * more likely to include labs such as `lab_2`, `lab_7`, `lab_8`\n",
+ " * values tend to be higher on average\n",
+ "\n",
+ "* **Profile B (baseline-like)**\n",
+ "\n",
+ " * labs are more uniformly distributed\n",
+ " * values are centered around lower or neutral ranges\n",
+ "\n",
+ "A generated sample might look like (metabolic profile):\n",
+ "\n",
+ "* Codes: `[2, 7, 1, 9]`\n",
+ "* Values: `[0.8, 0.7, 0.2, 0.9]`\n",
+ "\n",
+ "Another sample from a different profile (baseline):\n",
+ "\n",
+ "* Codes: `[1, 3, 5, 0]`\n",
+ "* Values: `[0.2, 0.3, 0.4, 0.1]`\n",
+ "\n",
+ "These profiles create **co-occurrence structure** (which labs appear together) and **value patterns**, making the data more realistic than purely random samples.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "### Downstream label\n",
+ "\n",
+ "The label is defined by a nonlinear rule:\n",
+ "\n",
+ "* positive if `(lab_2 × lab_7)` is large and `lab_1` is below a threshold\n",
+ "\n",
+ "This requires the model to:\n",
+ "\n",
+ "* identify specific lab types\n",
+ "* reason about their values\n",
+ "* capture interactions between labs\n",
+ "\n",
+ "Overall, the synthetic setup introduces just enough structure to make the task non-trivial while remaining lightweight.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "id": "f2bc5d41",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def generate_structured_fake_batch(\n",
+ " batch_size: int = 32,\n",
+ " num_labs: int = NUM_LABS,\n",
+ "):\n",
+ " \"\"\"Generate synthetic lab bags with profile-specific code/value structure.\"\"\"\n",
+ " panel_probs = torch.tensor(\n",
+ " [\n",
+ " [0.05, 0.18, 0.07, 0.06, 0.05, 0.05, 0.18, 0.18, 0.10, 0.08],\n",
+ " [0.12, 0.18, 0.16, 0.15, 0.06, 0.05, 0.13, 0.05, 0.05, 0.05],\n",
+ " [0.10, 0.08, 0.05, 0.05, 0.16, 0.16, 0.10, 0.08, 0.12, 0.10],\n",
+ " ],\n",
+ " dtype=torch.float32,\n",
+ " )\n",
+ "\n",
+ " anchor_codes = torch.tensor([1, 2, 7], dtype=torch.long)\n",
+ " panel_ids = torch.randint(0, 3, (batch_size,))\n",
+ "\n",
+ " codes = torch.empty(batch_size, num_labs, dtype=torch.long)\n",
+ " values = torch.empty(batch_size, num_labs, dtype=torch.float32)\n",
+ "\n",
+ " for i in range(batch_size):\n",
+ " panel = panel_ids[i].item()\n",
+ " extra_codes = torch.multinomial(\n",
+ " panel_probs[panel],\n",
+ " num_samples=num_labs - len(anchor_codes),\n",
+ " replacement=True,\n",
+ " )\n",
+ " sample_codes = torch.cat([anchor_codes, extra_codes], dim=0)\n",
+ " sample_codes = sample_codes[torch.randperm(num_labs)]\n",
+ "\n",
+ " sample_values = sample_codes.float() / (num_labs - 1)\n",
+ "\n",
+ " if panel == 0:\n",
+ " sample_values[(sample_codes == 1) | (sample_codes == 6) | (sample_codes == 7)] += 0.18\n",
+ " sample_values[sample_codes == 0] -= 0.05\n",
+ " elif panel == 1:\n",
+ " sample_values[(sample_codes == 2) | (sample_codes == 3)] += 0.15\n",
+ " sample_values[sample_codes == 0] += 0.25\n",
+ " sample_values[sample_codes == 1] += 0.05\n",
+ " sample_values[sample_codes == 6] += 0.05\n",
+ " else:\n",
+ " sample_values[\n",
+ " (sample_codes == 4)\n",
+ " | (sample_codes == 5)\n",
+ " | (sample_codes == 8)\n",
+ " | (sample_codes == 9)\n",
+ " ] += 0.18\n",
+ " sample_values[sample_codes == 1] -= 0.06\n",
+ " sample_values[sample_codes == 0] += 0.08\n",
+ "\n",
+ " sample_values += 0.05 * torch.randn_like(sample_values)\n",
+ " sample_values = torch.clamp(sample_values, 0.0, 1.0)\n",
+ "\n",
+ " codes[i] = sample_codes\n",
+ " values[i] = sample_values\n",
+ "\n",
+ " return codes, values\n",
+ "\n",
+ "\n",
+ "def get_mean_value_for_code(\n",
+ " codes: torch.Tensor,\n",
+ " values: torch.Tensor,\n",
+ " target_code: int,\n",
+ ") -> torch.Tensor:\n",
+ " mask = (codes == target_code).float()\n",
+ " counts = mask.sum(dim=1)\n",
+ " summed = (values * mask).sum(dim=1)\n",
+ " return torch.where(\n",
+ " counts > 0,\n",
+ " summed / counts.clamp(min=1.0),\n",
+ " torch.zeros_like(summed),\n",
+ " )\n",
+ "\n",
+ "\n",
+ "def generate_downstream_labels(\n",
+ " codes: torch.Tensor,\n",
+ " values: torch.Tensor,\n",
+ ") -> torch.Tensor:\n",
+ " code_1_value = get_mean_value_for_code(codes, values, target_code=1)\n",
+ " code_2_value = get_mean_value_for_code(codes, values, target_code=2)\n",
+ " code_7_value = get_mean_value_for_code(codes, values, target_code=7)\n",
+ " return ((code_2_value * code_7_value > 0.20) & (code_1_value < 0.50)).long()\n",
+ "\n",
+ "\n",
+ "def make_downstream_samples(\n",
+ " n_samples: int,\n",
+ " seed: int,\n",
+ " num_labs: int = NUM_LABS,\n",
+ "):\n",
+ " set_seed(seed)\n",
+ " codes, values = generate_structured_fake_batch(\n",
+ " batch_size=n_samples,\n",
+ " num_labs=num_labs,\n",
+ " )\n",
+ " labels = generate_downstream_labels(codes, values)\n",
+ "\n",
+ " samples = []\n",
+ " for i in range(n_samples):\n",
+ " samples.append(\n",
+ " {\n",
+ " \"patient_id\": f\"patient-{i}\",\n",
+ " \"visit_id\": f\"visit-{i}\",\n",
+ " \"lab_codes\": [LAB_CODES[int(code)] for code in codes[i].tolist()],\n",
+ " \"lab_values\": values[i].tolist(),\n",
+ " \"label\": int(labels[i].item()),\n",
+ " }\n",
+ " )\n",
+ " return samples\n",
+ "\n",
+ "\n",
+ "def build_dataset(samples):\n",
+ " return create_sample_dataset(\n",
+ " samples=samples,\n",
+ " input_schema={\"lab_codes\": \"sequence\", \"lab_values\": \"tensor\"},\n",
+ " output_schema={\"label\": \"binary\"},\n",
+ " dataset_name=\"labrador_downstream_demo\",\n",
+ " )\n",
+ "\n",
+ "\n",
+ "def split_samples(samples):\n",
+ " n_samples = len(samples)\n",
+ " train_end = int(n_samples * TRAIN_RATIO)\n",
+ " val_end = int(n_samples * (TRAIN_RATIO + VAL_RATIO))\n",
+ " return samples[:train_end], samples[train_end:val_end], samples[val_end:]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "df0a7a18",
+ "metadata": {},
+ "source": [
+ "## 3) Training and Evaluation Helpers\n",
+ "\n",
+ "This section defines the minimal training loop, evaluation function, and one helper that runs a single Labrador configuration on the downstream task."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 50,
+ "id": "c94ad128",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def train_model(model, dataset, lr: float):\n",
+ " model.train()\n",
+ " optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
+ " loader = get_dataloader(dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
+ "\n",
+ " for _ in range(EPOCHS):\n",
+ " for batch in loader:\n",
+ " optimizer.zero_grad()\n",
+ " output = model(**batch)\n",
+ " output[\"loss\"].backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ "\n",
+ "def compute_binary_metrics(y_true: torch.Tensor, y_pred: torch.Tensor):\n",
+ " y_true = y_true.float()\n",
+ " y_pred = y_pred.float()\n",
+ "\n",
+ " tp = ((y_pred == 1) & (y_true == 1)).sum().item()\n",
+ " tn = ((y_pred == 0) & (y_true == 0)).sum().item()\n",
+ " fp = ((y_pred == 1) & (y_true == 0)).sum().item()\n",
+ " fn = ((y_pred == 0) & (y_true == 1)).sum().item()\n",
+ "\n",
+ " total = max(tp + tn + fp + fn, 1)\n",
+ " accuracy = (tp + tn) / total\n",
+ "\n",
+ " precision_den = max(tp + fp, 1)\n",
+ " recall_den = max(tp + fn, 1)\n",
+ " precision = tp / precision_den\n",
+ " recall = tp / recall_den\n",
+ "\n",
+ " f1_den = max(precision + recall, 1e-12)\n",
+ " f1 = 2 * precision * recall / f1_den\n",
+ "\n",
+ " return {\n",
+ " \"accuracy\": float(accuracy),\n",
+ " \"precision\": float(precision),\n",
+ " \"recall\": float(recall),\n",
+ " \"f1\": float(f1),\n",
+ " }\n",
+ "\n",
+ "\n",
+ "def evaluate_model(model, dataset):\n",
+ " model.eval()\n",
+ " loader = get_dataloader(dataset, batch_size=BATCH_SIZE, shuffle=False)\n",
+ " y_true_all = []\n",
+ " y_pred_all = []\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " for batch in loader:\n",
+ " output = model(**batch)\n",
+ " y_true = output[\"y_true\"].detach().cpu().view(-1).float()\n",
+ " y_prob = output[\"y_prob\"].detach().cpu()\n",
+ "\n",
+ " if y_prob.dim() == 2 and y_prob.shape[1] > 1:\n",
+ " y_pred = torch.argmax(y_prob, dim=1).view(-1).float()\n",
+ " else:\n",
+ " y_pred = (y_prob.view(-1) > 0.5).float()\n",
+ "\n",
+ " y_true_all.append(y_true)\n",
+ " y_pred_all.append(y_pred)\n",
+ "\n",
+ " y_true_all = torch.cat(y_true_all)\n",
+ " y_pred_all = torch.cat(y_pred_all)\n",
+ " return compute_binary_metrics(y_true_all, y_pred_all)\n",
+ "\n",
+ "\n",
+ "def run_experiment(train_dataset, test_dataset, embed_dim: int, lr: float):\n",
+ " set_seed(SEED)\n",
+ " model = LabradorModel(\n",
+ " dataset=train_dataset,\n",
+ " code_feature_key=\"lab_codes\",\n",
+ " value_feature_key=\"lab_values\",\n",
+ " embed_dim=embed_dim,\n",
+ " num_heads=2,\n",
+ " num_layers=1,\n",
+ " )\n",
+ " train_model(model, train_dataset, lr=lr)\n",
+ " return evaluate_model(model, test_dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a9dc9e8e",
+ "metadata": {},
+ "source": [
+ "## 4) Create the Synthetic Downstream Dataset\n",
+ "\n",
+ "This section generates one small synthetic dataset, shuffles it, splits it into train/validation/test subsets, and converts each split into a PyHealth dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "id": "1ae4d1a1",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Label label vocab: {0: 0, 1: 1}\n",
+ "Label label vocab: {0: 0, 1: 1}\n",
+ "Label label vocab: {0: 0, 1: 1}\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "train/val/test: 360/120/120\n",
+ "train positive rate: 0.603\n",
+ "test positive rate: 0.600\n"
+ ]
+ }
+ ],
+ "source": [
+ "samples = make_downstream_samples(DOWNSTREAM_SAMPLES, seed=SEED)\n",
+ "random.Random(SEED).shuffle(samples)\n",
+ "train_samples, val_samples, test_samples = split_samples(samples)\n",
+ "\n",
+ "train_dataset = build_dataset(train_samples)\n",
+ "val_dataset = build_dataset(val_samples)\n",
+ "test_dataset = build_dataset(test_samples)\n",
+ "\n",
+ "print(f\"train/val/test: {len(train_dataset)}/{len(val_dataset)}/{len(test_dataset)}\")\n",
+ "print(f\"train positive rate: {np.mean([sample['label'] for sample in train_samples]):.3f}\")\n",
+ "print(f\"test positive rate: {np.mean([sample['label'] for sample in test_samples]):.3f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "94fdba18",
+ "metadata": {},
+ "source": [
+ "## 5) Embedding Dimension Ablation\n",
+ "\n",
+ "We vary model capacity by changing the embedding dimension:\n",
+ "\n",
+ "* `embed_dim = 64`\n",
+ "* `embed_dim = 128`\n",
+ "* `embed_dim = 256`\n",
+ "\n",
+ "The learning rate is fixed at `1e-3`.\n",
+ "\n",
+ "Each configuration is trained and evaluated on the same dataset, allowing direct comparison of model capacity.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
+ "id": "2ecd9e53",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " embed_dim | \n",
+ " accuracy | \n",
+ " precision | \n",
+ " recall | \n",
+ " f1 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 64 | \n",
+ " 0.466667 | \n",
+ " 0.590909 | \n",
+ " 0.361111 | \n",
+ " 0.448276 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 128 | \n",
+ " 0.500000 | \n",
+ " 0.603448 | \n",
+ " 0.486111 | \n",
+ " 0.538462 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 256 | \n",
+ " 0.425000 | \n",
+ " 0.548387 | \n",
+ " 0.236111 | \n",
+ " 0.330097 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " embed_dim accuracy precision recall f1\n",
+ "0 64 0.466667 0.590909 0.361111 0.448276\n",
+ "1 128 0.500000 0.603448 0.486111 0.538462\n",
+ "2 256 0.425000 0.548387 0.236111 0.330097"
+ ]
+ },
+ "execution_count": 54,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "embed_dim_results = []\n",
+ "for embed_dim in [64, 128, 256]:\n",
+ " metrics = run_experiment(\n",
+ " train_dataset=train_dataset,\n",
+ " test_dataset=test_dataset,\n",
+ " embed_dim=embed_dim,\n",
+ " lr=1e-3,\n",
+ " )\n",
+ " embed_dim_results.append({\"embed_dim\": embed_dim, **metrics})\n",
+ "\n",
+ "embed_dim_df = pd.DataFrame(embed_dim_results)\n",
+ "embed_dim_df"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "83deadcb",
+ "metadata": {},
+ "source": [
+ "## 6) Learning Rate Ablation\n",
+ "\n",
+ "We vary the optimizer learning rate while keeping model size fixed (`embed_dim = 128`):\n",
+ "\n",
+ "* `lr = 1e-4`\n",
+ "* `lr = 1e-3`\n",
+ "* `lr = 5e-3`\n",
+ "\n",
+ "This evaluates how sensitive training performance is to optimization settings.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 55,
+ "id": "96cde05b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " learning_rate | \n",
+ " accuracy | \n",
+ " precision | \n",
+ " recall | \n",
+ " f1 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 0.0001 | \n",
+ " 0.608333 | \n",
+ " 0.610619 | \n",
+ " 0.958333 | \n",
+ " 0.745946 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 0.0010 | \n",
+ " 0.500000 | \n",
+ " 0.603448 | \n",
+ " 0.486111 | \n",
+ " 0.538462 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 0.0050 | \n",
+ " 0.425000 | \n",
+ " 0.551724 | \n",
+ " 0.222222 | \n",
+ " 0.316832 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " learning_rate accuracy precision recall f1\n",
+ "0 0.0001 0.608333 0.610619 0.958333 0.745946\n",
+ "1 0.0010 0.500000 0.603448 0.486111 0.538462\n",
+ "2 0.0050 0.425000 0.551724 0.222222 0.316832"
+ ]
+ },
+ "execution_count": 55,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "learning_rate_results = []\n",
+ "for lr in [1e-4, 1e-3, 5e-3]:\n",
+ " metrics = run_experiment(\n",
+ " train_dataset=train_dataset,\n",
+ " test_dataset=test_dataset,\n",
+ " embed_dim=128,\n",
+ " lr=lr,\n",
+ " )\n",
+ " learning_rate_results.append({\"learning_rate\": lr, **metrics})\n",
+ "\n",
+ "learning_rate_df = pd.DataFrame(learning_rate_results)\n",
+ "learning_rate_df"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "62b53d17",
+ "metadata": {},
+ "source": [
+ "## 7) Summary\n",
+ "\n",
+ "This section prints the final metric comparisons in a compact format, including accuracy, precision, recall, and F1."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 65,
+ "id": "df4b9dba",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Embedding Dimension Ablation\n",
+ "embed_dim=64 -> accuracy=0.4667, precision=0.5909, recall=0.3611, f1=0.4483\n",
+ "embed_dim=128 -> accuracy=0.5000, precision=0.6034, recall=0.4861, f1=0.5385\n",
+ "embed_dim=256 -> accuracy=0.4250, precision=0.5484, recall=0.2361, f1=0.3301\n",
+ "\n",
+ "Learning Rate Ablation\n",
+ "lr=1e-04 -> accuracy=0.6083, precision=0.6106, recall=0.9583, f1=0.7459\n",
+ "lr=1e-03 -> accuracy=0.5000, precision=0.6034, recall=0.4861, f1=0.5385\n",
+ "lr=5e-03 -> accuracy=0.4250, precision=0.5517, recall=0.2222, f1=0.3168\n",
+ "\n",
+ "Embedding Dimension Results\n",
+ " embed_dim accuracy precision recall f1\n",
+ " 64 0.466667 0.590909 0.361111 0.448276\n",
+ " 128 0.500000 0.603448 0.486111 0.538462\n",
+ " 256 0.425000 0.548387 0.236111 0.330097\n",
+ "\n",
+ "Learning Rate Results\n",
+ " learning_rate accuracy precision recall f1\n",
+ " 0.0001 0.608333 0.610619 0.958333 0.745946\n",
+ " 0.0010 0.500000 0.603448 0.486111 0.538462\n",
+ " 0.0050 0.425000 0.551724 0.222222 0.316832\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Embedding Dimension Ablation\")\n",
+ "for row in embed_dim_results:\n",
+ " print(\n",
+ " f\"embed_dim={row['embed_dim']} \"\n",
+ " f\"-> accuracy={row['accuracy']:.4f}, precision={row['precision']:.4f}, \"\n",
+ " f\"recall={row['recall']:.4f}, f1={row['f1']:.4f}\"\n",
+ " )\n",
+ "print()\n",
+ "\n",
+ "print(\"Learning Rate Ablation\")\n",
+ "for row in learning_rate_results:\n",
+ " print(\n",
+ " f\"lr={row['learning_rate']:.0e} \"\n",
+ " f\"-> accuracy={row['accuracy']:.4f}, precision={row['precision']:.4f}, \"\n",
+ " f\"recall={row['recall']:.4f}, f1={row['f1']:.4f}\"\n",
+ " )\n",
+ "print()\n",
+ "\n",
+ "print(\"Embedding Dimension Results\")\n",
+ "print(embed_dim_df.to_string(index=False))\n",
+ "print()\n",
+ "\n",
+ "print(\"Learning Rate Results\")\n",
+ "print(learning_rate_df.to_string(index=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9b65316e",
+ "metadata": {},
+ "source": [
+ "## 8) Lightweight Masked-Lab Demonstration\n",
+ "\n",
+ "This section provides a simplified demonstration of the **masked pretraining objective** used in the Labrador paper.\n",
+ "\n",
+ "In Labrador, the model is trained by masking one lab in a patient’s lab set and predicting:\n",
+ "\n",
+ "* the **lab identity (code)**, and\n",
+ "* the **lab value**\n",
+ "\n",
+ "This is similar to masked language modeling, where the model learns to infer missing information from context.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "### Example\n",
+ "\n",
+ "Given labs:\n",
+ "\n",
+ "* Codes: `[2, 7, 1, 9]`\n",
+ "* Values: `[0.8, 0.7, 0.2, 0.9]`\n",
+ "\n",
+ "Mask one position:\n",
+ "\n",
+ "* Codes: `[2, [MASK], 1, 9]`\n",
+ "* Values: `[0.8, 0.0, 0.2, 0.9]`\n",
+ "\n",
+ "The model predicts the missing code (`7`) and value (`0.7`) using the remaining labs.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "### What's implemented here\n",
+ "\n",
+ "* Mask one lab per sample\n",
+ "* Use `LabradorModel` encoder (embedding + transformer)\n",
+ "* Add small heads to predict:\n",
+ "\n",
+ " * masked code (cross-entropy)\n",
+ " * masked value (MSE)\n",
+ "* Report evaluation metrics (accuracy, precision, recall, F1) for masked code prediction and MSE for masked value prediction.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "### Scope note\n",
+ "\n",
+ "This is a **lightweight demo only**:\n",
+ "\n",
+ "* uses synthetic data\n",
+ "* small model and short training\n",
+ "* no transfer to downstream tasks\n",
+ "\n",
+ "It is meant to illustrate the **pretraining mechanism**, not reproduce the full paper results."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 63,
+ "id": "ca0ef6a1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "\n",
+ "def mask_one_lab(codes: torch.Tensor, values: torch.Tensor, mask_token_id: int):\n",
+ " \"\"\"Mask one random lab position per sample.\"\"\"\n",
+ " batch_size, seq_len = codes.shape\n",
+ " mask_idx = torch.randint(0, seq_len, (batch_size,))\n",
+ "\n",
+ " row_idx = torch.arange(batch_size)\n",
+ " target_code = codes[row_idx, mask_idx].clone()\n",
+ " target_value = values[row_idx, mask_idx].clone()\n",
+ "\n",
+ " masked_codes = codes.clone()\n",
+ " masked_values = values.clone()\n",
+ "\n",
+ " masked_codes[row_idx, mask_idx] = mask_token_id\n",
+ " masked_values[row_idx, mask_idx] = 0.0\n",
+ "\n",
+ " return masked_codes, masked_values, mask_idx, target_code, target_value\n",
+ "\n",
+ "\n",
+ "def encode_with_labrador(model: LabradorModel, codes: torch.Tensor, values: torch.Tensor):\n",
+ " \"\"\"Reuse LabradorModel encoder blocks to produce token representations.\"\"\"\n",
+ " code_emb = model.code_embedding(codes)\n",
+ " value_emb = model.value_projection(values.unsqueeze(-1))\n",
+ "\n",
+ " x = code_emb + value_emb\n",
+ " x = model.value_fusion(x)\n",
+ " x = model.value_act(x)\n",
+ " x = model.value_norm(x)\n",
+ "\n",
+ " # Keep all token positions active in this synthetic masked-lab demo.\n",
+ " token_mask = torch.ones_like(codes, dtype=torch.bool)\n",
+ " x = model.transformer(x, src_key_padding_mask=~token_mask)\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "def compute_multiclass_metrics(\n",
+ " y_true: torch.Tensor,\n",
+ " y_pred: torch.Tensor,\n",
+ " num_classes: int,\n",
+ "):\n",
+ " \"\"\"Macro precision/recall/F1 for masked code prediction.\"\"\"\n",
+ " y_true = y_true.long()\n",
+ " y_pred = y_pred.long()\n",
+ "\n",
+ " acc = (y_true == y_pred).float().mean().item()\n",
+ "\n",
+ " precisions = []\n",
+ " recalls = []\n",
+ " f1s = []\n",
+ "\n",
+ " for cls in range(num_classes):\n",
+ " tp = ((y_pred == cls) & (y_true == cls)).sum().item()\n",
+ " fp = ((y_pred == cls) & (y_true != cls)).sum().item()\n",
+ " fn = ((y_pred != cls) & (y_true == cls)).sum().item()\n",
+ "\n",
+ " precision = tp / max(tp + fp, 1)\n",
+ " recall = tp / max(tp + fn, 1)\n",
+ " f1 = 0.0 if (precision + recall) == 0 else (2 * precision * recall) / (precision + recall)\n",
+ "\n",
+ " precisions.append(precision)\n",
+ " recalls.append(recall)\n",
+ " f1s.append(f1)\n",
+ "\n",
+ " return {\n",
+ " \"accuracy\": float(acc),\n",
+ " \"precision\": float(np.mean(precisions)),\n",
+ " \"recall\": float(np.mean(recalls)),\n",
+ " \"f1\": float(np.mean(f1s)),\n",
+ " }"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 64,
+ "id": "6ca044b2",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " ce_loss | \n",
+ " mse_loss | \n",
+ " total_loss | \n",
+ " accuracy | \n",
+ " precision | \n",
+ " recall | \n",
+ " f1 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 1.920523 | \n",
+ " 0.083962 | \n",
+ " 2.004485 | \n",
+ " 0.3125 | \n",
+ " 0.089286 | \n",
+ " 0.141876 | \n",
+ " 0.104206 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " ce_loss mse_loss total_loss accuracy precision recall f1\n",
+ "0 1.920523 0.083962 2.004485 0.3125 0.089286 0.141876 0.104206"
+ ]
+ },
+ "execution_count": 64,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "set_seed(SEED)\n",
+ "\n",
+ "# Eval synthetic batch for masked prediction demo.\n",
+ "eval_codes, eval_values = generate_structured_fake_batch(batch_size=128, num_labs=NUM_LABS)\n",
+ "\n",
+ "# Reuse the same LabradorModel encoder blocks.\n",
+ "masked_model = LabradorModel(\n",
+ " dataset=train_dataset,\n",
+ " code_feature_key=\"lab_codes\",\n",
+ " value_feature_key=\"lab_values\",\n",
+ " embed_dim=64,\n",
+ " num_heads=2,\n",
+ " num_layers=1,\n",
+ ")\n",
+ "\n",
+ "# Introduce a dedicated mask token outside the normal code range [0, num_labs-1].\n",
+ "mask_token_id = masked_model.num_labs\n",
+ "\n",
+ "# Expand embedding table by one row so mask token is valid.\n",
+ "old_embedding = masked_model.code_embedding\n",
+ "expanded_embedding = nn.Embedding(masked_model.num_labs + 1, masked_model.embed_dim)\n",
+ "with torch.no_grad():\n",
+ " expanded_embedding.weight[: masked_model.num_labs].copy_(old_embedding.weight)\n",
+ " expanded_embedding.weight[mask_token_id].copy_(old_embedding.weight.mean(dim=0))\n",
+ "masked_model.code_embedding = expanded_embedding\n",
+ "\n",
+ "# Lightweight masked heads (not added to core model class).\n",
+ "mask_code_head = nn.Linear(masked_model.embed_dim, masked_model.num_labs)\n",
+ "mask_value_head = nn.Sequential(\n",
+ " nn.Linear(masked_model.embed_dim + masked_model.num_labs, masked_model.embed_dim),\n",
+ " nn.ReLU(),\n",
+ " nn.Linear(masked_model.embed_dim, 1),\n",
+ " nn.Sigmoid(),\n",
+ ")\n",
+ "\n",
+ "optimizer = torch.optim.Adam(\n",
+ " list(masked_model.parameters()) + list(mask_code_head.parameters()) + list(mask_value_head.parameters()),\n",
+ " lr=1e-3,\n",
+ ")\n",
+ "ce_loss_fn = nn.CrossEntropyLoss()\n",
+ "mse_loss_fn = nn.MSELoss()\n",
+ "\n",
+ "MASK_STEPS = 500\n",
+ "for _ in range(MASK_STEPS):\n",
+ " train_codes, train_values = generate_structured_fake_batch(batch_size=256, num_labs=NUM_LABS)\n",
+ " masked_codes, masked_values, mask_idx, target_code, target_value = mask_one_lab(\n",
+ " train_codes,\n",
+ " train_values,\n",
+ " mask_token_id,\n",
+ " )\n",
+ "\n",
+ " x = encode_with_labrador(masked_model, masked_codes, masked_values)\n",
+ "\n",
+ " code_logits = mask_code_head(x)\n",
+ " code_probs = torch.softmax(code_logits, dim=-1)\n",
+ " value_pred = mask_value_head(torch.cat([x, code_probs], dim=-1)).squeeze(-1)\n",
+ "\n",
+ " row_idx = torch.arange(masked_codes.size(0))\n",
+ " masked_code_logits = code_logits[row_idx, mask_idx]\n",
+ " masked_value_pred = value_pred[row_idx, mask_idx]\n",
+ "\n",
+ " loss_code = ce_loss_fn(masked_code_logits, target_code)\n",
+ " loss_value = mse_loss_fn(masked_value_pred, target_value)\n",
+ " loss = loss_code + loss_value\n",
+ "\n",
+ " optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ "# Evaluation on held-out synthetic batch\n",
+ "masked_codes, masked_values, mask_idx, target_code, target_value = mask_one_lab(\n",
+ " eval_codes,\n",
+ " eval_values,\n",
+ " mask_token_id,\n",
+ ")\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " x = encode_with_labrador(masked_model, masked_codes, masked_values)\n",
+ " code_logits = mask_code_head(x)\n",
+ " code_probs = torch.softmax(code_logits, dim=-1)\n",
+ " value_pred = mask_value_head(torch.cat([x, code_probs], dim=-1)).squeeze(-1)\n",
+ "\n",
+ "row_idx = torch.arange(masked_codes.size(0))\n",
+ "masked_code_logits = code_logits[row_idx, mask_idx]\n",
+ "masked_value_pred = value_pred[row_idx, mask_idx]\n",
+ "\n",
+ "loss_code_eval = ce_loss_fn(masked_code_logits, target_code)\n",
+ "loss_value_eval = mse_loss_fn(masked_value_pred, target_value)\n",
+ "loss_total_eval = loss_code_eval + loss_value_eval\n",
+ "\n",
+ "pred_code = torch.argmax(masked_code_logits, dim=1)\n",
+ "code_metrics = compute_multiclass_metrics(target_code, pred_code, num_classes=masked_model.num_labs)\n",
+ "\n",
+ "masked_demo_results = pd.DataFrame([\n",
+ " {\n",
+ " \"ce_loss\": float(loss_code_eval.item()),\n",
+ " \"mse_loss\": float(loss_value_eval.item()),\n",
+ " \"total_loss\": float(loss_total_eval.item()),\n",
+ " \"accuracy\": code_metrics[\"accuracy\"],\n",
+ " \"precision\": code_metrics[\"precision\"],\n",
+ " \"recall\": code_metrics[\"recall\"],\n",
+ " \"f1\": code_metrics[\"f1\"],\n",
+ " }\n",
+ "])\n",
+ "\n",
+ "masked_demo_results"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py
index 5233b1726..885af2d98 100644
--- a/pyhealth/models/__init__.py
+++ b/pyhealth/models/__init__.py
@@ -9,6 +9,7 @@
from .embedding import EmbeddingModel
from .gamenet import GAMENet, GAMENetLayer
from .jamba_ehr import JambaEHR, JambaLayer
+from .labrador import LabradorModel
from .logistic_regression import LogisticRegression
from .gan import GAN
from .gnn import GAT, GCN
diff --git a/pyhealth/models/labrador.py b/pyhealth/models/labrador.py
new file mode 100644
index 000000000..961ad90d5
--- /dev/null
+++ b/pyhealth/models/labrador.py
@@ -0,0 +1,288 @@
+from typing import Dict, Optional, Tuple, cast
+
+import torch
+import torch.nn as nn
+
+from pyhealth.datasets import SampleDataset
+from pyhealth.models import BaseModel
+
+
+class LabradorModel(BaseModel):
+ """Transformer-based model for tabular lab code/value inputs.
+
+ The model consumes two aligned feature streams:
+ 1) categorical lab codes (token ids)
+ 2) continuous lab values
+
+ Architecture:
+ - code embedding
+ - value projection
+ - additive fusion -> linear -> ReLU -> LayerNorm
+ - Transformer encoder (no positional encoding)
+ - mean pooling over lab dimension
+ - MLP classifier head
+
+ Args:
+ dataset: The dataset used by PyHealth trainers.
+ code_feature_key: Input feature key containing lab code tokens.
+ value_feature_key: Input feature key containing lab values.
+ embed_dim: Hidden size for embeddings and transformer blocks.
+ num_heads: Number of attention heads.
+ num_layers: Number of transformer encoder layers.
+ dropout: Dropout used by the transformer encoder layer.
+ ff_hidden_dim: Feed-forward width inside each transformer layer.
+ classifier_hidden_dim: Hidden width of classifier MLP head.
+
+ Examples:
+ >>> from pyhealth.datasets import create_sample_dataset
+ >>> from pyhealth.models import LabradorModel
+ >>> samples = [
+ ... {
+ ... "patient_id": "p-0",
+ ... "visit_id": "v-0",
+ ... "lab_codes": ["lab-1", "lab-2"],
+ ... "lab_values": [0.2, 0.8],
+ ... "label": 1,
+ ... }
+ ... ]
+ >>> dataset = create_sample_dataset(
+ ... samples=samples,
+ ... input_schema={"lab_codes": "sequence", "lab_values": "tensor"},
+ ... output_schema={"label": "binary"},
+ ... dataset_name="labrador_demo",
+ ... )
+ >>> model = LabradorModel(
+ ... dataset=dataset,
+ ... code_feature_key="lab_codes",
+ ... value_feature_key="lab_values",
+ ... )
+ """
+
+ def __init__(
+ self,
+ dataset: SampleDataset,
+ code_feature_key: str,
+ value_feature_key: str,
+ embed_dim: int = 128,
+ num_heads: int = 2,
+ num_layers: int = 2,
+ dropout: float = 0.1,
+ ff_hidden_dim: Optional[int] = None,
+ classifier_hidden_dim: Optional[int] = None,
+ ):
+ super().__init__(dataset=dataset)
+
+ assert len(self.label_keys) == 1, "Only one label key is supported"
+ self.label_key = self.label_keys[0]
+
+ if code_feature_key not in self.feature_keys:
+ raise ValueError(
+ f"code_feature_key='{code_feature_key}' not found in dataset input schema: "
+ f"{self.feature_keys}"
+ )
+ if value_feature_key not in self.feature_keys:
+ raise ValueError(
+ f"value_feature_key='{value_feature_key}' not found in dataset input schema: "
+ f"{self.feature_keys}"
+ )
+
+ self.code_feature_key = code_feature_key
+ self.value_feature_key = value_feature_key
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+
+ num_labs = self.dataset.input_processors[self.code_feature_key].size()
+ if num_labs is None:
+ raise ValueError(
+ "LabradorModel requires a tokenized categorical code feature with known vocabulary size. "
+ f"Feature '{self.code_feature_key}' returned size=None."
+ )
+ self.num_labs = int(num_labs)
+
+ ff_hidden_dim = ff_hidden_dim or embed_dim
+ classifier_hidden_dim = classifier_hidden_dim or embed_dim
+
+ self.code_embedding = nn.Embedding(self.num_labs, embed_dim)
+ self.value_projection = nn.Linear(1, embed_dim)
+ self.value_fusion = nn.Linear(embed_dim, embed_dim)
+ self.value_act = nn.ReLU()
+ self.value_norm = nn.LayerNorm(embed_dim)
+
+ encoder_layer = nn.TransformerEncoderLayer(
+ d_model=embed_dim,
+ nhead=num_heads,
+ dim_feedforward=ff_hidden_dim,
+ dropout=dropout,
+ batch_first=True,
+ activation="relu",
+ )
+ self.transformer = nn.TransformerEncoder(
+ encoder_layer=encoder_layer,
+ num_layers=num_layers,
+ )
+
+ output_size = self.get_output_size()
+ self.classifier = nn.Sequential(
+ nn.Linear(embed_dim, classifier_hidden_dim),
+ nn.ReLU(),
+ nn.Linear(classifier_hidden_dim, output_size),
+ )
+
+ def _extract_value_and_mask(
+ self, feature_key: str, feature: torch.Tensor | Tuple[torch.Tensor, ...]
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Extract value tensor and optional mask from a processed feature.
+
+ Args:
+ feature_key: Name of the dataset feature being decoded.
+ feature: Processor output for one feature. This may be either a
+ tensor or a tuple containing tensors such as value and mask.
+
+ Returns:
+ A tuple ``(value, mask)`` where ``value`` is the feature tensor and
+ ``mask`` is the optional validity mask if provided by the processor.
+
+ Raises:
+ ValueError: If the processor schema does not contain a ``value``
+ field.
+ """
+ if isinstance(feature, torch.Tensor):
+ feature_tuple: Tuple[torch.Tensor, ...] = (feature,)
+ else:
+ feature_tuple = feature
+
+ schema = self.dataset.input_processors[feature_key].schema()
+
+ value = feature_tuple[schema.index("value")] if "value" in schema else None
+ if value is None:
+ raise ValueError(
+ f"Feature '{feature_key}' must contain 'value' in processor schema."
+ )
+
+ mask = feature_tuple[schema.index("mask")] if "mask" in schema else None
+ if len(feature_tuple) == len(schema) + 1 and mask is None:
+ mask = feature_tuple[-1]
+
+ return value, mask
+
+ @staticmethod
+ def _ensure_2d(tensor: torch.Tensor, name: str) -> torch.Tensor:
+ """Normalize tensor to ``[batch, num_labs]`` for aligned lab streams.
+
+ Args:
+ tensor: Input tensor representing lab codes, values, or masks.
+ name: Human-readable tensor name used in error messages.
+
+ Returns:
+ A 2D tensor of shape ``[batch, num_labs]``.
+
+ Raises:
+ ValueError: If the input tensor cannot be interpreted as a 2D lab
+ matrix.
+ """
+ if tensor.dim() == 2:
+ return tensor
+ if tensor.dim() == 3 and tensor.size(-1) == 1:
+ return tensor.squeeze(-1)
+ raise ValueError(
+ f"Expected {name} to have shape [batch, num_labs] "
+ f"(or [..., 1]), got {tuple(tensor.shape)}"
+ )
+
+ @staticmethod
+ def _mean_pool(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ """Apply masked mean pooling over the lab dimension.
+
+ Args:
+ x: Token representations of shape ``[batch, num_labs, embed_dim]``.
+ mask: Float mask of shape ``[batch, num_labs]`` indicating valid
+ lab positions.
+
+ Returns:
+ Pooled patient embeddings of shape ``[batch, embed_dim]``.
+ """
+ denom = mask.sum(dim=1, keepdim=True).clamp(min=1.0)
+ return (x * mask.unsqueeze(-1)).sum(dim=1) / denom
+
+ def forward(
+ self,
+ **kwargs: torch.Tensor | Tuple[torch.Tensor, ...],
+ ) -> Dict[str, torch.Tensor]:
+ """Forward propagation.
+
+ Args:
+ **kwargs: Keyword arguments containing the code feature,
+ value feature, optional label, and optional ``embed`` flag.
+ The two feature tensors must describe aligned lab sequences for
+ the same samples.
+
+ Returns:
+ A dictionary containing model outputs. This always includes:
+ - ``logit``: classification logits.
+ - ``y_prob``: predicted probabilities.
+ When labels are provided, the dictionary also includes:
+ - ``loss``: supervised task loss.
+ - ``y_true``: ground-truth labels.
+ When ``embed=True`` is passed, the dictionary also includes:
+ - ``embed``: pooled patient embedding.
+
+ Raises:
+ ValueError: If code and value features do not have matching shapes.
+ """
+ code_values, code_mask = self._extract_value_and_mask(
+ self.code_feature_key, kwargs[self.code_feature_key]
+ )
+ lab_values, value_mask = self._extract_value_and_mask(
+ self.value_feature_key, kwargs[self.value_feature_key]
+ )
+
+ codes = self._ensure_2d(code_values.to(self.device), "code feature").long()
+ values = self._ensure_2d(lab_values.to(self.device), "value feature").float()
+
+ if codes.shape != values.shape:
+ raise ValueError(
+ f"Code/value feature shapes must match, got codes={tuple(codes.shape)} "
+ f"and values={tuple(values.shape)}"
+ )
+
+ if code_mask is not None:
+ mask = self._ensure_2d(code_mask.to(self.device), "code mask").bool()
+ elif value_mask is not None:
+ mask = self._ensure_2d(value_mask.to(self.device), "value mask").bool()
+ else:
+ mask = codes != 0
+
+ # Avoid all-masked rows to keep transformer behavior stable.
+ invalid_rows = ~mask.any(dim=1)
+ if invalid_rows.any():
+ mask[invalid_rows, 0] = True
+
+ code_emb = self.code_embedding(codes)
+ value_emb = self.value_projection(values.unsqueeze(-1))
+
+ x = code_emb + value_emb
+ x = self.value_fusion(x)
+ x = self.value_act(x)
+ x = self.value_norm(x)
+
+ x = self.transformer(x, src_key_padding_mask=~mask)
+ patient_emb = self._mean_pool(x, mask.float())
+ logits = self.classifier(patient_emb)
+ y_prob = self.prepare_y_prob(logits)
+
+ results: Dict[str, torch.Tensor] = {
+ "logit": logits,
+ "y_prob": y_prob,
+ }
+
+ if self.label_key in kwargs:
+ y_true = cast(torch.Tensor, kwargs[self.label_key]).to(self.device)
+ loss = self.get_loss_function()(logits, y_true)
+ results["loss"] = loss
+ results["y_true"] = y_true
+
+ if kwargs.get("embed", False):
+ results["embed"] = patient_emb
+
+ return results
\ No newline at end of file
diff --git a/tests/core/test_labrador.py b/tests/core/test_labrador.py
new file mode 100644
index 000000000..5986bf3f9
--- /dev/null
+++ b/tests/core/test_labrador.py
@@ -0,0 +1,125 @@
+import unittest
+import torch
+
+from pyhealth.datasets import create_sample_dataset, get_dataloader
+from pyhealth.models import LabradorModel
+
+
+class TestLabradorModel(unittest.TestCase):
+ """Minimal smoke test for LabradorModel."""
+
+ def setUp(self):
+ """Set up test data and model."""
+ # Create minimal synthetic samples with aligned lab codes and values
+ # Both lab_codes and lab_values represent the same 4 labs per sample
+ self.samples = [
+ {
+ "patient_id": "patient-0",
+ "visit_id": "visit-0",
+ "lab_codes": ["lab-1", "lab-2", "lab-3", "lab-4"],
+ "lab_values": [1.0, 2.5, 3.0, 4.5],
+ "label": 0,
+ },
+ {
+ "patient_id": "patient-1",
+ "visit_id": "visit-1",
+ "lab_codes": ["lab-1", "lab-2", "lab-3", "lab-4"],
+ "lab_values": [2.1, 1.8, 2.9, 3.5],
+ "label": 1,
+ },
+ ]
+
+ # Define input and output schemas
+ self.input_schema = {
+ "lab_codes": "sequence", # Categorical lab codes
+ "lab_values": "tensor", # Continuous lab values
+ }
+ self.output_schema = {"label": "binary"} # Binary classification
+
+ # Create dataset
+ self.dataset = create_sample_dataset(
+ samples=self.samples,
+ input_schema=self.input_schema,
+ output_schema=self.output_schema,
+ dataset_name="test_labrador",
+ )
+
+ # Create model
+ self.model = LabradorModel(
+ dataset=self.dataset,
+ code_feature_key="lab_codes",
+ value_feature_key="lab_values",
+ embed_dim=32,
+ num_heads=2,
+ num_layers=1,
+ )
+
+ def test_model_initialization(self):
+ """Test that LabradorModel initializes correctly."""
+ self.assertIsInstance(self.model, LabradorModel)
+ self.assertEqual(self.model.embed_dim, 32)
+ self.assertEqual(self.model.num_heads, 2)
+ self.assertEqual(self.model.num_layers, 1)
+ self.assertEqual(self.model.code_feature_key, "lab_codes")
+ self.assertEqual(self.model.value_feature_key, "lab_values")
+ self.assertEqual(self.model.label_key, "label")
+
+ def test_model_forward(self):
+ """Test that LabradorModel forward pass works correctly."""
+ # Create data loader
+ train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
+ data_batch = next(iter(train_loader))
+
+ # Forward pass
+ with torch.no_grad():
+ ret = self.model(**data_batch)
+
+ # Check output structure
+ self.assertIn("loss", ret)
+ self.assertIn("y_prob", ret)
+ self.assertIn("y_true", ret)
+ self.assertIn("logit", ret)
+
+ # Check tensor shapes
+ self.assertEqual(ret["y_prob"].shape[0], 2) # batch size
+ self.assertEqual(ret["y_true"].shape[0], 2) # batch size
+ self.assertEqual(ret["logit"].shape[0], 2) # batch size
+
+ # Check that loss is a scalar
+ self.assertEqual(ret["loss"].dim(), 0)
+
+ def test_model_backward(self):
+ """Test that LabradorModel backward pass works correctly."""
+ # Create data loader
+ train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
+ data_batch = next(iter(train_loader))
+
+ # Forward pass
+ ret = self.model(**data_batch)
+
+ # Backward pass
+ ret["loss"].backward()
+
+ # Check that at least one parameter has gradients
+ has_gradient = any(
+ param.requires_grad and param.grad is not None
+ for param in self.model.parameters()
+ )
+ self.assertTrue(has_gradient, "No parameters have gradients after backward pass")
+
+ def test_model_with_embedding(self):
+ """Test that LabradorModel returns embeddings when requested."""
+ train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True)
+ data_batch = next(iter(train_loader))
+ data_batch["embed"] = True
+
+ with torch.no_grad():
+ ret = self.model(**data_batch)
+
+ self.assertIn("embed", ret)
+ self.assertEqual(ret["embed"].shape[0], 2) # batch size
+ self.assertEqual(ret["embed"].shape[1], 32) # embed_dim
+
+
+if __name__ == "__main__":
+ unittest.main()