From 770363dc4492cf545137f59681a62f18ad330c20 Mon Sep 17 00:00:00 2001 From: Matthew Ardi Date: Sun, 5 Apr 2026 15:27:36 -0500 Subject: [PATCH 1/3] HiCu --- docs/api/models.rst | 1 + docs/api/models/pyhealth.models.HiCu.rst | 24 ++ docs/api/tasks.rst | 1 + .../pyhealth.tasks.MIMIC4ICD10Coding.rst | 7 + examples/mimic4_icd10_coding_hicu.py | 209 ++++++++++ pyhealth/models/__init__.py | 1 + pyhealth/models/hicu.py | 393 ++++++++++++++++++ pyhealth/tasks/__init__.py | 2 +- pyhealth/tasks/medical_coding.py | 141 +++---- tests/core/test_hicu.py | 181 ++++++++ tests/core/test_mimic4_icd10_coding.py | 169 ++++++++ 11 files changed, 1045 insertions(+), 84 deletions(-) create mode 100644 docs/api/models/pyhealth.models.HiCu.rst create mode 100644 docs/api/tasks/pyhealth.tasks.MIMIC4ICD10Coding.rst create mode 100644 examples/mimic4_icd10_coding_hicu.py create mode 100644 pyhealth/models/hicu.py create mode 100644 tests/core/test_hicu.py create mode 100644 tests/core/test_mimic4_icd10_coding.py diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..6877cf0ac 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -194,6 +194,7 @@ API Reference models/pyhealth.models.ConCare models/pyhealth.models.Agent models/pyhealth.models.GRASP + models/pyhealth.models.HiCu models/pyhealth.models.MedLink models/pyhealth.models.TCN models/pyhealth.models.TFMTokenizer diff --git a/docs/api/models/pyhealth.models.HiCu.rst b/docs/api/models/pyhealth.models.HiCu.rst new file mode 100644 index 000000000..9891dd93c --- /dev/null +++ b/docs/api/models/pyhealth.models.HiCu.rst @@ -0,0 +1,24 @@ +pyhealth.models.HiCu +=================================== + +HiCu model for automated ICD coding with hierarchical curriculum learning. + +.. autoclass:: pyhealth.models.hicu.AsymmetricLoss + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.hicu.MultiResCNNEncoder + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.hicu.HierarchicalDecoder + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.hicu.HiCu + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..c9c6f074c 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -208,6 +208,7 @@ Available Tasks Base Task In-Hospital Mortality (MIMIC-IV) MIMIC-III ICD-9 Coding + MIMIC-IV ICD-10 Coding Cardiology Detection COVID-19 CXR Classification DKA Prediction (MIMIC-IV) diff --git a/docs/api/tasks/pyhealth.tasks.MIMIC4ICD10Coding.rst b/docs/api/tasks/pyhealth.tasks.MIMIC4ICD10Coding.rst new file mode 100644 index 000000000..de7f9e606 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.MIMIC4ICD10Coding.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.MIMIC4ICD10Coding +=========================================== + +.. autoclass:: pyhealth.tasks.medical_coding.MIMIC4ICD10Coding + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/mimic4_icd10_coding_hicu.py b/examples/mimic4_icd10_coding_hicu.py new file mode 100644 index 000000000..da37107b7 --- /dev/null +++ b/examples/mimic4_icd10_coding_hicu.py @@ -0,0 +1,209 @@ +"""HiCu ICD-10 coding example with training experiments. + +Runs in synthetic mode by default, or on real MIMIC-IV data with --data-dir. + +Usage: + python examples/mimic4_icd10_coding_hicu.py + python examples/mimic4_icd10_coding_hicu.py --data-dir /path/to/mimic-iv +""" + +import argparse +import torch +from pyhealth.datasets import MIMIC4Dataset, create_sample_dataset, get_dataloader +from pyhealth.models.hicu import HiCu +from pyhealth.tasks import MIMIC4ICD10Coding +from pyhealth.trainer import Trainer + + +def create_synthetic_dataset(): + """Create a small synthetic ICD-10 multilabel dataset.""" + samples = [ + { + "patient_id": "p0", "visit_id": "v0", + "text": ["patient", "admitted", "with", "type", "two", "diabetes", "and", "hypertension"], + "icd_codes": ["E11.321", "I10", "J44.1"], + }, + { + "patient_id": "p1", "visit_id": "v1", + "text": ["chest", "pain", "shortness", "of", "breath", "elevated", "troponin"], + "icd_codes": ["I21.09", "I11.0", "I10"], + }, + { + "patient_id": "p2", "visit_id": "v2", + "text": ["abdominal", "pain", "nausea", "vomiting", "gastroesophageal", "reflux"], + "icd_codes": ["K21.0", "E11.65"], + }, + { + "patient_id": "p3", "visit_id": "v3", + "text": ["chronic", "obstructive", "pulmonary", "disease", "exacerbation", "with", "respiratory", "failure"], + "icd_codes": ["J44.1", "E11.321", "I10"], + }, + { + "patient_id": "p4", "visit_id": "v4", + "text": ["heart", "failure", "with", "reduced", "ejection", "fraction", "diuretic", "therapy"], + "icd_codes": ["I11.0", "I21.09", "K21.0"], + }, + ] + return create_sample_dataset( + samples=samples, + input_schema={"text": "sequence"}, + output_schema={"icd_codes": "multilabel"}, + dataset_name="mimic4_icd10_synthetic", + ) + + +def load_mimic4_dataset(data_dir: str, dev: bool = False): + """Load MIMIC-IV data and apply the ICD-10 coding task.""" + ds = MIMIC4Dataset( + ehr_root=data_dir, + note_root=data_dir, + ehr_tables=["diagnoses_icd"], + note_tables=["discharge"], + dev=dev, + ) + task = MIMIC4ICD10Coding() + return ds.set_task(task) + + +def train_with_curriculum(model, train_loader, depth_epochs, device="cpu") -> float: + """Train through progressively finer hierarchy depths, returning final loss.""" + final_loss = 0.0 + for depth in sorted(depth_epochs.keys()): + model.set_depth(depth) + epochs = depth_epochs[depth] + print(f" Depth {depth} ({model.depth_sizes[depth]} codes): training for {epochs} epochs...") + + trainer = Trainer(model=model, device=device, enable_logging=False) + trainer.train( + train_dataloader=train_loader, + epochs=epochs, + optimizer_class=torch.optim.Adam, + optimizer_params={"lr": 1e-3}, + ) + + model.eval() + with torch.no_grad(): + batch = next(iter(train_loader)) + ret = model(**{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}) + final_loss = ret["loss"].item() + model.train() + print(f" -> Loss at depth {depth}: {final_loss:.4f}") + + return final_loss + + +def train_flat(model, train_loader, epochs, device="cpu") -> float: + """Train at the finest depth only (no curriculum), returning final loss.""" + model.set_depth(2) + print(f" Flat training at depth 2 ({model.depth_sizes[2]} codes): {epochs} epochs...") + + trainer = Trainer(model=model, device=device, enable_logging=False) + trainer.train( + train_dataloader=train_loader, + epochs=epochs, + optimizer_class=torch.optim.Adam, + optimizer_params={"lr": 1e-3}, + ) + + model.eval() + with torch.no_grad(): + batch = next(iter(train_loader)) + ret = model(**{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}) + loss = ret["loss"].item() + model.train() + print(f" -> Final loss: {loss:.4f}") + return loss + + +def main() -> None: + parser = argparse.ArgumentParser(description="HiCu ICD-10 coding example") + parser.add_argument("--data-dir", type=str, default=None, + help="Path to MIMIC-IV data directory (hosp/, note/ subdirs). " + "If omitted, uses synthetic data.") + parser.add_argument("--dev", action="store_true", + help="Use dev mode (limit to 1000 patients) for faster iteration.") + parser.add_argument("--epochs-d0", type=int, default=3, help="Epochs at depth 0 (chapters)") + parser.add_argument("--epochs-d1", type=int, default=5, help="Epochs at depth 1 (categories)") + parser.add_argument("--epochs-d2", type=int, default=10, help="Epochs at depth 2 (full codes)") + parser.add_argument("--batch-size", type=int, default=4, help="Training batch size") + parser.add_argument("--num-filter-maps", type=int, default=50, help="CNN filter maps") + parser.add_argument("--embedding-dim", type=int, default=100, help="Word embedding dimension") + args = parser.parse_args() + + device = "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + print(f"Using device: {device}") + + # --- Load dataset --- + if args.data_dir: + print(f"\nLoading real MIMIC-IV data from {args.data_dir}...") + dataset = load_mimic4_dataset(args.data_dir, dev=args.dev) + print(f"Loaded {len(dataset)} samples") + # Use larger defaults for real data + if args.num_filter_maps == 50 and args.embedding_dim == 100: + print("Using default hyperparameters (num_filter_maps=50, embedding_dim=100)") + else: + print("\nUsing synthetic dataset (pass --data-dir for real MIMIC-IV data)") + dataset = create_synthetic_dataset() + # Override to smaller dims for synthetic data + args.num_filter_maps = 16 + args.embedding_dim = 32 + + train_loader = get_dataloader(dataset, batch_size=args.batch_size, shuffle=True) + + base_kwargs = dict( + num_filter_maps=args.num_filter_maps, + embedding_dim=args.embedding_dim, + kernel_sizes=[3, 5, 9], + ) + depth_epochs = {0: args.epochs_d0, 1: args.epochs_d1, 2: args.epochs_d2} + total_flat_epochs = sum(depth_epochs.values()) + + results = {} + + # Experiment 1: Curriculum + ASL (baseline) + print("\n=== Experiment 1: Curriculum + ASL (baseline) ===") + model1 = HiCu(dataset, **base_kwargs) + print(f"Hierarchy depths: {model1.depth_sizes}") + results["curriculum+ASL"] = train_with_curriculum(model1, train_loader, depth_epochs, device) + + # Experiment 2: Flat training + ASL + print("\n=== Experiment 2: Flat (no curriculum) + ASL ===") + model2 = HiCu(dataset, **base_kwargs) + results["flat+ASL"] = train_flat(model2, train_loader, total_flat_epochs, device) + + # Experiment 3: Curriculum + BCE (no ASL) + print("\n=== Experiment 3: Curriculum + BCE (no ASL) ===") + model3 = HiCu(dataset, **base_kwargs, asl_gamma_neg=0.0, asl_gamma_pos=0.0, asl_clip=0.0) + results["curriculum+BCE"] = train_with_curriculum(model3, train_loader, depth_epochs, device) + + # Experiment 4: Curriculum + ASL + more filters + more_filters = args.num_filter_maps * 2 + print(f"\n=== Experiment 4: Curriculum + ASL + more filters ({more_filters}) ===") + model4 = HiCu(dataset, num_filter_maps=more_filters, embedding_dim=args.embedding_dim, kernel_sizes=[3, 5, 9]) + results[f"curriculum+ASL+filters{more_filters}"] = train_with_curriculum( + model4, train_loader, depth_epochs, device + ) + + # Summary + print("\n" + "=" * 60) + print("EXPERIMENT RESULTS SUMMARY") + print("=" * 60) + print(f"{'Configuration':<35} {'Final Loss':>12}") + print("-" * 60) + for config, loss in results.items(): + print(f"{config:<35} {loss:>12.4f}") + print("=" * 60) + if not args.data_dir: + print( + "\nNote: These results are on synthetic data." + "\nAbsolute values are not meaningful; the purpose is to demonstrate" + "\nthat all code paths execute correctly." + ) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..481b79c03 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -15,6 +15,7 @@ from .graph_torchvision_model import Graph_TorchvisionModel from .graphcare import GraphCare from .grasp import GRASP, GRASPLayer +from .hicu import HiCu, AsymmetricLoss, MultiResCNNEncoder, HierarchicalDecoder from .medlink import MedLink from .micron import MICRON, MICRONLayer from .mlp import MLP diff --git a/pyhealth/models/hicu.py b/pyhealth/models/hicu.py new file mode 100644 index 000000000..f18977538 --- /dev/null +++ b/pyhealth/models/hicu.py @@ -0,0 +1,393 @@ +from typing import Dict, List, Optional + +import torch +import torch.nn as nn + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel + +ICD10_CHAPTER_MAP: Dict[str, str] = { + "A": "A00-B99", + "B": "A00-B99", + "C": "C00-D49", + "D": "C00-D49", # D50-D89 is a separate chapter; handled below + "E": "E00-E89", + "F": "F01-F99", + "G": "G00-G99", + "H": "H00-H59", # H60-H95 is a separate chapter; handled below + "I": "I00-I99", + "J": "J00-J99", + "K": "K00-K95", + "L": "L00-L99", + "M": "M00-M99", + "N": "N00-N99", + "O": "O00-O9A", + "P": "P00-P96", + "Q": "Q00-Q99", + "R": "R00-R99", + "S": "S00-T88", + "T": "S00-T88", + "V": "V00-Y99", + "W": "V00-Y99", + "X": "V00-Y99", + "Y": "V00-Y99", + "Z": "Z00-Z99", + "U": "U00-U85", +} + +# D and H each span two chapters, split by numeric part. +_D_SPLIT = 50 +_H_SPLIT = 60 + + +def _get_icd10_chapter(code: str) -> str: + """Map an ICD-10-CM code to its chapter range (e.g. "E11.321" -> "E00-E89").""" + if not code or not code[0].isalpha(): + raise ValueError(f"Invalid ICD-10 code: {code}") + + first = code[0].upper() + if first == "D": + numeric = int(code[1:3]) if len(code) >= 3 and code[1:3].isdigit() else 0 + return "D50-D89" if numeric >= _D_SPLIT else "C00-D49" + if first == "H": + numeric = int(code[1:3]) if len(code) >= 3 and code[1:3].isdigit() else 0 + return "H60-H95" if numeric >= _H_SPLIT else "H00-H59" + + chapter = ICD10_CHAPTER_MAP.get(first) + if chapter is None: + raise ValueError(f"Unknown ICD-10 chapter for code: {code}") + return chapter + + +def _get_icd10_category(code: str) -> str: + """First 3 characters of the code (e.g. "E11.321" -> "E11").""" + return code.replace(".", "")[:3].upper() + + +def build_icd10_hierarchy(code_list: List[str]) -> Dict: + """Build a 3-level ICD-10-CM hierarchy (chapter -> category -> full code). + + Args: + code_list: ICD-10-CM codes from the dataset label vocabulary. + + Returns: + Dict with depth_to_codes, code_to_index, parent_to_children, + and child_to_parent mappings for the decoder. + """ + if not code_list: + raise ValueError("code_list must not be empty") + + full_codes = sorted(set(code_list)) + chapters = sorted({_get_icd10_chapter(c) for c in full_codes}) + categories = sorted({_get_icd10_category(c) for c in full_codes}) + + depth_to_codes: Dict[int, List[str]] = { + 0: chapters, + 1: categories, + 2: full_codes, + } + code_to_index: Dict[int, Dict[str, int]] = { + d: {code: idx for idx, code in enumerate(codes)} + for d, codes in depth_to_codes.items() + } + + parent_to_children: Dict[int, Dict[int, List[int]]] = {0: {}, 1: {}} + child_to_parent: Dict[int, Dict[int, int]] = {1: {}, 2: {}} + + for cat in categories: + chapter = _get_icd10_chapter(cat) + p_idx = code_to_index[0][chapter] + c_idx = code_to_index[1][cat] + parent_to_children[0].setdefault(p_idx, []).append(c_idx) + child_to_parent[1][c_idx] = p_idx + + for fc in full_codes: + cat = _get_icd10_category(fc) + p_idx = code_to_index[1][cat] + c_idx = code_to_index[2][fc] + parent_to_children[1].setdefault(p_idx, []).append(c_idx) + child_to_parent[2][c_idx] = p_idx + + return { + "depth_to_codes": depth_to_codes, + "code_to_index": code_to_index, + "parent_to_children": parent_to_children, + "child_to_parent": child_to_parent, + } + + +class AsymmetricLoss(nn.Module): + """Asymmetric focal loss for sparse multi-label classification (Ben-Baruch et al., 2020). + + Args: + gamma_neg: Focusing parameter for negatives. + gamma_pos: Focusing parameter for positives. + clip: Probability clipping threshold for negatives. + """ + + def __init__( + self, + gamma_neg: float = 4.0, + gamma_pos: float = 1.0, + clip: float = 0.05, + ): + super().__init__() + self.gamma_neg = gamma_neg + self.gamma_pos = gamma_pos + self.clip = clip + + def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + p = torch.sigmoid(logits) + + loss_pos = targets * torch.clamp( + (1 - p).pow(self.gamma_pos) * torch.log(p + 1e-8), + min=-100, + ) + + p_neg = (p - self.clip).clamp(min=1e-8) + loss_neg = (1 - targets) * torch.clamp( + p_neg.pow(self.gamma_neg) * torch.log(1 - p_neg + 1e-8), + min=-100, + ) + + return -(loss_pos + loss_neg).mean() + + +class _ResidualConvBlock(nn.Module): + """Single convolutional channel with a residual connection.""" + + def __init__(self, in_channels: int, out_channels: int, kernel_size: int): + super().__init__() + padding = kernel_size // 2 + self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding) + self.bn1 = nn.BatchNorm1d(out_channels) + self.relu = nn.ReLU() + self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding=padding) + self.bn2 = nn.BatchNorm1d(out_channels) + + self.downsample = None + if in_channels != out_channels: + self.downsample = nn.Sequential( + nn.Conv1d(in_channels, out_channels, kernel_size=1), + nn.BatchNorm1d(out_channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + out = self.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + if self.downsample is not None: + residual = self.downsample(residual) + return self.relu(out + residual) + + +class MultiResCNNEncoder(nn.Module): + """Parallel Conv1d branches with different kernel sizes, merged via 1x1 conv. + + Args: + input_dim: Input channels (embedding dimension). + num_filter_maps: Output channels after merge. + kernel_sizes: Kernel sizes for parallel branches. + """ + + def __init__( + self, + input_dim: int, + num_filter_maps: int, + kernel_sizes: Optional[List[int]] = None, + ): + super().__init__() + if kernel_sizes is None: + kernel_sizes = [3, 5, 9, 15, 19, 25] + self.branches = nn.ModuleList( + [_ResidualConvBlock(input_dim, num_filter_maps, ks) for ks in kernel_sizes] + ) + self.merge = nn.Conv1d( + num_filter_maps * len(kernel_sizes), num_filter_maps, kernel_size=1 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + branch_outputs = [branch(x) for branch in self.branches] + concatenated = torch.cat(branch_outputs, dim=1) + return self.merge(concatenated) + + +class HierarchicalDecoder(nn.Module): + """Per-label attention decoder with curriculum weight transfer between depths. + + Args: + num_filter_maps: Encoder output channels. + depth_sizes: Number of codes at each depth. + child_to_parent: Child-to-parent index mapping per depth. + """ + + def __init__( + self, + num_filter_maps: int, + depth_sizes: List[int], + child_to_parent: Dict[int, Dict[int, int]], + ): + super().__init__() + self.num_filter_maps = num_filter_maps + self.depth_sizes = depth_sizes + self.child_to_parent = child_to_parent + self.current_depth = len(depth_sizes) - 1 # default: finest + + self.attention = nn.ModuleList() + self.classifiers = nn.ModuleList() + for size in depth_sizes: + self.attention.append(nn.Linear(num_filter_maps, size)) + self.classifiers.append(nn.Linear(num_filter_maps, size)) + + def set_depth(self, depth: int) -> None: + """Switch active depth and copy parent weights to child positions.""" + if depth < 0 or depth >= len(self.depth_sizes): + raise ValueError( + f"depth must be in [0, {len(self.depth_sizes) - 1}], got {depth}" + ) + self.current_depth = depth + + if depth == 0: + return # Nothing to transfer at the coarsest level. + + c2p = self.child_to_parent.get(depth, {}) + if not c2p: + return + + parent_attn = self.attention[depth - 1] + parent_cls = self.classifiers[depth - 1] + child_attn = self.attention[depth] + child_cls = self.classifiers[depth] + + with torch.no_grad(): + for child_idx, parent_idx in c2p.items(): + child_attn.weight.data[child_idx] = parent_attn.weight.data[parent_idx] + child_attn.bias.data[child_idx] = parent_attn.bias.data[parent_idx] + child_cls.weight.data[child_idx] = parent_cls.weight.data[parent_idx] + child_cls.bias.data[child_idx] = parent_cls.bias.data[parent_idx] + + def forward(self, encoded: torch.Tensor) -> torch.Tensor: + d = self.current_depth + + # Per-label attention + attn_scores = self.attention[d](encoded.transpose(1, 2)) # (B, S, C) + attn_weights = torch.softmax(attn_scores, dim=1) + context = attn_weights.transpose(1, 2) @ encoded.transpose(1, 2) # (B, C, F) + + # Per-label classification via element-wise multiply + sum + logits = (context * self.classifiers[d].weight).sum(dim=2) + self.classifiers[d].bias + return logits + + +class HiCu(BaseModel): + """HiCu: Hierarchical Curriculum Learning for ICD coding. + + MultiResCNN encoder + per-label attention decoder with 3-level ICD-10 + hierarchy (chapter -> category -> full code). Call set_depth() between + training stages to transfer weights from coarse to fine codes. + + Paper: Ren et al., ML4H 2022. https://arxiv.org/abs/2208.02301 + + Args: + dataset: SampleDataset with multilabel ICD-10-CM output. + num_filter_maps: CNN output channels. + embedding_dim: Word embedding dimension. + kernel_sizes: Kernel sizes for the multi-resolution CNN. + asl_gamma_neg: ASL focusing parameter for negatives. + asl_gamma_pos: ASL focusing parameter for positives. + asl_clip: ASL probability clipping threshold. + """ + + def __init__( + self, + dataset: SampleDataset, + num_filter_maps: int = 50, + embedding_dim: int = 100, + kernel_sizes: Optional[List[int]] = None, + asl_gamma_neg: float = 4.0, + asl_gamma_pos: float = 1.0, + asl_clip: float = 0.05, + **kwargs, + ): + super(HiCu, self).__init__(dataset=dataset) + + if kernel_sizes is None: + kernel_sizes = [3, 5, 9, 15, 19, 25] + + assert len(self.label_keys) == 1, "HiCu supports exactly one label key" + self.label_key = self.label_keys[0] + assert len(self.feature_keys) == 1, "HiCu expects exactly one text feature" + self.text_key = self.feature_keys[0] + + self.num_filter_maps = num_filter_maps + self.embedding_dim = embedding_dim + self.kernel_sizes = kernel_sizes + + label_processor = self.dataset.output_processors[self.label_key] + label_vocab = list(label_processor.label_vocab.keys()) + self.hierarchy = build_icd10_hierarchy(label_vocab) + + depth_sizes = [len(self.hierarchy["depth_to_codes"][d]) for d in range(3)] + self.depth_sizes = depth_sizes + + input_processor = self.dataset.input_processors[self.text_key] + vocab_size = len(input_processor.code_vocab) + self.word_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) + + self.encoder = MultiResCNNEncoder(embedding_dim, num_filter_maps, kernel_sizes) + self.decoder = HierarchicalDecoder( + num_filter_maps, depth_sizes, self.hierarchy["child_to_parent"] + ) + self.asl_loss = AsymmetricLoss(asl_gamma_neg, asl_gamma_pos, asl_clip) + + self._build_label_mappings() + self.current_depth = 2 + + def _build_label_mappings(self) -> None: + hierarchy = self.hierarchy + full_codes = hierarchy["depth_to_codes"][2] + full_idx = hierarchy["code_to_index"][2] + + for d in range(2): + n_full = len(full_codes) + n_depth = len(hierarchy["depth_to_codes"][d]) + mapping = torch.zeros(n_full, n_depth) + + for fc in full_codes: + fi = full_idx[fc] + if d == 0: + ancestor = _get_icd10_chapter(fc) + else: + ancestor = _get_icd10_category(fc) + ai = hierarchy["code_to_index"][d][ancestor] + mapping[fi, ai] = 1.0 + + self.register_buffer(f"_label_map_{d}", mapping) + + def _remap_labels(self, y_true: torch.Tensor, depth: int) -> torch.Tensor: + if depth == 2: + return y_true + mapping = getattr(self, f"_label_map_{depth}") + return (y_true @ mapping).clamp(max=1.0) + + def set_depth(self, depth: int) -> None: + """Switch hierarchy depth (0=chapters, 1=categories, 2=full codes) and transfer weights.""" + self.current_depth = depth + self.decoder.set_depth(depth) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + text = kwargs[self.text_key] + if isinstance(text, tuple): + text = text[0] + text = text.to(self.device) + + embedded = self.word_embedding(text) + encoded = self.encoder(embedded.permute(0, 2, 1)) + logits = self.decoder(encoded) + + y_true_full = kwargs[self.label_key].to(self.device).float() + y_true = self._remap_labels(y_true_full, self.current_depth) + loss = self.asl_loss(logits, y_true) + y_prob = torch.sigmoid(logits) + + return {"loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits} diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..cc534a3c5 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -29,7 +29,7 @@ LengthOfStayPredictionOMOP, ) from .length_of_stay_stagenet_mimic4 import LengthOfStayStageNetMIMIC4 -from .medical_coding import MIMIC3ICD9Coding +from .medical_coding import MIMIC3ICD9Coding, MIMIC4ICD10Coding from .medical_transcriptions_classification import MedicalTranscriptionsClassification from .mortality_prediction import ( MortalityPredictionEICU, diff --git a/pyhealth/tasks/medical_coding.py b/pyhealth/tasks/medical_coding.py index 739c674d1..1c5c41ff8 100644 --- a/pyhealth/tasks/medical_coding.py +++ b/pyhealth/tasks/medical_coding.py @@ -6,7 +6,7 @@ from dataclasses import field from datetime import datetime from typing import Dict, List, Union, Type -from pyhealth.processors import TextProcessor, MultiLabelProcessor +from pyhealth.processors import TextProcessor, MultiLabelProcessor, SequenceProcessor import polars as pl from pyhealth.data.data import Patient @@ -139,49 +139,74 @@ def __call__(self, patient: Patient) -> List[Dict]: # return [{"text": text, "icd_codes": list(icd_codes)}] -# @dataclass(frozen=True) -# class MIMIC4ICD10Coding(TaskTemplate): -# """Medical coding task for MIMIC-IV using ICD-10 codes. -# This task uses discharge notes to predict ICD-10 codes for a patient. +MAX_TOKENS = 4000 -# Args: -# task_name: Name of the task -# input_schema: Definition of the input data schema -# output_schema: Definition of the output data schema -# """ -# task_name: str = "mimic4_icd10_coding" -# input_schema: Dict[str, str] = field(default_factory=lambda: {"text": "str"}) -# output_schema: Dict[str, str] = field(default_factory=lambda: {"icd_codes": "List[str]"}) -# def __call__(self, patient: Patient) -> List[Dict]: -# """Process a patient and extract the discharge notes and ICD-9 codes.""" -# text = "" -# icd_codes = set() +def _tokenize_clinical_text(text: str) -> List[str]: + """Lowercase, split on whitespace, and truncate to MAX_TOKENS.""" + return text.lower().split()[:MAX_TOKENS] -# for event in patient.events: -# event_type = event.type.lower() if isinstance(event.type, str) else "" -# # Look for "value" instead of "code" for clinical notes -# if event_type == "clinical_note": -# if "value" in event.attr_dict: -# text += event.attr_dict["value"] +class MIMIC4ICD10Coding(BaseTask): + """MIMIC-IV ICD-10 coding task. Filters to icd_version=10 only.""" -# vocabulary = event.attr_dict.get("vocabulary", "").upper() -# if vocabulary == "ICD10CM": -# if event_type == "diagnoses_icd" or event_type == "procedures_icd": -# if "code" in event.attr_dict: -# icd_codes.add(event.attr_dict["code"]) + task_name: str = "mimic4_icd10_coding" + input_schema: Dict[str, Union[str, Type]] = {"text": SequenceProcessor} + output_schema: Dict[str, Union[str, Type]] = {"icd_codes": MultiLabelProcessor} -# if text == "" or len(icd_codes) < 1: -# return [] + def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: + """Keep only patients who have at least one discharge note.""" + filtered_df = df.filter( + pl.col("patient_id").is_in( + df.filter(pl.col("event_type") == "discharge") + .select("patient_id") + .unique() + .collect() + .to_series() + ) + ) + return filtered_df -# return [{"text": text, "icd_codes": list(icd_codes)}] + def __call__(self, patient: Patient) -> List[Dict]: + samples = [] + admissions = patient.get_events(event_type="admissions") + + for admission in admissions: + diagnoses = patient.get_events( + event_type="diagnoses_icd", + filters=[("hadm_id", "==", admission.hadm_id)], + ) + icd_codes = [ + e.icd_code + for e in diagnoses + if str(getattr(e, "icd_version", "")) == "10" + ] + + notes = patient.get_events( + event_type="discharge", + filters=[("hadm_id", "==", admission.hadm_id)], + ) + raw_text = " ".join(note.text for note in notes if hasattr(note, "text")) + tokens = _tokenize_clinical_text(raw_text) + + if not tokens or len(icd_codes) < 1: + continue + + samples.append( + { + "patient_id": patient.patient_id, + "text": tokens, + "icd_codes": list(set(icd_codes)), + } + ) + + return samples def main(): - # Test case for MIMIC4ICD9Coding and MIMIC3 - from pyhealth.datasets import MIMIC3Dataset, MIMIC4Dataset + """Quick smoke test for medical coding tasks.""" + from pyhealth.datasets import MIMIC3Dataset root = "/srv/local/data/MIMIC-III/mimic-iii-clinical-database-1.4" print("Testing MIMIC3ICD9Coding task...") @@ -193,61 +218,11 @@ def main(): dev=True, ) mimic3_coding = MIMIC3ICD9Coding() - # print(len(mimic3_coding.samples)) samples = dataset.set_task(mimic3_coding) - # Print sample information print(f"Total samples generated: {len(samples)}") if len(samples) > 0: - print("First sample:") print(f" - Text length: {len(samples[0]['text'])} characters") print(f" - Number of ICD codes: {len(samples[0]['icd_codes'])}") - if len(samples[0]["icd_codes"]) > 0: - print( - f" - Sample ICD codes: {samples[0]['icd_codes'][:5] if len(samples[0]['icd_codes']) > 5 else samples[0]['icd_codes']}" - ) - - # Initialize the dataset with dev mode enabled - print("Testing MIMIC4ICD9Coding task...") - dataset = MIMIC4Dataset( - root="/srv/local/data/MIMIC-IV/2.0/hosp", - tables=["diagnoses_icd", "procedures_icd"], - note_root="/srv/local/data/MIMIC-IV/2.0/note", - dev=True, - ) - # Create the task instance - mimic4_coding = MIMIC4ICD9Coding() - - # Generate samples - samples = dataset.set_task(mimic4_coding) - - # Print sample information - print(f"Total samples generated: {len(samples)}") - if len(samples) > 0: - print("First sample:") - print(f" - Text length: {len(samples[0]['text'])} characters") - print(f" - Number of ICD codes: {len(samples[0]['icd_codes'])}") - if len(samples[0]["icd_codes"]) > 0: - print( - f" - Sample ICD codes: {samples[0]['icd_codes'][:5] if len(samples[0]['icd_codes']) > 5 else samples[0]['icd_codes']}" - ) - - print("Testing MIMIC4ICD10Coding task... ") - - mimic4_coding = MIMIC4ICD10Coding() - - # Generate samples - samples = dataset.set_task(mimic4_coding) - - # Print sample information - print(f"Total samples generated: {len(samples)}") - if len(samples) > 0: - print("First sample:") - print(f" - Text length: {len(samples[0]['text'])} characters") - print(f" - Number of ICD codes: {len(samples[0]['icd_codes'])}") - if len(samples[0]["icd_codes"]) > 0: - print( - f" - Sample ICD codes: {samples[0]['icd_codes'][:5] if len(samples[0]['icd_codes']) > 5 else samples[0]['icd_codes']}" - ) if __name__ == "__main__": diff --git a/tests/core/test_hicu.py b/tests/core/test_hicu.py new file mode 100644 index 000000000..f3748a70b --- /dev/null +++ b/tests/core/test_hicu.py @@ -0,0 +1,181 @@ +import unittest + +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models.hicu import ( + AsymmetricLoss, + HiCu, + build_icd10_hierarchy, +) + + +class TestHiCu(unittest.TestCase): + """Test cases for the HiCu model.""" + + def setUp(self): + """Create a synthetic multilabel dataset with ICD-10 codes.""" + self.samples = [ + { + "patient_id": "p0", + "visit_id": "v0", + "text": ["patient", "has", "fever", "and", "cough"], + "icd_codes": ["E11.321", "I10", "J44.1"], + }, + { + "patient_id": "p1", + "visit_id": "v1", + "text": ["chest", "pain", "shortness", "of", "breath"], + "icd_codes": ["I21.09", "I11.0"], + }, + { + "patient_id": "p2", + "visit_id": "v2", + "text": ["abdominal", "pain", "nausea"], + "icd_codes": ["K21.0", "E11.65"], + }, + ] + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema={"text": "sequence"}, + output_schema={"icd_codes": "multilabel"}, + dataset_name="test_hicu", + ) + self.model = HiCu( + self.dataset, + num_filter_maps=8, + embedding_dim=16, + ) + + def test_model_initialization(self): + """HiCu instantiates with correct attributes.""" + self.assertIsInstance(self.model, HiCu) + self.assertEqual(self.model.num_filter_maps, 8) + self.assertEqual(self.model.embedding_dim, 16) + self.assertEqual(self.model.text_key, "text") + self.assertEqual(self.model.label_key, "icd_codes") + self.assertEqual(len(self.model.depth_sizes), 3) + # Depth 0 = chapters, depth 2 = full codes + self.assertLessEqual(self.model.depth_sizes[0], self.model.depth_sizes[2]) + + def test_forward_pass(self): + """Forward pass returns all required output keys.""" + loader = get_dataloader(self.dataset, batch_size=2) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + for key in ("loss", "y_prob", "y_true", "logit"): + self.assertIn(key, ret, f"Missing key: {key}") + self.assertEqual(ret["loss"].dim(), 0) # scalar + self.assertEqual(ret["y_prob"].shape[0], 2) # batch size + + def test_backward_pass(self): + """Gradients flow after loss.backward().""" + loader = get_dataloader(self.dataset, batch_size=2) + batch = next(iter(loader)) + + ret = self.model(**batch) + ret["loss"].backward() + + has_grad = any( + p.requires_grad and p.grad is not None for p in self.model.parameters() + ) + self.assertTrue(has_grad, "No gradients after backward pass") + + def test_output_shapes(self): + """Output shapes match num_codes at the current depth.""" + loader = get_dataloader(self.dataset, batch_size=2) + batch = next(iter(loader)) + + for depth in range(3): + self.model.set_depth(depth) + with torch.no_grad(): + ret = self.model(**batch) + expected_codes = self.model.depth_sizes[depth] + self.assertEqual( + ret["y_prob"].shape[1], + expected_codes, + f"y_prob width mismatch at depth {depth}", + ) + self.assertEqual( + ret["logit"].shape[1], + expected_codes, + f"logit width mismatch at depth {depth}", + ) + + def test_depth_change(self): + """set_depth changes active decoder and output size.""" + loader = get_dataloader(self.dataset, batch_size=2) + batch = next(iter(loader)) + + self.model.set_depth(0) + with torch.no_grad(): + ret0 = self.model(**batch) + + self.model.set_depth(2) + with torch.no_grad(): + ret2 = self.model(**batch) + + # Depth 0 has fewer codes than depth 2. + self.assertLess(ret0["y_prob"].shape[1], ret2["y_prob"].shape[1]) + + def test_weight_transfer(self): + """After set_depth, child weights equal parent weights at mapped positions.""" + c2p = self.model.hierarchy["child_to_parent"].get(1, {}) + if not c2p: + self.skipTest("No child-to-parent mapping at depth 1") + + self.model.set_depth(1) + + parent_w = self.model.decoder.attention[0].weight.data + child_w = self.model.decoder.attention[1].weight.data + + for child_idx, parent_idx in c2p.items(): + torch.testing.assert_close( + child_w[child_idx], + parent_w[parent_idx], + msg=f"Weight mismatch: child {child_idx} != parent {parent_idx}", + ) + + def test_asymmetric_loss(self): + """AsymmetricLoss produces reasonable scalar values.""" + loss_fn = AsymmetricLoss(gamma_neg=4.0, gamma_pos=1.0, clip=0.05) + logits = torch.randn(4, 10) + targets = torch.zeros(4, 10) + targets[:, :3] = 1.0 + + loss = loss_fn(logits, targets) + + self.assertEqual(loss.dim(), 0) + self.assertTrue(loss.item() > 0, "Loss should be positive") + self.assertTrue(torch.isfinite(loss), "Loss should be finite") + + def test_icd10_hierarchy(self): + """Hierarchy builder maps E11.321 to correct ancestors.""" + codes = ["E11.321", "I10", "I11.0", "J44.1"] + h = build_icd10_hierarchy(codes) + + # Depth 0: chapters + self.assertIn("E00-E89", h["depth_to_codes"][0]) + self.assertIn("I00-I99", h["depth_to_codes"][0]) + self.assertIn("J00-J99", h["depth_to_codes"][0]) + + # Depth 1: categories (3-char) + self.assertIn("E11", h["depth_to_codes"][1]) + self.assertIn("I10", h["depth_to_codes"][1]) + self.assertIn("I11", h["depth_to_codes"][1]) + self.assertIn("J44", h["depth_to_codes"][1]) + + # Depth 2: full codes + self.assertEqual(h["depth_to_codes"][2], codes) + + # Check parent-child: E11 (cat) should be child of E00-E89 (chapter) + cat_idx = h["code_to_index"][1]["E11"] + chapter_idx = h["code_to_index"][0]["E00-E89"] + self.assertEqual(h["child_to_parent"][1][cat_idx], chapter_idx) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_mimic4_icd10_coding.py b/tests/core/test_mimic4_icd10_coding.py new file mode 100644 index 000000000..3f6a9956d --- /dev/null +++ b/tests/core/test_mimic4_icd10_coding.py @@ -0,0 +1,169 @@ +import unittest +from datetime import datetime + +import polars as pl + +from pyhealth.data.data import Patient +from pyhealth.processors import SequenceProcessor, MultiLabelProcessor +from pyhealth.tasks.medical_coding import MIMIC4ICD10Coding, _tokenize_clinical_text + + +def _make_patient(patient_id, rows): + """Build a Patient from a list of row dicts. + + Each row must have at least event_type and timestamp. Extra keys + become columns prefixed with event_type/ (how Patient.get_events + reconstructs Event.attr_dict). + """ + records = [] + for row in rows: + et = row["event_type"] + ts = row.get("timestamp", datetime(2024, 1, 1)) + rec = {"patient_id": patient_id, "event_type": et, "timestamp": ts} + for k, v in row.items(): + if k not in ("event_type", "timestamp"): + rec[f"{et}/{k}"] = v + records.append(rec) + df = pl.DataFrame(records) + return Patient(patient_id=patient_id, data_source=df) + + +class TestMIMIC4ICD10Coding(unittest.TestCase): + """Tests for the MIMIC4ICD10Coding task.""" + + def test_task_schema_types(self): + """Schema uses SequenceProcessor for text and MultiLabelProcessor for codes.""" + task = MIMIC4ICD10Coding() + self.assertIs(task.input_schema["text"], SequenceProcessor) + self.assertIs(task.output_schema["icd_codes"], MultiLabelProcessor) + + def test_basic_extraction(self): + """One admission with ICD-10 codes and a discharge note produces one sample.""" + task = MIMIC4ICD10Coding() + patient = _make_patient("p1", [ + {"event_type": "admissions", "hadm_id": "100"}, + {"event_type": "diagnoses_icd", "hadm_id": "100", + "icd_code": "E11.321", "icd_version": "10"}, + {"event_type": "diagnoses_icd", "hadm_id": "100", + "icd_code": "I10", "icd_version": "10"}, + {"event_type": "discharge", "hadm_id": "100", + "text": "Patient discharged with diabetes and hypertension."}, + ]) + samples = task(patient) + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["patient_id"], "p1") + self.assertIsInstance(samples[0]["text"], list) + self.assertIn("diabetes", samples[0]["text"]) + self.assertEqual(sorted(samples[0]["icd_codes"]), ["E11.321", "I10"]) + + def test_filters_out_icd9(self): + """ICD-9 codes in the same table are excluded from ICD-10 task output.""" + task = MIMIC4ICD10Coding() + patient = _make_patient("p2", [ + {"event_type": "admissions", "hadm_id": "200"}, + {"event_type": "diagnoses_icd", "hadm_id": "200", + "icd_code": "E11.321", "icd_version": "10"}, + {"event_type": "diagnoses_icd", "hadm_id": "200", + "icd_code": "25000", "icd_version": "9"}, + {"event_type": "discharge", "hadm_id": "200", + "text": "Discharge summary text here."}, + ]) + samples = task(patient) + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["icd_codes"], ["E11.321"]) + + def test_no_notes_skips_admission(self): + """Admission without discharge notes produces no sample.""" + task = MIMIC4ICD10Coding() + patient = _make_patient("p3", [ + {"event_type": "admissions", "hadm_id": "300"}, + {"event_type": "diagnoses_icd", "hadm_id": "300", + "icd_code": "I10", "icd_version": "10"}, + # discharge note on a different admission so the column exists + {"event_type": "discharge", "hadm_id": "999", + "text": "unrelated"}, + ]) + samples = task(patient) + self.assertEqual(len(samples), 0) + + def test_no_codes_skips_admission(self): + """Admission with notes but no ICD codes produces no sample.""" + task = MIMIC4ICD10Coding() + patient = _make_patient("p4", [ + {"event_type": "admissions", "hadm_id": "400"}, + {"event_type": "discharge", "hadm_id": "400", + "text": "Patient seen and discharged."}, + # diagnosis on a different admission so the column exists + {"event_type": "diagnoses_icd", "hadm_id": "999", + "icd_code": "X00", "icd_version": "10"}, + ]) + samples = task(patient) + self.assertEqual(len(samples), 0) + + def test_multiple_admissions(self): + """Each admission produces its own sample with correct codes.""" + task = MIMIC4ICD10Coding() + patient = _make_patient("p5", [ + {"event_type": "admissions", "hadm_id": "500"}, + {"event_type": "admissions", "hadm_id": "501"}, + {"event_type": "diagnoses_icd", "hadm_id": "500", + "icd_code": "J44.1", "icd_version": "10"}, + {"event_type": "diagnoses_icd", "hadm_id": "501", + "icd_code": "K21.0", "icd_version": "10"}, + {"event_type": "discharge", "hadm_id": "500", + "text": "First admission discharge."}, + {"event_type": "discharge", "hadm_id": "501", + "text": "Second admission discharge."}, + ]) + samples = task(patient) + self.assertEqual(len(samples), 2) + codes_by_admission = {s["icd_codes"][0] for s in samples} + self.assertEqual(codes_by_admission, {"J44.1", "K21.0"}) + + def test_duplicate_codes_deduplicated(self): + """Duplicate ICD codes for the same admission are deduplicated.""" + task = MIMIC4ICD10Coding() + patient = _make_patient("p6", [ + {"event_type": "admissions", "hadm_id": "600"}, + {"event_type": "diagnoses_icd", "hadm_id": "600", + "icd_code": "I10", "icd_version": "10"}, + {"event_type": "diagnoses_icd", "hadm_id": "600", + "icd_code": "I10", "icd_version": "10"}, + {"event_type": "discharge", "hadm_id": "600", + "text": "Discharge note."}, + ]) + samples = task(patient) + self.assertEqual(len(samples[0]["icd_codes"]), 1) + + def test_text_is_lowercased(self): + """Output tokens are lowercased.""" + task = MIMIC4ICD10Coding() + patient = _make_patient("p7", [ + {"event_type": "admissions", "hadm_id": "700"}, + {"event_type": "diagnoses_icd", "hadm_id": "700", + "icd_code": "E11.321", "icd_version": "10"}, + {"event_type": "discharge", "hadm_id": "700", + "text": "Patient Has DIABETES."}, + ]) + samples = task(patient) + for token in samples[0]["text"]: + self.assertEqual(token, token.lower()) + + +class TestTokenizer(unittest.TestCase): + """Tests for _tokenize_clinical_text.""" + + def test_lowercases(self): + self.assertEqual(_tokenize_clinical_text("Hello World"), ["hello", "world"]) + + def test_truncates(self): + long_text = " ".join(f"word{i}" for i in range(5000)) + tokens = _tokenize_clinical_text(long_text) + self.assertEqual(len(tokens), 4000) + + def test_empty_string(self): + self.assertEqual(_tokenize_clinical_text(""), []) + + +if __name__ == "__main__": + unittest.main() From df7d117a0ed0f336ed0ec9b0620254d3a47a3112 Mon Sep 17 00:00:00 2001 From: Matthew Ardi Date: Sun, 5 Apr 2026 15:36:33 -0500 Subject: [PATCH 2/3] revert --- pyhealth/tasks/medical_coding.py | 54 ++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/pyhealth/tasks/medical_coding.py b/pyhealth/tasks/medical_coding.py index 1c5c41ff8..acd5ee566 100644 --- a/pyhealth/tasks/medical_coding.py +++ b/pyhealth/tasks/medical_coding.py @@ -205,8 +205,8 @@ def __call__(self, patient: Patient) -> List[Dict]: def main(): - """Quick smoke test for medical coding tasks.""" - from pyhealth.datasets import MIMIC3Dataset + # Test case for MIMIC4ICD9Coding and MIMIC3 + from pyhealth.datasets import MIMIC3Dataset, MIMIC4Dataset root = "/srv/local/data/MIMIC-III/mimic-iii-clinical-database-1.4" print("Testing MIMIC3ICD9Coding task...") @@ -218,11 +218,61 @@ def main(): dev=True, ) mimic3_coding = MIMIC3ICD9Coding() + # print(len(mimic3_coding.samples)) samples = dataset.set_task(mimic3_coding) + # Print sample information print(f"Total samples generated: {len(samples)}") if len(samples) > 0: + print("First sample:") print(f" - Text length: {len(samples[0]['text'])} characters") print(f" - Number of ICD codes: {len(samples[0]['icd_codes'])}") + if len(samples[0]["icd_codes"]) > 0: + print( + f" - Sample ICD codes: {samples[0]['icd_codes'][:5] if len(samples[0]['icd_codes']) > 5 else samples[0]['icd_codes']}" + ) + + # Initialize the dataset with dev mode enabled + print("Testing MIMIC4ICD9Coding task...") + dataset = MIMIC4Dataset( + root="/srv/local/data/MIMIC-IV/2.0/hosp", + tables=["diagnoses_icd", "procedures_icd"], + note_root="/srv/local/data/MIMIC-IV/2.0/note", + dev=True, + ) + # Create the task instance + mimic4_coding = MIMIC4ICD9Coding() + + # Generate samples + samples = dataset.set_task(mimic4_coding) + + # Print sample information + print(f"Total samples generated: {len(samples)}") + if len(samples) > 0: + print("First sample:") + print(f" - Text length: {len(samples[0]['text'])} characters") + print(f" - Number of ICD codes: {len(samples[0]['icd_codes'])}") + if len(samples[0]["icd_codes"]) > 0: + print( + f" - Sample ICD codes: {samples[0]['icd_codes'][:5] if len(samples[0]['icd_codes']) > 5 else samples[0]['icd_codes']}" + ) + + print("Testing MIMIC4ICD10Coding task... ") + + mimic4_coding = MIMIC4ICD10Coding() + + # Generate samples + samples = dataset.set_task(mimic4_coding) + + # Print sample information + print(f"Total samples generated: {len(samples)}") + if len(samples) > 0: + print("First sample:") + print(f" - Text length: {len(samples[0]['text'])} characters") + print(f" - Number of ICD codes: {len(samples[0]['icd_codes'])}") + if len(samples[0]["icd_codes"]) > 0: + print( + f" - Sample ICD codes: {samples[0]['icd_codes'][:5] if len(samples[0]['icd_codes']) > 5 else samples[0]['icd_codes']}" + ) if __name__ == "__main__": From c2880e813a7aa6e11309c9acc7b4e429ee55aeea Mon Sep 17 00:00:00 2001 From: "anthropic-code-agent[bot]" <242468646+Claude@users.noreply.github.com> Date: Sun, 5 Apr 2026 21:13:31 +0000 Subject: [PATCH 3/3] Address code review feedback: streaming collection, robust text handling, deterministic ordering, visit_id, and memory-efficient label mappings Agent-Logs-Url: https://github.com/matthew-ardi/PyHealth/sessions/4752d079-651a-4fe1-9faa-3a2025813f50 Co-authored-by: matthew-ardi <25186507+matthew-ardi@users.noreply.github.com> --- pyhealth/models/hicu.py | 29 +++++++++++++++++++++++------ pyhealth/tasks/medical_coding.py | 11 ++++++++--- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/pyhealth/models/hicu.py b/pyhealth/models/hicu.py index f18977538..a516eb7f5 100644 --- a/pyhealth/models/hicu.py +++ b/pyhealth/models/hicu.py @@ -344,14 +344,15 @@ def __init__( self.current_depth = 2 def _build_label_mappings(self) -> None: + """Build sparse full_to_ancestor index mappings instead of dense matrices.""" hierarchy = self.hierarchy full_codes = hierarchy["depth_to_codes"][2] full_idx = hierarchy["code_to_index"][2] for d in range(2): n_full = len(full_codes) - n_depth = len(hierarchy["depth_to_codes"][d]) - mapping = torch.zeros(n_full, n_depth) + # Store as 1D index tensor: full_to_ancestor[i] = ancestor_index + full_to_ancestor = torch.zeros(n_full, dtype=torch.long) for fc in full_codes: fi = full_idx[fc] @@ -360,15 +361,31 @@ def _build_label_mappings(self) -> None: else: ancestor = _get_icd10_category(fc) ai = hierarchy["code_to_index"][d][ancestor] - mapping[fi, ai] = 1.0 + full_to_ancestor[fi] = ai - self.register_buffer(f"_label_map_{d}", mapping) + self.register_buffer(f"_label_map_{d}", full_to_ancestor) def _remap_labels(self, y_true: torch.Tensor, depth: int) -> torch.Tensor: + """Remap full-code labels to ancestor labels at the given depth.""" if depth == 2: return y_true - mapping = getattr(self, f"_label_map_{depth}") - return (y_true @ mapping).clamp(max=1.0) + + # y_true: (batch_size, n_full_codes) + # full_to_ancestor: (n_full_codes,) - maps each full code to its ancestor index + full_to_ancestor = getattr(self, f"_label_map_{depth}") + n_ancestor = len(self.hierarchy["depth_to_codes"][depth]) + + # Use scatter_add to accumulate labels at ancestor positions + batch_size = y_true.shape[0] + y_ancestor = torch.zeros( + batch_size, n_ancestor, dtype=y_true.dtype, device=y_true.device + ) + + # Expand full_to_ancestor for batch processing + indices = full_to_ancestor.unsqueeze(0).expand(batch_size, -1) + y_ancestor.scatter_add_(1, indices, y_true) + + return y_ancestor.clamp(max=1.0) def set_depth(self, depth: int) -> None: """Switch hierarchy depth (0=chapters, 1=categories, 2=full codes) and transfer weights.""" diff --git a/pyhealth/tasks/medical_coding.py b/pyhealth/tasks/medical_coding.py index acd5ee566..ee134693b 100644 --- a/pyhealth/tasks/medical_coding.py +++ b/pyhealth/tasks/medical_coding.py @@ -162,7 +162,7 @@ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: df.filter(pl.col("event_type") == "discharge") .select("patient_id") .unique() - .collect() + .collect(engine="streaming") .to_series() ) ) @@ -187,7 +187,11 @@ def __call__(self, patient: Patient) -> List[Dict]: event_type="discharge", filters=[("hadm_id", "==", admission.hadm_id)], ) - raw_text = " ".join(note.text for note in notes if hasattr(note, "text")) + raw_text = " ".join( + str(note.text) + for note in notes + if hasattr(note, "text") and note.text is not None + ) tokens = _tokenize_clinical_text(raw_text) if not tokens or len(icd_codes) < 1: @@ -196,8 +200,9 @@ def __call__(self, patient: Patient) -> List[Dict]: samples.append( { "patient_id": patient.patient_id, + "visit_id": admission.hadm_id, "text": tokens, - "icd_codes": list(set(icd_codes)), + "icd_codes": sorted(set(icd_codes)), } )