From 98d6a67d776c5a88b86eba17b9129a25f93f2928 Mon Sep 17 00:00:00 2001 From: mducducd Date: Fri, 13 Feb 2026 14:01:16 +0000 Subject: [PATCH 01/11] add multi-target support; tests, fixes, and docs --- src/stamp/__main__.py | 4 +- src/stamp/config.yaml | 24 +- src/stamp/encoding/encoder/__init__.py | 2 +- src/stamp/encoding/encoder/chief.py | 2 +- src/stamp/encoding/encoder/eagle.py | 2 +- src/stamp/encoding/encoder/gigapath.py | 2 +- src/stamp/encoding/encoder/madeleine.py | 2 +- src/stamp/encoding/encoder/titan.py | 2 +- src/stamp/heatmaps/__init__.py | 8 +- src/stamp/modeling/config.py | 26 +- src/stamp/modeling/crossval.py | 230 ++++---- src/stamp/modeling/data.py | 512 +++++++++++++----- src/stamp/modeling/deploy.py | 321 +++++++++-- src/stamp/modeling/models/__init__.py | 79 ++- src/stamp/modeling/models/barspoon.py | 367 +++++++++++++ src/stamp/modeling/registry.py | 10 + src/stamp/modeling/train.py | 156 +++--- src/stamp/preprocessing/__init__.py | 2 +- .../extractor/chief_ctranspath.py | 2 +- .../preprocessing/extractor/ctranspath.py | 2 +- .../preprocessing/extractor/dinobloom.py | 2 +- src/stamp/preprocessing/tiling.py | 2 +- src/stamp/statistics/__init__.py | 40 +- src/stamp/statistics/survival.py | 12 +- src/stamp/types.py | 1 + src/stamp/{ => utils}/cache.py | 0 src/stamp/{ => utils}/config.py | 0 src/stamp/{ => utils}/seed.py | 0 src/stamp/utils/target_file.py | 351 ++++++++++++ tests/random_data.py | 86 +++ tests/test_cache_tiles.py | 2 +- tests/test_config.py | 2 +- tests/test_crossval.py | 2 +- tests/test_data.py | 2 +- tests/test_deployment.py | 31 +- tests/test_encoders.py | 2 +- tests/test_feature_extractors.py | 2 +- tests/test_model.py | 112 ++++ tests/test_statistics.py | 151 ++++++ tests/test_train_deploy.py | 99 +++- uv.lock | 9 +- 41 files changed, 2210 insertions(+), 453 deletions(-) create mode 100644 src/stamp/modeling/models/barspoon.py rename src/stamp/{ => utils}/cache.py (100%) rename src/stamp/{ => utils}/config.py (100%) rename src/stamp/{ => utils}/seed.py (100%) create mode 100644 src/stamp/utils/target_file.py diff --git a/src/stamp/__main__.py b/src/stamp/__main__.py index 4ab8416f..ffa98bae 100755 --- a/src/stamp/__main__.py +++ b/src/stamp/__main__.py @@ -6,14 +6,14 @@ import yaml -from stamp.config import StampConfig from stamp.modeling.config import ( AdvancedConfig, MlpModelParams, ModelParams, VitModelParams, ) -from stamp.seed import Seed +from stamp.utils.config import StampConfig +from stamp.utils.seed import Seed STAMP_FACTORY_SETTINGS = Path(__file__).with_name("config.yaml") diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index 796140a5..4f16dcb2 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -4,7 +4,7 @@ preprocessing: # Extractor to use for feature extractor. Possible options are "ctranspath", # "uni", "conch", "chief-ctranspath", "conch1_5", "uni2", "dino-bloom", # "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow", - # "virchow-full", "musk", "mstar", "plip" + # "virchow-full", "musk", "mstar", "plip", "ticon" # Some of them require requesting access to the respective authors beforehand. extractor: "chief-ctranspath" @@ -73,6 +73,8 @@ crossval: # Name of the column from the clini table to train on. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # For survival (should be status and follow-up days columns in clini table) # status_label: "event" @@ -130,6 +132,8 @@ training: # Name of the column from the clini table to train on. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # For survival (should be status and follow-up days columns in clini table) # status_label: "event" @@ -172,6 +176,8 @@ deployment: # Name of the column from the clini to compare predictions to. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # For survival (should be status and follow-up days columns in clini table) # status_label: "event" @@ -197,6 +203,8 @@ statistics: # Name of the target label. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # A lot of the statistics are computed "one-vs-all", i.e. there needs to be # a positive class to calculate the statistics for. @@ -316,7 +324,7 @@ advanced_config: max_lr: 1e-4 div_factor: 25. # Select a model regardless of task - model_name: "vit" # or mlp, trans_mil + model_name: "vit" # or mlp, trans_mil, barspoon model_params: vit: # Vision Transformer @@ -335,3 +343,15 @@ advanced_config: dim_hidden: 512 num_layers: 2 dropout: 0.25 + + # NOTE: Only the `barspoon` model supports multi-target classification + # (i.e. `ground_truth_label` can be a list of column names). Other + # models expect a single target column. + barspoon: # Encoder-Decoder Transformer for multi-target classification + d_model: 512 + num_encoder_heads: 8 + num_decoder_heads: 8 + num_encoder_layers: 2 + num_decoder_layers: 2 + dim_feedforward: 2048 + positional_encoding: true \ No newline at end of file diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index 5827e884..86daa54a 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -12,11 +12,11 @@ from tqdm import tqdm import stamp -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.modeling.data import CoordsInfo, get_coords, read_table from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, PandasLabel +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/chief.py b/src/stamp/encoding/encoder/chief.py index 2ad4b91b..924ceebb 100644 --- a/src/stamp/encoding/encoder/chief.py +++ b/src/stamp/encoding/encoder/chief.py @@ -10,11 +10,11 @@ from numpy import ndarray from tqdm import tqdm -from stamp.cache import STAMP_CACHE_DIR, file_digest, get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, PandasLabel +from stamp.utils.cache import STAMP_CACHE_DIR, file_digest, get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/eagle.py b/src/stamp/encoding/encoder/eagle.py index 03bc833e..9266f315 100644 --- a/src/stamp/encoding/encoder/eagle.py +++ b/src/stamp/encoding/encoder/eagle.py @@ -9,13 +9,13 @@ from torch import Tensor from tqdm import tqdm -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.encoding.encoder.chief import CHIEF from stamp.modeling.data import CoordsInfo from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, PandasLabel +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/gigapath.py b/src/stamp/encoding/encoder/gigapath.py index 9cb3f6f5..4c0a2f6b 100644 --- a/src/stamp/encoding/encoder/gigapath.py +++ b/src/stamp/encoding/encoder/gigapath.py @@ -9,12 +9,12 @@ from gigapath import slide_encoder from tqdm import tqdm -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.modeling.data import CoordsInfo from stamp.preprocessing.config import ExtractorName from stamp.types import PandasLabel, SlideMPP +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/madeleine.py b/src/stamp/encoding/encoder/madeleine.py index 5798a592..a0c74dcd 100644 --- a/src/stamp/encoding/encoder/madeleine.py +++ b/src/stamp/encoding/encoder/madeleine.py @@ -3,10 +3,10 @@ import torch from numpy import ndarray -from stamp.cache import STAMP_CACHE_DIR from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.preprocessing.config import ExtractorName +from stamp.utils.cache import STAMP_CACHE_DIR try: from madeleine.models.factory import create_model_from_pretrained diff --git a/src/stamp/encoding/encoder/titan.py b/src/stamp/encoding/encoder/titan.py index 1012d98f..568254ca 100644 --- a/src/stamp/encoding/encoder/titan.py +++ b/src/stamp/encoding/encoder/titan.py @@ -10,12 +10,12 @@ from tqdm import tqdm from transformers import AutoModel -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.modeling.data import CoordsInfo from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, Microns, PandasLabel, SlideMPP +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/heatmaps/__init__.py b/src/stamp/heatmaps/__init__.py index 446d85d6..22fb5250 100644 --- a/src/stamp/heatmaps/__init__.py +++ b/src/stamp/heatmaps/__init__.py @@ -70,14 +70,14 @@ def _attention_rollout_single( device = feats.device - # --- 1. Forward pass to fill attn_weights in each SelfAttention layer --- + # 1. Forward pass to fill attn_weights in each SelfAttention layer _ = model( bags=feats.unsqueeze(0), coords=coords.unsqueeze(0), mask=torch.zeros(1, len(feats), dtype=torch.bool, device=device), ) - # --- 2. Rollout computation --- + # 2. Rollout computation attn_rollout: torch.Tensor | None = None for layer in model.transformer.layers: # type: ignore attn = getattr(layer[0], "attn_weights", None) # SelfAttention.attn_weights @@ -96,10 +96,10 @@ def _attention_rollout_single( if attn_rollout is None: raise RuntimeError("No attention maps collected from transformer layers.") - # --- 3. Extract CLS → tiles attention --- + # 3. Extract CLS → tiles attention cls_attn = attn_rollout[0, 1:] # [tile] - # --- 4. Normalize for visualization consistency --- + # 4. Normalize for visualization consistency cls_attn = cls_attn - cls_attn.min() cls_attn = cls_attn / (cls_attn.max().clamp(min=1e-8)) diff --git a/src/stamp/modeling/config.py b/src/stamp/modeling/config.py index 21ce69db..5b9a6bcc 100644 --- a/src/stamp/modeling/config.py +++ b/src/stamp/modeling/config.py @@ -21,7 +21,7 @@ class TrainConfig(BaseModel): ) feature_dir: Path = Field(description="Directory containing feature files") - ground_truth_label: PandasLabel | None = Field( + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = Field( default=None, description="Name of categorical column in clinical table to train on", ) @@ -64,7 +64,7 @@ class DeploymentConfig(BaseModel): slide_table: Path feature_dir: Path - ground_truth_label: PandasLabel | None = None + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None patient_label: PandasLabel = "PATIENT" filename_label: PandasLabel = "FILENAME" @@ -99,8 +99,29 @@ class TransMILModelParams(BaseModel): dim_hidden: int = 512 +class BarspoonParams(BaseModel): + model_config = ConfigDict(extra="forbid") + d_model: int = 512 + num_encoder_heads: int = 8 + num_decoder_heads: int = 8 + num_encoder_layers: int = 2 + num_decoder_layers: int = 2 + dim_feedforward: int = 2048 + positional_encoding: bool = True + # Other hparams + learning_rate: float = 1e-4 + + class LinearModelParams(BaseModel): model_config = ConfigDict(extra="forbid") + num_encoder_heads: int = 8 + num_decoder_heads: int = 8 + num_encoder_layers: int = 2 + num_decoder_layers: int = 2 + dim_feedforward: int = 2048 + positional_encoding: bool = True + # Other hparams + learning_rate: float = 1e-4 class ModelParams(BaseModel): @@ -109,6 +130,7 @@ class ModelParams(BaseModel): trans_mil: TransMILModelParams = Field(default_factory=TransMILModelParams) mlp: MlpModelParams = Field(default_factory=MlpModelParams) linear: LinearModelParams = Field(default_factory=LinearModelParams) + barspoon: BarspoonParams = Field(default_factory=BarspoonParams) class AdvancedConfig(BaseModel): diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index 43e76f01..2caccecb 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -1,8 +1,10 @@ import logging +from collections import Counter from collections.abc import Mapping, Sequence -from typing import Any, Final +from typing import Any, cast import numpy as np +import torch from pydantic import BaseModel from sklearn.model_selection import KFold, StratifiedKFold @@ -10,13 +12,8 @@ from stamp.modeling.data import ( PatientData, create_dataloader, - detect_feature_type, - filter_complete_patient_data_, - load_patient_level_data, + load_patient_data_, log_patient_class_summary, - patient_to_ground_truth_from_clini_table_, - patient_to_survival_from_clini_table_, - slide_to_patient_from_slide_table_, ) from stamp.modeling.deploy import ( _predict, @@ -28,7 +25,6 @@ from stamp.modeling.train import setup_model_for_training, train_model_ from stamp.modeling.transforms import VaryPrecisionTransform from stamp.types import ( - FeaturePath, GroundTruth, PatientId, ) @@ -53,67 +49,31 @@ def categorical_crossval_( config: CrossvalConfig, advanced: AdvancedConfig, ) -> None: - feature_type = detect_feature_type(config.feature_dir) + if config.task is None: + raise ValueError( + "task must be set to 'classification' | 'regression' | 'survival'" + ) + + patient_to_data, feature_type = load_patient_data_( + feature_dir=config.feature_dir, + clini_table=config.clini_table, + slide_table=config.slide_table, + task=config.task, + ground_truth_label=config.ground_truth_label, + time_label=config.time_label, + status_label=config.status_label, + patient_label=config.patient_label, + filename_label=config.filename_label, + drop_patients_with_missing_ground_truth=True, + ) _logger.info(f"Detected feature type: {feature_type}") - if feature_type in ("tile", "slide"): - if config.slide_table is None: - raise ValueError("A slide table is required for modeling") - if config.task == "survival": - if config.time_label is None or config.status_label is None: - raise ValueError( - "Both time_label and status_label are is required for survival modeling" - ) - patient_to_ground_truth: dict[PatientId, GroundTruth] = ( - patient_to_survival_from_clini_table_( - clini_table_path=config.clini_table, - time_label=config.time_label, - status_label=config.status_label, - patient_label=config.patient_label, - ) - ) - else: - if config.ground_truth_label is None: - raise ValueError( - "Ground truth label is required for classification or regression modeling" - ) - patient_to_ground_truth: dict[PatientId, GroundTruth] = ( - patient_to_ground_truth_from_clini_table_( - clini_table_path=config.clini_table, - ground_truth_label=config.ground_truth_label, - patient_label=config.patient_label, - ) - ) - slide_to_patient: Final[dict[FeaturePath, PatientId]] = ( - slide_to_patient_from_slide_table_( - slide_table_path=config.slide_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - filename_label=config.filename_label, - ) - ) - patient_to_data: Mapping[PatientId, PatientData] = ( - filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, - slide_to_patient=slide_to_patient, - drop_patients_with_missing_ground_truth=True, - ) - ) - elif feature_type == "patient": - patient_to_data: Mapping[PatientId, PatientData] = load_patient_level_data( - task=config.task, - clini_table=config.clini_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - ground_truth_label=config.ground_truth_label, - time_label=config.time_label, - status_label=config.status_label, - ) - patient_to_ground_truth: dict[PatientId, GroundTruth] = { - pid: pd.ground_truth for pid, pd in patient_to_data.items() - } - else: - raise RuntimeError(f"Unsupported feature type: {feature_type}") + patient_to_ground_truth = { + pid: pd.ground_truth for pid, pd in patient_to_data.items() + } + + if feature_type not in ("tile", "slide", "patient"): + raise ValueError(f"Unknown feature type: {feature_type}") config.output_dir.mkdir(parents=True, exist_ok=True) splits_file = config.output_dir / "splits.json" @@ -158,18 +118,51 @@ def categorical_crossval_( f"{ground_truths_not_in_split}" ) + categories_for_export: ( + dict[str, list] | list + ) = [] # declare upfront to avoid unbound variable warnings + categories: Sequence[GroundTruth] | list | None = [] # type: ignore # declare upfront to avoid unbound variable warnings + if config.task == "classification": - categories = config.categories or sorted( - { - patient_data.ground_truth - for patient_data in patient_to_data.values() - if patient_data.ground_truth is not None - } - ) - log_patient_class_summary( - patient_to_data={pid: patient_to_data[pid] for pid in patient_to_data}, - categories=categories, - ) + # Determine categories for training (single-target) and for export (supports multi-target) + if isinstance(config.ground_truth_label, str): + categories = config.categories or sorted( + { + patient_data.ground_truth + for patient_data in patient_to_data.values() + if patient_data.ground_truth is not None + } + ) + log_patient_class_summary( + patient_to_data={pid: patient_to_data[pid] for pid in patient_to_data}, + categories=categories, + ) + categories_for_export = cast(list, categories) + else: + # Multi-target: build a mapping from target label -> sorted list of categories + categories_accum: dict[str, set[GroundTruth]] = {} + for patient_data in patient_to_data.values(): + gt = patient_data.ground_truth + if isinstance(gt, dict): + for k, v in gt.items(): + if v is not None: + categories_accum.setdefault(k, set()).add(v) + categories_for_export = {k: sorted(v) for k, v in categories_accum.items()} + # Log summary per target + for t, cats in categories_for_export.items(): + ground_truths = [ + pd.ground_truth.get(t) + for pd in patient_to_data.values() + if isinstance(pd.ground_truth, dict) + and pd.ground_truth.get(t) is not None + ] + counter = Counter(ground_truths) + _logger.info( + f"{t} | Total patients: {len(ground_truths)} | " + + " | ".join([f"Class {c}: {counter.get(c, 0)}" for c in cats]) + ) + # For training, categories can remain None (inferred later) + categories = config.categories or None else: categories = [] @@ -206,12 +199,18 @@ def categorical_crossval_( }, categories=( categories - or sorted( - { - patient_data.ground_truth - for patient_data in patient_to_data.values() - if patient_data.ground_truth is not None - } + if categories is not None + else ( + sorted( + { + patient_data.ground_truth + for patient_data in patient_to_data.values() + if patient_data.ground_truth is not None + and not isinstance(patient_data.ground_truth, dict) + } + ) + if not isinstance(config.ground_truth_label, Sequence) + else None ) ), train_transform=( @@ -263,30 +262,48 @@ def categorical_crossval_( ) if config.task == "survival": - _to_survival_prediction_df( - patient_to_ground_truth=patient_to_ground_truth, - predictions=predictions, - patient_label=config.patient_label, - cut_off=getattr(model.hparams, "train_pred_median", None), - ).to_csv(split_dir / "patient-preds.csv", index=False) + if isinstance(config.ground_truth_label, str): + _to_survival_prediction_df( + patient_to_ground_truth=cast( + Mapping[PatientId, str | None], patient_to_ground_truth + ), + predictions=cast(Mapping[PatientId, torch.Tensor], predictions), + patient_label=config.patient_label, + cut_off=getattr(model.hparams, "train_pred_median", None), + ).to_csv(split_dir / "patient-preds.csv", index=False) + else: + _logger.warning( + "Multi-target survival prediction export not yet supported; skipping CSV save" + ) elif config.task == "regression": if config.ground_truth_label is None: raise RuntimeError("Grounf truth label is required for regression") - _to_regression_prediction_df( - patient_to_ground_truth=patient_to_ground_truth, - predictions=predictions, - patient_label=config.patient_label, - ground_truth_label=config.ground_truth_label, - ).to_csv(split_dir / "patient-preds.csv", index=False) + if isinstance(config.ground_truth_label, str): + _to_regression_prediction_df( + patient_to_ground_truth=cast( + Mapping[PatientId, str | None], patient_to_ground_truth + ), + predictions=cast(Mapping[PatientId, torch.Tensor], predictions), + patient_label=config.patient_label, + ground_truth_label=config.ground_truth_label, + ).to_csv(split_dir / "patient-preds.csv", index=False) + else: + _logger.warning( + "Multi-target regression prediction export not yet supported; skipping CSV save" + ) else: if config.ground_truth_label is None: raise RuntimeError( "Grounf truth label is required for classification" ) _to_prediction_df( - categories=categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, - predictions=predictions, + predictions=cast( + Mapping[PatientId, torch.Tensor] + | Mapping[PatientId, dict[str, torch.Tensor]], + predictions, + ), patient_label=config.patient_label, ground_truth_label=config.ground_truth_label, ).to_csv(split_dir / "patient-preds.csv", index=False) @@ -296,6 +313,16 @@ def _get_splits( *, patient_to_data: Mapping[PatientId, PatientData[Any]], n_splits: int, spliter ) -> _Splits: patients = np.array(list(patient_to_data.keys())) + + # Extract ground truth for stratification. + # For multi-target (dict), use the first target's value + y_strat = np.array( + [ + next(iter(gt.values())) if isinstance(gt, dict) else gt + for gt in [patient.ground_truth for patient in patient_to_data.values()] + ] + ) + skf = spliter(n_splits=n_splits, shuffle=True, random_state=0) splits = _Splits( splits=[ @@ -303,12 +330,7 @@ def _get_splits( train_patients=set(patients[train_indices]), test_patients=set(patients[test_indices]), ) - for train_indices, test_indices in skf.split( - patients, - np.array( - [patient.ground_truth for patient in patient_to_data.values()] - ), - ) + for train_indices, test_indices in skf.split(patients, y_strat) ] ) return splits diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 6cabec64..8b30ab11 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -5,19 +5,29 @@ from dataclasses import KW_ONLY, dataclass from itertools import groupby from pathlib import Path -from typing import IO, BinaryIO, Counter, Generic, TextIO, TypeAlias, Union, cast +from typing import ( + IO, + Any, + BinaryIO, + Dict, + Final, + Generic, + List, + TextIO, + TypeAlias, + Union, + cast, +) import h5py import numpy as np import pandas as pd import torch -from jaxtyping import Float from packaging.version import Version from torch import Tensor from torch.utils.data import DataLoader, Dataset import stamp -from stamp.seed import Seed from stamp.types import ( Bags, BagSize, @@ -35,6 +45,7 @@ Task, TilePixels, ) +from stamp.utils.seed import Seed _logger = logging.getLogger("stamp") @@ -43,14 +54,17 @@ __copyright__ = "Copyright (C) 2022-2025 Marko van Treeck, Minh Duc Nguyen" __license__ = "MIT" -_Bag: TypeAlias = Float[Tensor, "tile feature"] -_EncodedTarget: TypeAlias = Float[Tensor, "category_is_hot"] | Float[Tensor, "1"] # noqa: F821 +_Bag: TypeAlias = Tensor +_EncodedTarget: TypeAlias = ( + Tensor | dict[str, Tensor] +) # Union of encoded targets or multi-target dict _BinaryIOLike: TypeAlias = Union[BinaryIO, IO[bytes]] """The ground truth, encoded numerically - classification: one-hot float [C] - regression: float [1] +- multi-target: dict[target_name -> one-hot/regression value] """ -_Coordinates: TypeAlias = Float[Tensor, "tile 2"] +_Coordinates: TypeAlias = Tensor @dataclass @@ -64,7 +78,7 @@ class PatientData(Generic[GroundTruthType]): def tile_bag_dataloader( *, - patient_data: Sequence[PatientData[GroundTruth | None]], + patient_data: Sequence[PatientData[GroundTruth | None | dict]], bag_size: int | None, task: Task, categories: Sequence[Category] | None = None, @@ -74,7 +88,7 @@ def tile_bag_dataloader( transform: Callable[[Tensor], Tensor] | None, ) -> tuple[ DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], - Sequence[Category], + Sequence[Category] | Mapping[str, Sequence[Category]], ]: """Creates a dataloader from patient data for tile-level (bagged) features. @@ -86,103 +100,139 @@ def tile_bag_dataloader( task='regression': returns float targets """ - if task == "classification": - raw_ground_truths = np.array([patient.ground_truth for patient in patient_data]) - categories = ( - categories if categories is not None else list(np.unique(raw_ground_truths)) - ) - # one_hot = torch.tensor(raw_ground_truths.reshape(-1, 1) == categories) - one_hot = torch.tensor( - raw_ground_truths.reshape(-1, 1) == categories, dtype=torch.float32 - ) - ds = BagDataset( - bags=[patient.feature_files for patient in patient_data], - bag_size=bag_size, - ground_truths=one_hot, - transform=transform, - ) - cats_out: Sequence[Category] = list(categories) - elif task == "regression": - raw_targets = np.array( - [ - np.nan if p.ground_truth is None else float(p.ground_truth) - for p in patient_data - ], - dtype=np.float32, - ) - y = torch.from_numpy(raw_targets).reshape(-1, 1) + targets, cats_out = _parse_targets( + patient_data=patient_data, + task=task, + categories=categories, + ) - ds = BagDataset( - bags=[patient.feature_files for patient in patient_data], - bag_size=bag_size, - ground_truths=y, - transform=transform, - ) - cats_out = [] + is_multitarget = isinstance(targets[0], dict) - elif task == "survival": # Not yet support logistic-harzard - times: list[float] = [] - events: list[float] = [] + collate_fn = _collate_multitarget if is_multitarget else _collate_to_tuple - for p in patient_data: - if p.ground_truth is None: - times.append(np.nan) - events.append(np.nan) - continue + ds = BagDataset( + bags=[patient.feature_files for patient in patient_data], + bag_size=bag_size, + ground_truths=targets, + transform=transform, + ) + dl = DataLoader( + ds, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=collate_fn, + worker_init_fn=Seed.get_loader_worker_init() if Seed._is_set() else None, + ) - try: - time_str, status_str = p.ground_truth.split(" ", 1) + return ( + cast( + DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], + dl, + ), + cats_out, + ) - # Handle missing values encoded as "nan" - if time_str.lower() == "nan": - times.append(np.nan) - else: - times.append(float(time_str)) - - if status_str.lower() == "nan": - events.append(np.nan) - elif status_str.lower() in {"dead", "event", "1", "Yes", "yes"}: - events.append(1.0) - elif status_str.lower() in {"alive", "censored", "0", "No", "no"}: - events.append(0.0) - else: - events.append(np.nan) # unknown status → mark missing - except Exception: +def _parse_targets( + *, + patient_data: Sequence, + task: Task, + categories: Sequence[Category] | None = None, + target_spec: dict[str, Any] | None = None, + target_label: str | None = None, +) -> tuple[ + Union[torch.Tensor, list[dict[str, torch.Tensor]]], + Sequence[Category] | Mapping[str, Sequence[Category]], +]: + """ + Parse raw GroundTruth (str) into model-ready tensors. + This is the ONLY place task semantics live. + """ + + gts = [p.ground_truth for p in patient_data] + + if task == "classification": + if any(isinstance(gt, dict) for gt in gts if gt is not None): + # infer target names from the first non-None dict + first_dict = next(gt for gt in gts if isinstance(gt, dict)) + target_names = list(first_dict.keys()) + + # infer categories per target (ignore None patients, ignore None values) + categories_out: dict[str, list[str]] = {t: [] for t in target_names} + for gt in gts: + if not isinstance(gt, dict): + continue + for t in target_names: + v = gt.get(t) + if v is not None: + categories_out[t].append(v) + + # make unique + sorted + categories_out = { + t: sorted(set(vals)) for t, vals in categories_out.items() + } + + # encode per patient; if gt missing -> all zeros + encoded: list[dict[str, Tensor]] = [] + for gt in gts: + patient_encoded: dict[str, Tensor] = {} + for t in target_names: + cats = categories_out[t] + if not isinstance(gt, dict) or gt.get(t) is None: + one_hot = torch.zeros(len(cats), dtype=torch.float32) + else: + one_hot = torch.tensor( + [gt[t] == c for c in cats], + dtype=torch.float32, + ) + patient_encoded[t] = one_hot + encoded.append(patient_encoded) + + # IMPORTANT: return categories as mapping, not list-of-target-names + return encoded, categories_out + + # single target + unique = {gt for gt in gts if gt is not None} + if len(unique) >= 2 or categories is not None: + raw = np.array([p.ground_truth for p in patient_data]) + categories = categories or list(sorted(unique)) + labels = torch.tensor( + raw.reshape(-1, 1) == categories, + dtype=torch.float32, + ) + return labels, categories + + raise ValueError( + "Only one unique class found in classification task. " + "This is usually a data or configuration error." + ) + + elif task == "regression": + y = torch.tensor( + [np.nan if gt is None else float(gt) for gt in gts], + dtype=torch.float32, + ).reshape(-1, 1) + return y, [] + + elif task == "survival": + times, events = [], [] + for gt in gts: + if gt is None: times.append(np.nan) events.append(np.nan) + continue - # Final tensor shape: (N, 2) - y = torch.tensor(np.column_stack([times, events]), dtype=torch.float32) + time_str, status_str = gt.split(" ", 1) + times.append(np.nan if time_str.lower() == "nan" else float(time_str)) + events.append(_parse_survival_status(status_str)) - ds = BagDataset( - bags=[patient.feature_files for patient in patient_data], - bag_size=bag_size, - ground_truths=y, - transform=transform, - ) - cats_out: Sequence[Category] = [] # survival has no categories + y = torch.tensor(np.column_stack([times, events]), dtype=torch.float32) + return y, [] else: - raise ValueError(f"Unknown task: {task}") - - return ( - cast( - DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], - DataLoader( - ds, - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - collate_fn=_collate_to_tuple, - worker_init_fn=Seed.get_loader_worker_init() - if Seed._is_set() - else None, - ), - ), - cats_out, - ) + raise ValueError(f"Unsupported task: {task}") def _collate_to_tuple( @@ -210,6 +260,24 @@ def _collate_to_tuple( return (bags, coords, bag_sizes, encoded_targets) +def _collate_multitarget( + items: list[tuple[_Bag, _Coordinates, BagSize, Dict[str, Tensor]]], +) -> tuple[Bags, CoordinatesBatch, BagSizes, Dict[str, Tensor]]: + bags = torch.stack([b for b, _, _, _ in items]) + coords = torch.stack([c for _, c, _, _ in items]) + bag_sizes = torch.tensor([s for _, _, s, _ in items]) + + acc: Dict[str, List[Tensor]] = {} + + for _, _, _, tdict in items: + for k, v in tdict.items(): + acc.setdefault(k, []).append(v) + + targets: Dict[str, Tensor] = {k: torch.stack(v, dim=0) for k, v in acc.items()} + + return bags, coords, bag_sizes, targets + + def patient_feature_dataloader( *, patient_data: Sequence[PatientData[GroundTruth | None]], @@ -237,21 +305,31 @@ def create_dataloader( *, feature_type: str, task: Task, - patient_data: Sequence[PatientData[GroundTruth | None]], + patient_data: Sequence[PatientData[GroundTruth | None | dict]], bag_size: int | None = None, batch_size: int, shuffle: bool, num_workers: int, transform: Callable[[Tensor], Tensor] | None, - categories: Sequence[Category] | None = None, -) -> tuple[DataLoader, Sequence[Category]]: + categories: Sequence[Category] | Mapping[str, Sequence[Category]] | None = None, +) -> tuple[DataLoader, Sequence[Category] | Mapping[str, Sequence[Category]]]: """Unified dataloader for all feature types and tasks.""" if feature_type == "tile": + # For multi-target classification, categories may be a mapping from + # target name to per-target categories. _parse_targets (used inside + # tile_bag_dataloader) only consumes explicit categories for the + # single-target case, so we pass a sequence or None here. + cats_arg: Sequence[Category] | None + if isinstance(categories, Mapping): + cats_arg = None + else: + cats_arg = categories + return tile_bag_dataloader( patient_data=patient_data, bag_size=bag_size, task=task, - categories=categories, + categories=cats_arg, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, @@ -263,21 +341,32 @@ def create_dataloader( if task == "classification": raw = np.array([p.ground_truth for p in patient_data]) - categories = categories or list(np.unique(raw)) - labels = torch.tensor(raw.reshape(-1, 1) == categories, dtype=torch.float32) - elif task == "regression": + categories_out = categories or list(np.unique(raw)) labels = torch.tensor( - [ - float(gt) - for gt in (p.ground_truth for p in patient_data) - if gt is not None - ], - dtype=torch.float32, - ).reshape(-1, 1) + raw.reshape(-1, 1) == categories_out, dtype=torch.float32 + ) + elif task == "regression": + values: list[float] = [] + for gt in (p.ground_truth for p in patient_data): + if gt is None: + continue + if isinstance(gt, dict): + # Use first value for multi-target regression + first_val = next(iter(gt.values())) + values.append(float(first_val)) + else: + values.append(float(gt)) + + labels = torch.tensor(values, dtype=torch.float32).reshape(-1, 1) elif task == "survival": times, events = [], [] for p in patient_data: - t, e = (p.ground_truth or "nan nan").split(" ", 1) + if isinstance(p.ground_truth, dict): + # Multi-target survival: use first target + val = list(p.ground_truth.values())[0] + t, e = (val or "nan nan").split(" ", 1) + else: + t, e = (p.ground_truth or "nan nan").split(" ", 1) times.append(float(t) if t.lower() != "nan" else np.nan) events.append(_parse_survival_status(e)) @@ -340,9 +429,9 @@ def load_patient_level_data( clini_table: Path, feature_dir: Path, patient_label: PandasLabel, - ground_truth_label: PandasLabel | None = None, # <- now optional - time_label: PandasLabel | None = None, # <- for survival - status_label: PandasLabel | None = None, # <- for survival + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None, + time_label: PandasLabel | None = None, + status_label: PandasLabel | None = None, feature_ext: str = ".h5", ) -> dict[PatientId, PatientData]: """ @@ -419,7 +508,7 @@ class BagDataset(Dataset[tuple[_Bag, _Coordinates, BagSize, _EncodedTarget]]): If `bag_size` is None, all the samples will be used. """ - ground_truths: Float[Tensor, "index category_is_hot"] | Float[Tensor, "index 1"] + ground_truths: Tensor | list[dict[str, Tensor]] # ground_truths: Bool[Tensor, "index category_is_hot"] # """The ground truth for each bag, one-hot encoded.""" @@ -529,7 +618,7 @@ def mpp(self) -> SlideMPP: def get_coords(feature_h5: h5py.File) -> CoordsInfo: - # --- NEW: handle missing coords ----multiplex data bypass: no coords found; generated fake coords + # NEW: handle missing coords - multiplex data bypass: no coords found; generated fake coords if "coords" not in feature_h5: feats_obj = feature_h5["patch_embeddings"] @@ -627,33 +716,71 @@ def patient_to_ground_truth_from_clini_table_( *, clini_table_path: Path | TextIO, patient_label: PandasLabel, - ground_truth_label: PandasLabel, -) -> dict[PatientId, GroundTruth]: - """Loads the patients and their ground truths from a clini table.""" + ground_truth_label: PandasLabel | Sequence[PandasLabel], +) -> ( + dict[PatientId, GroundTruth | None] | dict[PatientId, dict[str, GroundTruth | None]] +): + """Loads the patients and their ground truths from a clini table. + + `ground_truth_label` may be either a single column name (str) or a sequence + of column names. In the latter case the returned mapping will contain a + dict mapping column -> value for each patient (supporting multi-target + setups). + """ + # Normalize to list for uniform handling + if isinstance(ground_truth_label, str): + cols = [patient_label, ground_truth_label] + multi = False + target_cols_inner: list[PandasLabel] = [] + else: + cols = [patient_label, *list(ground_truth_label)] + multi = True + target_cols_inner = [c for c in cols if c != patient_label] + clini_df = read_table( clini_table_path, - usecols=[patient_label, ground_truth_label], + usecols=cols, dtype=str, - ).dropna() + ) + + # If multi-target, keep rows where at least one target is present; for + # single target behave like before and drop rows missing the value. + if multi: + clini_df = clini_df.dropna(subset=target_cols_inner, how="all") + else: + clini_df = clini_df.dropna(subset=[ground_truth_label]) + try: - patient_to_ground_truth: Mapping[PatientId, GroundTruth] = clini_df.set_index( - patient_label, verify_integrity=True - )[ground_truth_label].to_dict() + if multi: + # Build mapping patient -> {col: value} + result: dict[PatientId, dict[str, GroundTruth | None]] = {} + for _, row in clini_df.iterrows(): + pid = row[patient_label] + # Convert pandas nan to None and keep strings otherwise + result[pid] = { + col: (None if pd.isna(row[col]) else str(row[col])) + for col in target_cols_inner + } + return result + else: + patient_to_ground_truth: Mapping[PatientId, str] = cast( + Mapping[PatientId, str], + clini_df.set_index(patient_label, verify_integrity=True)[ + cast(PandasLabel, ground_truth_label) + ].to_dict(), + ) + return cast(dict[PatientId, GroundTruth | None], patient_to_ground_truth) except KeyError as e: if patient_label not in clini_df: raise ValueError( f"{patient_label} was not found in clini table " f"(columns in clini table: {clini_df.columns})" ) from e - elif ground_truth_label not in clini_df: + else: raise ValueError( - f"{ground_truth_label} was not found in clini table " + f"One or more ground truth columns were not found in clini table " f"(columns in clini table: {clini_df.columns})" ) from e - else: - raise e from e - - return patient_to_ground_truth def patient_to_survival_from_clini_table_( @@ -667,7 +794,6 @@ def patient_to_survival_from_clini_table_( Loads patients and their survival ground truths (time + event) from a clini table. Returns - ------- dict[PatientId, GroundTruth] Mapping patient_id -> "time status" (e.g. "302 dead", "476 alive"). """ @@ -780,7 +906,9 @@ def read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: def filter_complete_patient_data_( *, - patient_to_ground_truth: Mapping[PatientId, GroundTruth | None], + patient_to_ground_truth: Mapping[ + PatientId, GroundTruth | dict[str, GroundTruth] | None + ], slide_to_patient: Mapping[FeaturePath, PatientId], drop_patients_with_missing_ground_truth: bool, ) -> Mapping[PatientId, PatientData]: @@ -865,7 +993,7 @@ def _log_patient_slide_feature_inconsistencies( ) -def get_stride(coords: Float[Tensor, "tile 2"]) -> float: +def get_stride(coords: Tensor) -> float: """Gets the minimum step width between any two coordintes.""" xs: Tensor = coords[:, 0].unique(sorted=True) ys: Tensor = coords[:, 1].unique(sorted=True) @@ -927,26 +1055,130 @@ def _parse_survival_status(value) -> int | None: ) +def load_patient_data_( + *, + feature_dir: Path, + clini_table: Path, + slide_table: Path | None, + task: Task, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, + time_label: PandasLabel | None, + status_label: PandasLabel | None, + patient_label: PandasLabel, + filename_label: PandasLabel, + drop_patients_with_missing_ground_truth: bool = True, +) -> tuple[Mapping[PatientId, PatientData], str]: + """Load patient data based on feature type (tile, slide, or patient). + + This consolidates the common data loading logic used across train, crossval, and deploy. + + Returns: + (patient_to_data, feature_type) + """ + feature_type = detect_feature_type(feature_dir) + + if feature_type in ("tile", "slide"): + if slide_table is None: + raise ValueError("A slide table is required for tile/slide-level features") + + # Load ground truth based on task + if task == "survival": + if time_label is None or status_label is None: + raise ValueError( + "Both time_label and status_label are required for survival modeling" + ) + patient_to_ground_truth = patient_to_survival_from_clini_table_( + clini_table_path=clini_table, + time_label=time_label, + status_label=status_label, + patient_label=patient_label, + ) + else: + if ground_truth_label is None: + raise ValueError( + "Ground truth label is required for classification or regression modeling" + ) + patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( + clini_table_path=clini_table, + ground_truth_label=ground_truth_label, + patient_label=patient_label, + ) + + # Link slides to patients + slide_to_patient: Final[dict[FeaturePath, PatientId]] = ( + slide_to_patient_from_slide_table_( + slide_table_path=slide_table, + feature_dir=feature_dir, + patient_label=patient_label, + filename_label=filename_label, + ) + ) + + # Filter to complete patient data + patient_to_data = filter_complete_patient_data_( + patient_to_ground_truth=cast( + Mapping[PatientId, GroundTruth | dict[str, GroundTruth] | None], + patient_to_ground_truth, + ), + slide_to_patient=slide_to_patient, + drop_patients_with_missing_ground_truth=drop_patients_with_missing_ground_truth, + ) + elif feature_type == "patient": + patient_to_data = load_patient_level_data( + task=task, + clini_table=clini_table, + feature_dir=feature_dir, + patient_label=patient_label, + ground_truth_label=ground_truth_label, + time_label=time_label, + status_label=status_label, + ) + else: + raise RuntimeError(f"Unknown feature type: {feature_type}") + + return patient_to_data, feature_type + + def log_patient_class_summary( *, patient_to_data: Mapping[PatientId, PatientData], categories: Sequence[Category] | None, - prefix: str = "", ) -> None: + """ + Logs class distribution. + Supports both single-target and multi-target classification. + """ + ground_truths = [ - pd.ground_truth - for pd in patient_to_data.values() - if pd.ground_truth is not None + p.ground_truth for p in patient_to_data.values() if p.ground_truth is not None ] if not ground_truths: - _logger.warning(f"{prefix}No ground truths available to summarize.") + _logger.warning("No ground truths available for summary.") return - cats = categories or sorted(set(ground_truths)) - counter = Counter(ground_truths) + # Multi-target + if isinstance(ground_truths[0], dict): + # Collect per-target values + per_target: dict[str, list] = {} - _logger.info( - f"{prefix}Total patients: {len(ground_truths)} | " - + " | ".join([f"Class {c}: {counter.get(c, 0)}" for c in cats]) - ) + for gt in ground_truths: + for key, value in gt.items(): + per_target.setdefault(key, []).append(value) + + for target_name, values in per_target.items(): + counts = {} + for v in values: + counts[v] = counts.get(v, 0) + 1 + + _logger.info( + f"[Multi-target] Target '{target_name}' distribution: {counts}" + ) + + # Single-target + else: + counts = {} + for gt in ground_truths: + counts[gt] = counts.get(gt, 0) + 1 + + _logger.info(f"Class distribution: {counts}") diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index 905c6005..10272328 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd import torch -from jaxtyping import Float +import torch.nn.functional as F from lightning.pytorch.accelerators.accelerator import Accelerator from stamp.modeling.data import ( @@ -20,7 +20,7 @@ slide_to_patient_from_slide_table_, ) from stamp.modeling.registry import ModelName, load_model_class -from stamp.types import GroundTruth, PandasLabel, PatientId +from stamp.types import Category, GroundTruth, PandasLabel, PatientId __all__ = ["deploy_categorical_model_"] @@ -32,6 +32,13 @@ Logit: TypeAlias = float +# Prediction type aliases +PredictionSingle: TypeAlias = torch.Tensor +PredictionMulti: TypeAlias = dict[str, torch.Tensor] +PredictionsType: TypeAlias = Mapping[ + PatientId, Union[PredictionSingle, PredictionMulti] +] + def load_model_from_ckpt(path: Union[str, Path]): ckpt = torch.load(path, map_location="cpu", weights_only=False) @@ -50,7 +57,7 @@ def deploy_categorical_model_( clini_table: Path | None, slide_table: Path | None, feature_dir: Path, - ground_truth_label: PandasLabel | None, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, time_label: PandasLabel | None, status_label: PandasLabel | None, patient_label: PandasLabel, @@ -126,7 +133,12 @@ def deploy_categorical_model_( # classification/regression: still use ground_truth_label if ( len( - ground_truth_labels := set(model.ground_truth_label for model in models) + ground_truth_labels := { + tuple(model.ground_truth_label) + if isinstance(model.ground_truth_label, list) + else (model.ground_truth_label,) + for model in models + } ) != 1 ): @@ -145,17 +157,21 @@ def deploy_categorical_model_( f"{ground_truth_label} vs {model_ground_truth_label}" ) - ground_truth_label = ground_truth_label or model_ground_truth_label + ground_truth_label = ground_truth_label or cast( + PandasLabel, model_ground_truth_label + ) output_dir.mkdir(exist_ok=True, parents=True) model_categories = None if task == "classification": # Ensure the categories were the same between all models - category_sets = {tuple(m.categories) for m in models} + category_sets = { + tuple(cast(Sequence[GroundTruth], m.categories)) for m in models + } if len(category_sets) != 1: raise RuntimeError(f"Categories differ between models: {category_sets}") - model_categories = list(models[0].categories) + model_categories = list(cast(Sequence[GroundTruth], models[0].categories)) # Data loading logic if feature_type in ("tile", "slide"): @@ -171,6 +187,15 @@ def deploy_categorical_model_( ) if clini_table is not None: if task == "survival": + if not hasattr(models[0], "time_label") or not isinstance( + models[0].time_label, str + ): + raise AttributeError("Model is missing valid 'time_label' (str).") + if not hasattr(models[0], "status_label") or not isinstance( + models[0].status_label, str + ): + raise AttributeError("Model is missing valid 'status_label' (str).") + patient_to_ground_truth = patient_to_survival_from_clini_table_( clini_table_path=clini_table, patient_label=patient_label, @@ -192,7 +217,10 @@ def deploy_categorical_model_( patient_id: None for patient_id in set(slide_to_patient.values()) } patient_to_data = filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, + patient_to_ground_truth=cast( + Mapping[PatientId, GroundTruth | None], + patient_to_ground_truth, + ), slide_to_patient=slide_to_patient, drop_patients_with_missing_ground_truth=False, ) @@ -241,7 +269,10 @@ def deploy_categorical_model_( "regression": _to_regression_prediction_df, "survival": _to_survival_prediction_df, }[task] - all_predictions: list[Mapping[PatientId, Float[torch.Tensor, "category"]]] = [] # noqa: F821 + all_predictions: list[PredictionsType] = [] + categories_for_export: ( + Sequence[Category] | Mapping[str, Sequence[Category]] | None + ) = cast(Sequence[Category] | Mapping[str, Sequence[Category]] | None, None) for model_i, model in enumerate(models): predictions = _predict( model=model, @@ -251,6 +282,26 @@ def deploy_categorical_model_( ) all_predictions.append(predictions) + if isinstance(next(iter(predictions.values())), dict): + # Multi-target case: gather categories across all targets for export (use model categories if available, else infer from GT) + categories_accum: dict[str, set[GroundTruth]] = {} + + for pd_item in patient_to_data.values(): + gt = pd_item.ground_truth + if isinstance(gt, dict): + for k, v in gt.items(): + if v is not None: + categories_accum.setdefault(k, set()).add(v) + + categories_for_export = {k: sorted(v) for k, v in categories_accum.items()} + + else: + # Single-target case: use categories from model if available, else infer from GT + if task == "classification": + categories_for_export = models[0].categories + else: + categories_for_export = [] + # cut-off values from survival ckpt cut_off = ( getattr(model.hparams, "train_pred_median", None) @@ -261,7 +312,7 @@ def deploy_categorical_model_( # Only save individual model files when deploying multiple models (ensemble) if len(models) > 1: df_builder( - categories=model_categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, predictions=predictions, patient_label=patient_label, @@ -270,7 +321,7 @@ def deploy_categorical_model_( ).to_csv(output_dir / f"patient-preds-{model_i}.csv", index=False) else: df_builder( - categories=model_categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, predictions=predictions, patient_label=patient_label, @@ -279,17 +330,29 @@ def deploy_categorical_model_( ).to_csv(output_dir / "patient-preds.csv", index=False) if task == "classification": - # TODO we probably also want to save the 95% confidence interval in addition to the mean + # compute mean prediction across models (supports single- and multi-target) + mean_preds: dict[PatientId, object] = {} + for pid in patient_ids: + model_preds = cast( + list[torch.Tensor], [preds[pid] for preds in all_predictions] + ) + firstp = model_preds[0] + if isinstance(firstp, dict): + # per-target averaging + mean_preds[pid] = { + t: torch.stack([p[t] for p in model_preds]).mean(dim=0) + for t in firstp.keys() + } + else: + mean_preds[pid] = torch.stack(model_preds).mean(dim=0) + + assert categories_for_export is not None, ( + "categories_for_export must be set before use" + ) df_builder( - categories=model_categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, - predictions={ - # Mean prediction - patient_id: torch.stack( - [predictions[patient_id] for predictions in all_predictions] - ).mean(dim=0) - for patient_id in patient_ids - }, + predictions=mean_preds, patient_label=patient_label, ground_truth_label=ground_truth_label, ).to_csv(output_dir / "patient-preds_95_confidence_interval.csv", index=False) @@ -301,7 +364,7 @@ def _predict( test_dl: torch.utils.data.DataLoader, patient_ids: Sequence[PatientId], accelerator: str | Accelerator, -) -> Mapping[PatientId, Float[torch.Tensor, "..."]]: +) -> PredictionsType: model = model.eval() torch.set_float32_matmul_precision("medium") @@ -310,8 +373,11 @@ def _predict( getattr(model, "train_patients", []) ) | set(getattr(model, "valid_patients", [])) if overlap := patients_used_for_training & set(patient_ids): - raise ValueError( - f"some of the patients in the validation set were used during training: {overlap}" + _logger.critical( + "DATA LEAKAGE DETECTED: %d patient(s) in deployment set were used " + "during training/validation. Overlapping IDs: %s", + len(overlap), + sorted(overlap), ) trainer = lightning.Trainer( @@ -320,51 +386,190 @@ def _predict( logger=False, ) - raw_preds = torch.concat(cast(list[torch.Tensor], trainer.predict(model, test_dl))) + outs = trainer.predict(model, test_dl) + + if not outs: + return {} + + first = outs[0] + + # Multi-target case: each element of outs is a dict[target_label -> tensor] + if isinstance(first, dict): + per_target_lists: dict[str, list[torch.Tensor]] = {} + for out in outs: + if not isinstance(out, dict): + raise RuntimeError("Mixed prediction output types from model") + for k, v in out.items(): + per_target_lists.setdefault(k, []).append(v) + + per_target_tensors: dict[str, torch.Tensor] = { + k: torch.cat(vlist, dim=0) for k, vlist in per_target_lists.items() + } + + if getattr(model.hparams, "task", None) == "classification": + for k in list(per_target_tensors.keys()): + per_target_tensors[k] = torch.softmax(per_target_tensors[k], dim=1) + + # build per-patient dicts + num_preds = next(iter(per_target_tensors.values())).shape[0] + predictions: dict[PatientId, dict[str, torch.Tensor]] = {} + for i, pid in enumerate(patient_ids[:num_preds]): + predictions[pid] = { + k: per_target_tensors[k][i] for k in per_target_tensors.keys() + } + + return predictions + + # Single-target case: each element of outs is a tensor + outs_single = cast(list[torch.Tensor], outs) + + raw_preds = torch.cat(outs_single, dim=0) if getattr(model.hparams, "task", None) == "classification": - predictions = torch.softmax(raw_preds, dim=1) + raw_preds = torch.softmax(raw_preds, dim=1) elif getattr(model.hparams, "task", None) == "survival": - predictions = raw_preds.squeeze(-1) # (N,) risk scores - else: # regression - predictions = raw_preds + raw_preds = raw_preds.squeeze(-1) + + result: dict[PatientId, torch.Tensor] = { + pid: raw_preds[i] for i, pid in enumerate(patient_ids) + } - return dict(zip(patient_ids, predictions, strict=True)) + return result def _to_prediction_df( *, - categories: Sequence[GroundTruth], - patient_to_ground_truth: Mapping[PatientId, GroundTruth | None], - predictions: Mapping[PatientId, torch.Tensor], + categories: Sequence[GroundTruth] | Mapping[str, Sequence[GroundTruth]], + patient_to_ground_truth: Mapping[PatientId, GroundTruth | None] + | Mapping[PatientId, dict[str, GroundTruth | None]], + predictions: Mapping[PatientId, torch.Tensor] + | Mapping[PatientId, dict[str, torch.Tensor]], patient_label: PandasLabel, - ground_truth_label: PandasLabel, + ground_truth_label: PandasLabel | Sequence[PandasLabel], **kwargs, ) -> pd.DataFrame: - """Compiles deployment results into a DataFrame.""" - return pd.DataFrame( - [ - { - patient_label: patient_id, - ground_truth_label: patient_to_ground_truth.get(patient_id), - "pred": categories[int(prediction.argmax())], - **{ - f"{ground_truth_label}_{category}": prediction[i_cat].item() - for i_cat, category in enumerate(categories) - }, - "loss": ( - torch.nn.functional.cross_entropy( - prediction.reshape(1, -1), - torch.tensor(np.where(np.array(categories) == ground_truth)[0]), - ).item() - if (ground_truth := patient_to_ground_truth.get(patient_id)) - is not None - else None - ), - } - for patient_id, prediction in predictions.items() - ] - ).sort_values(by="loss") + """Compiles deployment results into a DataFrame. + + Supports single-target and multi-target classification. + - Single-target: `predictions` maps patient -> tensor and `categories` is a sequence. + - Multi-target: `predictions` maps patient -> dict[target_label -> tensor] and + `categories` is a mapping from target_label -> sequence of category names. + """ + first_pred = next(iter(predictions.values())) + + # Multi-target predictions: dict per patient + if isinstance(first_pred, dict): + # determine target labels + target_labels = list(cast(dict, first_pred).keys()) + + # prepare categories mapping + if isinstance(categories, dict): + cats_map = categories + else: + # try infer categories list ordering: assume categories is a sequence-of-sequences + cats_map = {} + if isinstance(categories, Sequence): + try: + for i, t in enumerate(target_labels): + cats_map[t] = list( + cast(Sequence[Sequence[GroundTruth]], categories)[i] + ) + except Exception: + cats_map = {} + + # infer missing category lists from ground truth + if any(t not in cats_map for t in target_labels): + inferred: dict[str, set] = {t: set() for t in target_labels} + for pid, gt in patient_to_ground_truth.items(): + if isinstance(gt, dict): + for t in target_labels: + val = gt.get(t) + if val is not None: + inferred[t].add(val) + for t in target_labels: + if t not in cats_map: + cats_map[t] = sorted(inferred.get(t, [])) + + rows = [] + for pid, pred_dict in predictions.items(): + row: dict = {patient_label: pid} + gt_entry = patient_to_ground_truth.get(pid) + # ground truths per target + for t in target_labels: + if isinstance(gt_entry, dict): + row[t] = gt_entry.get(t) + else: + row[t] = gt_entry + + total_loss = 0.0 + has_loss = False + for t in target_labels: + tensor = cast(dict[str, torch.Tensor], pred_dict)[t] + probs = tensor.detach().cpu() + cats: Sequence[GroundTruth] = cast( + Sequence[GroundTruth], + cats_map.get(t, []), + ) + if probs.numel() == 1: + row[f"pred_{t}"] = float(probs.item()) + else: + pred_idx = int(probs.argmax().item()) + row[f"pred_{t}"] = ( + cats[pred_idx] if pred_idx < len(cats) else pred_idx + ) + for i_cat, cat in enumerate(cats): + if i_cat < probs.shape[0]: + row[f"{t}_{cat}"] = float(probs[i_cat].item()) + else: + row[f"{t}_{cat}"] = None + + if isinstance(gt_entry, dict) and (gt := gt_entry.get(t)) is not None: + try: + target_index = int(np.where(np.array(cats) == gt)[0][0]) + loss = torch.nn.functional.cross_entropy( + probs.reshape(1, -1), torch.tensor([target_index]) + ).item() + total_loss += loss + has_loss = True + except Exception: + pass + + row["loss"] = total_loss if has_loss else None + rows.append(row) + + return pd.DataFrame(rows) + + # Single-target (original behaviour) + if not all(isinstance(p, torch.Tensor) for p in predictions.values()): + raise TypeError("Single-target block received multi-target dict predictions.") + + predictions = cast(Mapping[PatientId, torch.Tensor], predictions) + + rows = [] + for pid, prediction in predictions.items(): + gt = patient_to_ground_truth.get(pid) + cats = cast(Sequence[GroundTruth], categories) + pred_idx = int(prediction.argmax()) + row = { + patient_label: pid, + ground_truth_label: gt, + "pred": cats[pred_idx], + **{ + f"{ground_truth_label}_{category}": float(prediction[i_cat].item()) + for i_cat, category in enumerate(cats) + }, + "loss": ( + torch.nn.functional.cross_entropy( + prediction.reshape(1, -1), + torch.tensor(np.where(np.array(cats) == gt)[0]), + ).item() + if gt is not None + else None + ), + } + rows.append(row) + + return pd.DataFrame(rows).sort_values(by="loss") def _to_regression_prediction_df( @@ -383,8 +588,6 @@ def _to_regression_prediction_df( - pred (float) - loss (per-sample L1 loss if GT available, else None) """ - import torch.nn.functional as F - return pd.DataFrame( [ { diff --git a/src/stamp/modeling/models/__init__.py b/src/stamp/modeling/models/__init__.py index 59a0a3aa..86003c52 100644 --- a/src/stamp/modeling/models/__init__.py +++ b/src/stamp/modeling/models/__init__.py @@ -3,7 +3,7 @@ import inspect from abc import ABC from collections.abc import Iterable, Sequence -from typing import Any, TypeAlias +from typing import Any, Mapping, TypeAlias import lightning import numpy as np @@ -14,6 +14,11 @@ from torchmetrics.classification import MulticlassAUROC import stamp +from stamp.modeling.models.barspoon import ( + EncDecTransformer, + LitMilClassificationMixin, + TargetLabel, +) from stamp.modeling.models.cox import neg_partial_log_likelihood from stamp.types import ( Bags, @@ -818,3 +823,75 @@ class LitPatientSurvival(LitSlideSurvival): """ supported_features = ["patient"] + + +class LitEncDecTransformer(LitMilClassificationMixin): + def __init__( + self, + *, + dim_input: int, + category_weights: Mapping[TargetLabel, torch.Tensor], + model_class: type[nn.Module] | None = None, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, + categories: Mapping[str, Sequence[Category]], + # Model parameters + d_model: int = 512, + num_encoder_heads: int = 8, + num_decoder_heads: int = 8, + num_encoder_layers: int = 2, + num_decoder_layers: int = 2, + dim_feedforward: int = 2048, + positional_encoding: bool = True, + # Other hparams + learning_rate: float = 1e-4, + **hparams: Any, + ) -> None: + weights_dict: dict[TargetLabel, torch.Tensor] = dict(category_weights) + super().__init__( + weights=weights_dict, + learning_rate=learning_rate, + ) + _ = hparams # so we don't get unused parameter warnings + + self.model = EncDecTransformer( + d_features=dim_input, + target_n_outs={t: len(w) for t, w in category_weights.items()}, + d_model=d_model, + num_encoder_heads=num_encoder_heads, + num_decoder_heads=num_decoder_heads, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dim_feedforward=dim_feedforward, + positional_encoding=positional_encoding, + ) + + self.hparams["supported_features"] = "tile" + self.hparams.update({"task": "classification"}) + # ---- Normalize categories into strict mapping[str, list[str]] ---- + if not isinstance(categories, Mapping): + raise ValueError( + "Multi-target classification requires categories as Mapping[str, Sequence[str]]." + ) + + normalized_categories: dict[str, list[str]] = { + str(k): list(v) for k, v in categories.items() + } + + # Sanity check: head size must match category size + for t, w in category_weights.items(): + if t not in normalized_categories: + raise ValueError(f"Missing categories for target '{t}'") + if len(normalized_categories[t]) != len(w): + raise ValueError( + f"Category mismatch for target '{t}': " + f"{len(normalized_categories[t])} categories " + f"but head has {len(w)} outputs." + ) + + self.ground_truth_label = ground_truth_label + self.categories = normalized_categories + + self.save_hyperparameters() + + def forward(self, *args): + return self.model(*args) diff --git a/src/stamp/modeling/models/barspoon.py b/src/stamp/modeling/models/barspoon.py new file mode 100644 index 00000000..f841bb3d --- /dev/null +++ b/src/stamp/modeling/models/barspoon.py @@ -0,0 +1,367 @@ +""" +Port from https://github.com/KatherLab/barspoon-transformer +""" + +import re +from typing import Any, TypeAlias + +import lightning +import torch +import torch.nn.functional as F +import torchmetrics +from packaging.version import Version +from torch import nn +from torchmetrics.classification import MulticlassAUROC +from torchmetrics.utilities.data import dim_zero_cat + +import stamp +from stamp.types import Bags, BagSizes, CoordinatesBatch + +__all__ = [ + "EncDecTransformer", + "LitMilClassificationMixin", + "SafeMulticlassAUROC", +] + + +TargetLabel: TypeAlias = str + + +class EncDecTransformer(nn.Module): + """An encoder decoder architecture for multilabel classification tasks + + This architecture is a modified version of the one found in [Attention Is + All You Need][1]: First, we project the features into a lower-dimensional + feature space, to prevent the transformer architecture's complexity from + exploding for high-dimensional features. We add sinusodial [positional + encodings][1]. We then encode these projected input tokens using a + transformer encoder stack. Next, we decode these tokens using a set of + class tokens, one per output label. Finally, we forward each of the decoded + tokens through a fully connected layer to get a label-wise prediction. + + PE1 + | + +--+ v +---+ + t1 ->|FC|--+-->| |--+ + . +--+ | E | | + . | x | | + . +--+ | m | | + tn ->|FC|--+-->| |--+ + +--+ ^ +---+ | + | | + PEn v + +---+ +---+ + c1 ---------------->| |-->|FC1|--> s1 + . | D | +---+ . + . | x | . + . | l | +---+ . + ck ---------------->| |-->|FCk|--> sk + +---+ +---+ + + We opted for this architecture instead of a more traditional [Vision + Transformer][2] to improve performance for multi-label predictions with many + labels. Our experiments have shown that adding too many class tokens to a + vision transformer decreases its performance, as the same weights have to + both process the tiles' information and the class token's processing. Using + an encoder-decoder architecture alleviates these issues, as the data-flow of + the class tokens is completely independent of the encoding of the tiles. + Furthermore, analysis has shown that there is almost no interaction between + the different classes in the decoder. While this points to the decoder + being more powerful than needed in practice, this also means that each + label's prediction is mostly independent of the others. As a consequence, + noisy labels will not negatively impact the accuracy of non-noisy ones. + + In our experiments so far we did not see any improvement by adding + positional encodings. We tried + + 1. [Sinusodal encodings][1] + 2. Adding absolute positions to the feature vector, scaled down so the + maximum value in the training dataset is 1. + + Since neither reduced performance and the author percieves the first one to + be more elegant (as the magnitude of the positional encodings is bounded), + we opted to keep the positional encoding regardless in the hopes of it + improving performance on future tasks. + + The architecture _differs_ from the one descibed in [Attention Is All You + Need][1] as follows: + + 1. There is an initial projection stage to reduce the dimension of the + feature vectors and allow us to use the transformer with arbitrary + features. + 2. Instead of the language translation task described in [Attention Is All + You Need][1], where the tokens of the words translated so far are used + to predict the next word in the sequence, we use a set of fixed, learned + class tokens in conjunction with equally as many independent fully + connected layers to predict multiple labels at once. + + [1]: https://arxiv.org/abs/1706.03762 "Attention Is All You Need" + [2]: https://arxiv.org/abs/2010.11929 + "An Image is Worth 16x16 Words: + Transformers for Image Recognition at Scale" + """ + + def __init__( + self, + d_features: int, + target_n_outs: dict[str, int], + *, + d_model: int = 512, + num_encoder_heads: int = 8, + num_decoder_heads: int = 8, + num_encoder_layers: int = 2, + num_decoder_layers: int = 2, + dim_feedforward: int = 2048, + positional_encoding: bool = True, + ) -> None: + super().__init__() + + self.projector = nn.Sequential(nn.Linear(d_features, d_model), nn.ReLU()) + + encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=num_encoder_heads, + dim_feedforward=dim_feedforward, + batch_first=True, + norm_first=True, + ) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layer, num_layers=num_encoder_layers, enable_nested_tensor=False + ) + + self.target_labels = target_n_outs.keys() + + # One class token per output label + self.class_tokens = nn.ParameterDict( + { + sanitize(target_label): torch.rand(d_model) + for target_label in target_n_outs + } + ) + + decoder_layer = nn.TransformerDecoderLayer( + d_model=d_model, + nhead=num_decoder_heads, + dim_feedforward=dim_feedforward, + batch_first=True, + norm_first=True, + ) + self.transformer_decoder = nn.TransformerDecoder( + decoder_layer, num_layers=num_decoder_layers + ) + + self.heads = nn.ModuleDict( + { + sanitize(target_label): nn.Linear( + in_features=d_model, out_features=n_out + ) + for target_label, n_out in target_n_outs.items() + } + ) + + self.positional_encoding = positional_encoding + + def forward( + self, + tile_tokens: torch.Tensor, + tile_positions: torch.Tensor, + ) -> dict[str, torch.Tensor]: + batch_size, _, _ = tile_tokens.shape + + tile_tokens = self.projector(tile_tokens) # shape: [bs, seq_len, d_model] + + if self.positional_encoding: + # Add positional encodings + d_model = tile_tokens.size(-1) + x = tile_positions.unsqueeze(-1) / 100_000 ** ( + torch.arange(d_model // 4).type_as(tile_positions) / d_model + ) + positional_encodings = torch.cat( + [ + torch.sin(x).flatten(start_dim=-2), + torch.cos(x).flatten(start_dim=-2), + ], + dim=-1, + ) + tile_tokens = tile_tokens + positional_encodings + + tile_tokens = self.transformer_encoder(tile_tokens) + + class_tokens = torch.stack( + [self.class_tokens[sanitize(t)] for t in self.target_labels] + ).expand(batch_size, -1, -1) + class_tokens = self.transformer_decoder(tgt=class_tokens, memory=tile_tokens) + + # Apply the corresponding head to each class token + logits = { + target_label: self.heads[sanitize(target_label)](class_token) + for target_label, class_token in zip( + self.target_labels, + class_tokens.permute(1, 0, 2), # Permute to [target, batch, d_model] + strict=True, + ) + } + + return logits + + +class LitMilClassificationMixin(lightning.LightningModule): + """Makes a module into a multilabel, multiclass Lightning one""" + + supported_features = ["tile"] + + def __init__( + self, + *, + weights: dict[TargetLabel, torch.Tensor], + # Other hparams + learning_rate: float = 1e-4, + stamp_version: Version = Version(stamp.__version__), + **hparams: Any, + ) -> None: + super().__init__() + _ = hparams # So we don't get unused parameter warnings + + # Check if version is compatible. + if stamp_version < Version("2.4.0"): + # Update this as we change our model in incompatible ways! + raise ValueError( + f"model has been built with stamp version {stamp_version} " + f"which is incompatible with the current version." + ) + elif stamp_version > Version(stamp.__version__): + # Let's be strict with models "from the future", + # better fail deadly than have broken results. + raise ValueError( + "model has been built with a stamp version newer than the installed one " + f"({stamp_version} > {stamp.__version__}). " + "Please upgrade stamp to a compatible version." + ) + + self.hparams.update({"task": "classification"}) + + self.learning_rate = learning_rate + + target_aurocs = torchmetrics.MetricCollection( + { + sanitize(target_label): SafeMulticlassAUROC(num_classes=len(weight)) + for target_label, weight in weights.items() + } + ) + for step_name in ["train", "validation", "test"]: + setattr( + self, + f"{step_name}_target_aurocs", + target_aurocs.clone(prefix=f"{step_name}_"), + ) + + self.weights = weights + + self.save_hyperparameters() + + def step( + self, + batch: tuple[Bags, CoordinatesBatch, BagSizes, dict[str, torch.Tensor]], + step_name=None, + ): + """Process a batch with structure (feats, coords, bag_sizes, targets). + + Args: + batch: Tuple of (feats, coords, bag_sizes, targets) where: + - feats: bag features [batch, bag_size, feature_dim] + - coords: tile coordinates [batch, bag_size, 2] + - bag_sizes: number of tiles per bag [batch] + - targets: dict mapping target names to one-hot encoded tensors [batch, num_classes] + step_name: Optional step name for logging ('train', 'validation', 'test'). + """ + feats: Bags + coords: CoordinatesBatch + bag_sizes: BagSizes + targets: dict[str, torch.Tensor] + feats, coords, bag_sizes, targets = batch + logits = self(feats, coords) + + # Calculate the cross entropy loss for each target, then sum them + loss = sum( + F.cross_entropy( + (logit := logits[target_label]), + targets[target_label].type_as(logit), + weight=weight.type_as(logit), + ) + for target_label, weight in self.weights.items() + ) + + if step_name: + self.log( + f"{step_name}_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + + # Update target-wise metrics + for target_label in self.weights: + target_auroc = getattr(self, f"{step_name}_target_aurocs")[ + sanitize(target_label) + ] + is_na = (targets[target_label] == 0).all(dim=1) + target_auroc.update( + logits[target_label][~is_na], + targets[target_label][~is_na].argmax(dim=1), + ) + self.log( + f"{step_name}_{target_label}_auroc", + target_auroc, + on_step=False, + on_epoch=True, + sync_dist=True, + ) + + return loss + + def training_step(self, batch, batch_idx): # pyright: ignore[reportIncompatibleMethodOverride] + return self.step(batch, step_name="train") + + def validation_step(self, batch, batch_idx): # pyright: ignore[reportIncompatibleMethodOverride] + return self.step(batch, step_name="validation") + + def test_step(self, batch, batch_idx): # pyright: ignore[reportIncompatibleMethodOverride] + return self.step(batch, step_name="test") + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + if len(batch) == 2: + feats, positions = batch + else: + feats, positions, _, _ = batch + + logits = self(feats, positions) + + softmaxed = { + target_label: torch.softmax(x, 1) for target_label, x in logits.items() + } + return softmaxed + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + return optimizer + + +def sanitize(x: str) -> str: + return re.sub(r"[^A-Za-z0-9_]", "_", x) + + +class SafeMulticlassAUROC(MulticlassAUROC): + """A Multiclass AUROC that doesn't blow up when no targets are given""" + + def compute(self) -> torch.Tensor: + # Add faux entry if there are none so far + if len(self.preds) == 0: + self.update(torch.zeros(1, self.num_classes), torch.zeros(1).long()) + elif len(dim_zero_cat(self.preds)) == 0: + self.update( + torch.zeros(1, self.num_classes).type_as(self.preds[0]), + torch.zeros(1).long().type_as(self.target[0]), + ) + return super().compute() diff --git a/src/stamp/modeling/registry.py b/src/stamp/modeling/registry.py index 2205af22..14011084 100644 --- a/src/stamp/modeling/registry.py +++ b/src/stamp/modeling/registry.py @@ -1,6 +1,7 @@ from enum import StrEnum from stamp.modeling.models import ( + LitEncDecTransformer, LitPatientClassifier, LitPatientRegressor, LitPatientSurvival, @@ -21,6 +22,7 @@ class ModelName(StrEnum): MLP = "mlp" TRANS_MIL = "trans_mil" LINEAR = "linear" + BARSPOON = "barspoon" # Map (feature_type, task) → correct Lightning wrapper class @@ -34,6 +36,7 @@ class ModelName(StrEnum): ("patient", "classification"): LitPatientClassifier, ("patient", "regression"): LitPatientRegressor, ("patient", "survival"): LitPatientSurvival, + # ("tile", "multiclass"): LitEncDecTransformer, } @@ -54,6 +57,13 @@ def load_model_class(task: Task, feature_type: str, model_name: ModelName): case ModelName.MLP: from stamp.modeling.models.mlp import MLP as ModelClass + case ModelName.BARSPOON: + from stamp.modeling.models.barspoon import ( + EncDecTransformer as ModelClass, + ) + + LitModelClass = LitEncDecTransformer + case ModelName.LINEAR: from stamp.modeling.models.mlp import ( Linear as ModelClass, diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index b855e2a0..fb2bdb3b 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -2,7 +2,7 @@ import shutil from collections.abc import Callable, Mapping, Sequence from pathlib import Path -from typing import cast +from typing import Any, cast import lightning import torch @@ -18,13 +18,7 @@ PatientData, PatientFeatureDataset, create_dataloader, - detect_feature_type, - filter_complete_patient_data_, - load_patient_level_data, - log_patient_class_summary, - patient_to_ground_truth_from_clini_table_, - patient_to_survival_from_clini_table_, - slide_to_patient_from_slide_table_, + load_patient_data_, ) from stamp.modeling.registry import ModelName, load_model_class from stamp.modeling.transforms import VaryPrecisionTransform @@ -53,66 +47,25 @@ def train_categorical_model_( advanced: AdvancedConfig, ) -> None: """Trains a model based on the feature type.""" - feature_type = detect_feature_type(config.feature_dir) - _logger.info(f"Detected feature type: {feature_type}") - - if feature_type in ("tile", "slide"): - if config.slide_table is None: - raise ValueError("A slide table is required for modeling") - if config.task == "survival": - if config.time_label is None or config.status_label is None: - raise ValueError( - "Both time_label and status_label is required for survival modeling" - ) - patient_to_ground_truth = patient_to_survival_from_clini_table_( - clini_table_path=config.clini_table, - time_label=config.time_label, - status_label=config.status_label, - patient_label=config.patient_label, - ) - else: - if config.ground_truth_label is None: - raise ValueError( - "Ground truth label is required for tile-level modeling" - ) - patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( - clini_table_path=config.clini_table, - ground_truth_label=config.ground_truth_label, - patient_label=config.patient_label, - ) - slide_to_patient = slide_to_patient_from_slide_table_( - slide_table_path=config.slide_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - filename_label=config.filename_label, - ) - patient_to_data = filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, - slide_to_patient=slide_to_patient, - drop_patients_with_missing_ground_truth=True, - ) - elif feature_type == "patient": - # Patient-level: ignore slide_table - if config.slide_table is not None: - _logger.warning("slide_table is ignored for patient-level features.") - - patient_to_data = load_patient_level_data( - task=config.task, - clini_table=config.clini_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - ground_truth_label=config.ground_truth_label, - time_label=config.time_label, - status_label=config.status_label, - ) - else: - raise RuntimeError(f"Unknown feature type: {feature_type}") - if config.task is None: raise ValueError( "task must be set to 'classification' | 'regression' | 'survival'" ) + patient_to_data, feature_type = load_patient_data_( + feature_dir=config.feature_dir, + clini_table=config.clini_table, + slide_table=config.slide_table, + task=config.task, + ground_truth_label=config.ground_truth_label, + time_label=config.time_label, + status_label=config.status_label, + patient_label=config.patient_label, + filename_label=config.filename_label, + drop_patients_with_missing_ground_truth=True, + ) + _logger.info(f"Detected feature type: {feature_type}") + # Train the model (the rest of the logic is unchanged) model, train_dl, valid_dl = setup_model_for_training( patient_to_data=patient_to_data, @@ -145,14 +98,14 @@ def train_categorical_model_( def setup_model_for_training( *, - patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], + patient_to_data: Mapping[PatientId, PatientData[GroundTruth | None | dict]], task: Task, categories: Sequence[Category] | None, train_transform: Callable[[torch.Tensor], torch.Tensor] | None, feature_type: str, advanced: AdvancedConfig, # Metadata, has no effect on model training - ground_truth_label: PandasLabel | None, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, time_label: PandasLabel | None, status_label: PandasLabel | None, clini_table: Path, @@ -193,10 +146,6 @@ def setup_model_for_training( feature_type=feature_type, train_categories=train_categories, ) - log_patient_class_summary( - patient_to_data={pid: patient_to_data[pid] for pid in patient_to_data}, - categories=categories, - ) # 1. Default to a model if none is specified if advanced.model_name is None: @@ -209,6 +158,7 @@ def setup_model_for_training( LitModelClass, ModelClass = load_model_class( task, feature_type, advanced.model_name ) + print(f"Using Lightning wrapper class: {LitModelClass}") # 3. Validate that the chosen model supports the feature type if feature_type not in LitModelClass.supported_features: @@ -272,7 +222,7 @@ def setup_model_for_training( def setup_dataloaders_for_training( *, - patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], + patient_to_data: Mapping[PatientId, PatientData[GroundTruth | None | dict]], task: Task, categories: Sequence[Category] | None, bag_size: int, @@ -283,7 +233,7 @@ def setup_dataloaders_for_training( ) -> tuple[ DataLoader, DataLoader, - Sequence[Category], + Sequence[Category] | Mapping[str, Sequence[Category]], int, Sequence[PatientId], Sequence[PatientId], @@ -310,10 +260,25 @@ def setup_dataloaders_for_training( ) if task == "classification": - stratify = ground_truths + # Handle both single and multi-target cases + if ground_truths and isinstance(ground_truths[0], dict): + # Multi-target: use first target for stratification + first_key = list(ground_truths[0].keys())[0] + stratify = [cast(dict, gt)[first_key] for gt in ground_truths] + else: + stratify = ground_truths elif task == "survival": - # Extract event indicator (status) - statuses = [int(gt.split()[1]) for gt in ground_truths] + # Extract event indicator (status) - handle both single and multi-target + statuses = [] + for gt in ground_truths: + if isinstance(gt, dict): + # Multi-target survival: extract from first target + first_key = list(gt.keys())[0] + val = cast(dict, gt)[first_key] + if val: + statuses.append(int(val.split()[1])) + else: + statuses.append(int(gt.split()[1])) stratify = statuses elif task == "regression": stratify = None @@ -321,7 +286,10 @@ def setup_dataloaders_for_training( train_patients, valid_patients = cast( tuple[Sequence[PatientId], Sequence[PatientId]], train_test_split( - list(patient_to_data), stratify=stratify, shuffle=True, random_state=0 + list(patient_to_data), + stratify=cast(Any, stratify), + shuffle=True, + random_state=0, ), ) @@ -441,28 +409,50 @@ def _compute_class_weights_and_check_categories( *, train_dl: DataLoader, feature_type: str, - train_categories: Sequence[str], -) -> torch.Tensor: + train_categories: Sequence[str] | Mapping[str, Sequence[str]], +) -> torch.Tensor | dict[str, torch.Tensor]: """ Computes class weights and checks for category issues. Logs warnings if there are too few or underpopulated categories. Returns normalized category weights as a torch.Tensor. """ if feature_type == "tile": - category_counts = cast(BagDataset, train_dl.dataset).ground_truths.sum(dim=0) + dataset = cast(BagDataset, train_dl.dataset) + + if isinstance(dataset.ground_truths, list): + # Multi-target case: compute weights per target head + weights_per_target: dict[str, torch.Tensor] = {} + + target_keys = dataset.ground_truths[0].keys() + + for key in target_keys: + stacked = torch.stack([gt[key] for gt in dataset.ground_truths], dim=0) + counts = stacked.sum(dim=0) + w = counts.sum() / counts + weights_per_target[key] = w / w.sum() + + return weights_per_target + else: + category_counts = dataset.ground_truths.sum(dim=0) else: - category_counts = cast( - PatientFeatureDataset, train_dl.dataset - ).ground_truths.sum(dim=0) + dataset = cast(PatientFeatureDataset, train_dl.dataset) + category_counts = dataset.ground_truths.sum(dim=0) cat_ratio_reciprocal = category_counts.sum() / category_counts category_weights = cat_ratio_reciprocal / cat_ratio_reciprocal.sum() if len(train_categories) <= 1: raise ValueError(f"not enough categories to train on: {train_categories}") - elif any(category_counts < 16): + elif (category_counts < 16).any(): + category_counts_list = ( + category_counts.tolist() + if category_counts.dim() > 0 + else [category_counts.item()] + ) underpopulated_categories = { category: int(count) - for category, count in zip(train_categories, category_counts, strict=True) + for category, count in zip( + train_categories, category_counts_list, strict=True + ) if count < 16 } _logger.warning( diff --git a/src/stamp/preprocessing/__init__.py b/src/stamp/preprocessing/__init__.py index a1844526..84cb48e9 100755 --- a/src/stamp/preprocessing/__init__.py +++ b/src/stamp/preprocessing/__init__.py @@ -16,7 +16,6 @@ from tqdm import tqdm import stamp -from stamp.cache import get_processing_code_hash from stamp.preprocessing.config import ExtractorName from stamp.preprocessing.extractor import Extractor from stamp.preprocessing.tiling import ( @@ -32,6 +31,7 @@ SlidePixels, TilePixels, ) +from stamp.utils.cache import get_processing_code_hash __author__ = "Marko van Treeck" __copyright__ = "Copyright (C) 2022-2024 Marko van Treeck" diff --git a/src/stamp/preprocessing/extractor/chief_ctranspath.py b/src/stamp/preprocessing/extractor/chief_ctranspath.py index 2d2e6b9b..03f5bba6 100644 --- a/src/stamp/preprocessing/extractor/chief_ctranspath.py +++ b/src/stamp/preprocessing/extractor/chief_ctranspath.py @@ -1,6 +1,6 @@ from pathlib import Path -from stamp.cache import STAMP_CACHE_DIR, file_digest +from stamp.utils.cache import STAMP_CACHE_DIR, file_digest try: import gdown diff --git a/src/stamp/preprocessing/extractor/ctranspath.py b/src/stamp/preprocessing/extractor/ctranspath.py index ba9a277c..387e7947 100644 --- a/src/stamp/preprocessing/extractor/ctranspath.py +++ b/src/stamp/preprocessing/extractor/ctranspath.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Optional, TypeVar, cast -from stamp.cache import STAMP_CACHE_DIR, file_digest +from stamp.utils.cache import STAMP_CACHE_DIR, file_digest try: import gdown diff --git a/src/stamp/preprocessing/extractor/dinobloom.py b/src/stamp/preprocessing/extractor/dinobloom.py index fb7713ce..54b79c53 100644 --- a/src/stamp/preprocessing/extractor/dinobloom.py +++ b/src/stamp/preprocessing/extractor/dinobloom.py @@ -8,9 +8,9 @@ from torch import nn from torchvision import transforms -from stamp.cache import STAMP_CACHE_DIR from stamp.preprocessing.config import ExtractorName from stamp.preprocessing.extractor import Extractor +from stamp.utils.cache import STAMP_CACHE_DIR __author__ = "Marko van Treeck" __copyright__ = "Copyright (C) 2022-2025 Marko van Treeck" diff --git a/src/stamp/preprocessing/tiling.py b/src/stamp/preprocessing/tiling.py index 82a3efba..e09a06fa 100644 --- a/src/stamp/preprocessing/tiling.py +++ b/src/stamp/preprocessing/tiling.py @@ -461,7 +461,7 @@ def _extract_mpp_from_metadata(slide: openslide.AbstractSlide) -> SlideMPP | Non return None doc = minidom.parseString(xml_path) collection = doc.documentElement - images = collection.getElementsByTagName("Image") + images = collection.getElementsByTagName("Image") # pyright: ignore[reportOptionalMemberAccess] pixels = images[0].getElementsByTagName("Pixels") mpp = float(pixels[0].getAttribute("PhysicalSizeX")) except Exception: diff --git a/src/stamp/statistics/__init__.py b/src/stamp/statistics/__init__.py index ec09e1e0..bdbef1fa 100644 --- a/src/stamp/statistics/__init__.py +++ b/src/stamp/statistics/__init__.py @@ -1,3 +1,11 @@ +"""Statistics utilities (wrappers) for classification, regression and survival. + +This module provides a small, stable wrapper `compute_stats_` that dispatches +to the task-specific statistic implementations found in the submodules. +""" + +from __future__ import annotations + from collections.abc import Sequence from pathlib import Path from typing import NewType @@ -17,23 +25,25 @@ plot_multiple_decorated_roc_curves, plot_single_decorated_roc_curve, ) -from stamp.statistics.survival import ( - _plot_km, - _survival_stats_for_csv, -) +from stamp.statistics.survival import _plot_km, _survival_stats_for_csv from stamp.types import PandasLabel, Task +__all__ = ["StatsConfig", "compute_stats_"] + + __author__ = "Marko van Treeck, Minh Duc Nguyen" __copyright__ = "Copyright (C) 2022-2024 Marko van Treeck, Minh Duc Nguyen" __license__ = "MIT" def _read_table(file: Path, **kwargs) -> pd.DataFrame: - """Loads a dataframe from a file.""" + """Load a dataframe from CSV or XLSX file path. + + This small helper centralizes file IO formatting and keeps callers simple. + """ if isinstance(file, Path) and file.suffix == ".xlsx": return pd.read_excel(file, **kwargs) - else: - return pd.read_csv(file, **kwargs) + return pd.read_csv(file, **kwargs) class StatsConfig(BaseModel): @@ -60,6 +70,11 @@ def compute_stats_( time_label: str | None = None, status_label: str | None = None, ) -> None: + """Compute and save statistics for the provided task and prediction CSVs. + + This wrapper keeps the external API stable while delegating the detailed + computations and plotting to the submodules under `stamp.statistics.*`. + """ match task: case "classification": if true_class is None or ground_truth_label is None: @@ -105,7 +120,6 @@ def compute_stats_( n_bootstrap_samples=n_bootstrap_samples, threshold_cmap=threshold_cmap, ) - else: plot_multiple_decorated_roc_curves( ax=ax, @@ -116,9 +130,7 @@ def compute_stats_( ) fig.tight_layout() - if not output_dir.exists(): - output_dir.mkdir(parents=True, exist_ok=True) - + output_dir.mkdir(parents=True, exist_ok=True) fig.savefig(output_dir / f"roc-curve_{ground_truth_label}={true_class}.svg") plt.close(fig) @@ -134,7 +146,6 @@ def compute_stats_( title=f"{ground_truth_label} = {true_class}", n_bootstrap_samples=n_bootstrap_samples, ) - else: plot_multiple_decorated_precision_recall_curves( ax=ax, @@ -203,12 +214,7 @@ def compute_stats_( cut_off=cut_off, ) - # ------------------------------------------------------------------ # # Save individual and aggregated CSVs - # ------------------------------------------------------------------ # stats_df = pd.DataFrame(per_fold).transpose() stats_df.index.name = "fold_name" # label the index column stats_df.to_csv(output_dir / "survival-stats_individual.csv", index=True) - - # agg_df = _aggregate_with_ci(stats_df) - # agg_df.to_csv(output_dir / "survival-stats_aggregated.csv", index=True) diff --git a/src/stamp/statistics/survival.py b/src/stamp/statistics/survival.py index 063793cf..7c298a54 100644 --- a/src/stamp/statistics/survival.py +++ b/src/stamp/statistics/survival.py @@ -46,7 +46,7 @@ def _survival_stats_for_csv( if risk_label is None: risk_label = "pred_score" - # --- Clean NaNs and invalid events before computing stats --- + # Clean NaNs and invalid events before computing stats df = df.dropna(subset=[time_label, status_label, risk_label]).copy() df = df[df[status_label].isin([0, 1])] if len(df) == 0: @@ -56,10 +56,10 @@ def _survival_stats_for_csv( event = np.asarray(df[status_label], dtype=int) risk = np.asarray(df[risk_label], dtype=float) - # --- Concordance index --- + # Concordance index c_index, n_pairs = _cindex(time, event, risk) - # --- Log-rank test (median split) --- + # Log-rank test (median split) median_risk = float(cut_off) if cut_off is not None else float(np.nanmedian(risk)) low_mask = risk <= median_risk high_mask = risk > median_risk @@ -101,7 +101,7 @@ def _plot_km( if risk_label is None: risk_label = "pred_score" - # --- Clean NaNs and invalid entries --- + # Clean NaNs and invalid entries df = df.replace(["NaN", "nan", "None", "Inf", "inf"], np.nan) df = df.dropna(subset=[time_label, status_label, risk_label]).copy() df = df[df[status_label].isin([0, 1])] @@ -113,7 +113,7 @@ def _plot_km( event = np.asarray(df[status_label], dtype=int) risk = np.asarray(df[risk_label], dtype=float) - # --- split groups --- + # split groups median_risk = float(cut_off) if cut_off is not None else np.nanmedian(risk) low_mask = risk <= median_risk high_mask = risk > median_risk @@ -138,7 +138,7 @@ def _plot_km( add_at_risk_counts(kmf_low, kmf_high, ax=ax) - # --- log-rank and c-index --- + # log-rank and c-index res = logrank_test( low_df[time_label], high_df[time_label], diff --git a/src/stamp/types.py b/src/stamp/types.py index f1f571cc..c1ff6873 100644 --- a/src/stamp/types.py +++ b/src/stamp/types.py @@ -37,6 +37,7 @@ PatientId: TypeAlias = str GroundTruth: TypeAlias = str +MultiClassGroundTruth: TypeAlias = tuple[str, ...] FeaturePath = NewType("FeaturePath", Path) Category: TypeAlias = str diff --git a/src/stamp/cache.py b/src/stamp/utils/cache.py similarity index 100% rename from src/stamp/cache.py rename to src/stamp/utils/cache.py diff --git a/src/stamp/config.py b/src/stamp/utils/config.py similarity index 100% rename from src/stamp/config.py rename to src/stamp/utils/config.py diff --git a/src/stamp/seed.py b/src/stamp/utils/seed.py similarity index 100% rename from src/stamp/seed.py rename to src/stamp/utils/seed.py diff --git a/src/stamp/utils/target_file.py b/src/stamp/utils/target_file.py new file mode 100644 index 00000000..08c30b6b --- /dev/null +++ b/src/stamp/utils/target_file.py @@ -0,0 +1,351 @@ +"""Automatically generate target information from clini table + +# The `barspoon-targets 2.0` File Format + +A barspoon target file is a [TOML][1] file with the following entries: + + - A `version` key mapping to a version string `"barspoon-targets "`, where + `` is a [PEP-440 version string][2] compatible with `2.0`. + - A `targets` table, the keys of which are target labels (as found in the + clinical table) and the values specify exactly one of the following: + 1. A categorical target label, marked by the presence of a `categories` + key-value pair. + 2. A target label to quantize, marked by the presence of a `thresholds` + key-value pair. + 3. A target format defined in in a later version of barspoon targets. + A target may only ever have one of the fields `categories` or `thresholds`. + A definition of these entries can be found below. + +[1]: https://toml.io "Tom's Obvious Minimal Language" +[2]: https://peps.python.org/pep-0440/ + "PEP 440 - Version Identification and Dependency Specification" + +## Categorical Target Label + +A categorical target is a target table with a key-value pair `categories`. +`categories` contains a list of lists of literal strings. Each list of strings +will be treated as one category, with all literal strings within that list being +treated as one representative for that category. This allows the user to easily +group related classes into one large class (i.e. `"True", "1", "Yes"` could all +be unified into the same category). + +### Category Weights + +It is possible to assign a weight to each category, to e.g. weigh rarer classes +more heavily. The weights are stored in a table `targets.LABEL.class_weights`, +whose keys is the first representative of each category, and the values of which +is the weight of the category as a floating point number. + +## Target Label to Quantize + +If a target has the `thresholds` option key set, it is interpreted as a +continuous target which has to be quantized. `thresholds` has to be a list of +floating point numbers [t_0, t_n], n > 1 containing the thresholds of the bins +to quantize the values into. A categorical target will be quantized into bins + +```asciimath +b_0 = [t_0; t_1], b_1 = (t_1; b_2], ... b_(n-1) = (t_(n-1); t_n] +``` + +The bins will be treated as categories with names +`f"[{t_0:+1.2e};{t_1:+1.2e}]"` for the first bin and +`f"({t_i:+1.2e};{t_(i+1):+1.2e}]"` for all other bins + +To avoid confusion, we recommend to also format the `thresholds` list the same +way. + +The bins can also be weighted. See _Categorical Target Label: Category Weights_ +for details. + + > Experience has shown that many labels contain non-negative values with a + > disproportionate amount (more than n_samples/n_bins) of zeroes. We thus + > decided to make the _right_ side of each bin inclusive, as the bin (-A,0] + > then naturally includes those zero values. +""" + +import logging +from pathlib import Path +from typing import ( + Any, + Dict, + List, + NamedTuple, + Optional, + Sequence, + TextIO, + Tuple, +) + +import numpy as np +import numpy.typing as npt +import pandas as pd +import torch +import torch.nn.functional as F +from packaging.specifiers import Specifier + + +def read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: + if not isinstance(path, Path): + return pd.read_csv(path, **kwargs) + elif path.suffix == ".xlsx": + return pd.read_excel(path, **kwargs) + elif path.suffix == ".csv": + return pd.read_csv(path, **kwargs) + else: + raise ValueError( + "table to load has to either be an excel (`*.xlsx`) or csv (`*.csv`) file." + ) + + +__all__ = ["build_targets", "decode_targets"] + + +class TargetSpec(NamedTuple): + version: str + targets: Dict[str, Dict[str, Any]] + + +class EncodedTarget(NamedTuple): + categories: List[str] + encoded: torch.Tensor + weight: torch.Tensor + + +def encode_category( + *, + clini_df: pd.DataFrame, + target_label: str, + categories: Sequence[List[str]], + class_weights: Optional[Dict[str, float]] = None, + **ignored, +) -> Tuple[List[str], torch.Tensor, torch.Tensor]: + # Map each category to its index + category_map = {member: idx for idx, cat in enumerate(categories) for member in cat} + + # Map each item to it's category's index, mapping nans to num_classes+1 + # This way we can easily discard the NaN column later + indexes = clini_df[target_label].map(lambda c: category_map.get(c, len(categories))) + indexes = torch.tensor(indexes.values) + + # Discard nan column + one_hot = F.one_hot(indexes, num_classes=len(categories) + 1)[:, :-1] + + # Class weights + if class_weights is not None: + weight = torch.tensor([class_weights[c[0]] for c in categories]) + else: + # No class weights given; use normalized inverse frequency + counts = one_hot.sum(dim=0) + weight = (w := (counts.sum() / counts)) / w.sum() + + # Warn user of unused labels + if ignored: + logging.warn(f"ignored labels in target {target_label}: {ignored}") + + return [c[0] for c in categories], one_hot, weight + + +def encode_quantize( + *, + clini_df: pd.DataFrame, + target_label: str, + thresholds: npt.NDArray[np.floating[Any]], + class_weights: Optional[Dict[str, float]] = None, + **ignored, +) -> Tuple[List[str], torch.Tensor, torch.Tensor]: + # Warn user of unused labels + if ignored: + logging.warn(f"ignored labels in target {target_label}: {ignored}") + + n_bins = len(thresholds) - 1 + numeric_vals = torch.tensor(pd.to_numeric(clini_df[target_label]).values).reshape( + -1, 1 + ) + + # Map each value to a class index as follows: + # 1. If the value is NaN or less than the left-most threshold, use class + # index 0 + # 2. If it is between the left-most and the right-most threshold, set it to + # the bin number (starting from 1) + # 3. If it is larger than the right-most threshold, set it to N_bins + 1 + bin_index = ( + (numeric_vals > torch.tensor(thresholds).reshape(1, -1)).count_nonzero(1) + # For the first bucket, we have to include the lower threshold + + (numeric_vals.reshape(-1) == thresholds[0]) + ) + # One hot encode and discard nan columns (first and last col) + one_hot = F.one_hot(bin_index, num_classes=n_bins + 2)[:, 1:-1] + + # Class weights + categories = [ + f"[{thresholds[0]:+1.2e};{thresholds[1]:+1.2e}]", + *( + f"({lower:+1.2e};{upper:+1.2e}]" + for lower, upper in zip(thresholds[1:-1], thresholds[2:], strict=True) + ), + ] + + if class_weights is not None: + weight = torch.tensor([class_weights[c] for c in categories]) + else: + # No class weights given; use normalized inverse frequency + counts = one_hot.sum(0) + weight = (w := (np.divide(counts.sum(), counts, where=counts > 0))) / w.sum() + + return categories, one_hot, weight + + +def decode_targets( + encoded: torch.Tensor, + *, + target_labels: Sequence[str], + targets: Dict[str, Any], + version: str = "barspoon-targets 2.0", + **ignored, +) -> List[np.ndarray]: + name, version = version.split(" ") + spec = Specifier("~=2.0") + + if not (name == "barspoon-targets" and spec.contains(version)): + raise ValueError( + f"incompatible target file: expected barspoon-targets{spec}, found `{name} {version}`" + ) + + # Warn user of unused labels + if ignored: + logging.warn(f"ignored parameters: {ignored}") + + decoded_targets = [] + curr_col = 0 + for target_label in target_labels: + info = targets[target_label] + + if (categories := info.get("categories")) is not None: + # Add another column which is one iff all the other values are zero + encoded_target = encoded[:, curr_col : curr_col + len(categories)] + is_none = ~encoded_target.any(dim=1).view(-1, 1) + encoded_target = torch.cat([encoded_target, is_none], dim=1) + + # Decode to class labels + representatives = np.array([c[0] for c in categories] + [None]) + category_index = encoded_target.argmax(dim=1) + decoded = representatives[category_index] + decoded_targets.append(decoded) + + curr_col += len(categories) + + elif (thresholds := info.get("thresholds")) is not None: + n_bins = len(thresholds) - 1 + encoded_target = encoded[:, curr_col : curr_col + n_bins] + is_none = ~encoded_target.any(dim=1).view(-1, 1) + encoded_target = torch.cat([encoded_target, is_none], dim=1) + + bin_edges = [-np.inf, *thresholds, np.inf] + representatives = np.array( + [ + f"[{lower:+1.2e};{upper:+1.2e})" + for lower, upper in zip(bin_edges[:-1], bin_edges[1:]) + ] + ) + decoded = representatives[encoded_target.argmax(dim=1)] + + decoded_targets.append(decoded) + + curr_col += n_bins + + else: + raise ValueError(f"cannot decode {target_label}: no target info") + + return decoded_targets + + +def build_targets( + *, + clini_tables: Sequence[Path], + categorical_labels: Sequence[str], + category_min_count: int = 32, + quantize: Sequence[tuple[str, int]] = (), +) -> Dict[str, EncodedTarget]: + clini_df = pd.concat([read_table(c) for c in clini_tables]) + encoded_targets: Dict[str, EncodedTarget] = {} + + # categorical targets + for target_label in categorical_labels: + counts = clini_df[target_label].value_counts() + well_supported = counts[counts >= category_min_count] + + if len(well_supported) <= 1: + continue + + categories = [[str(cat)] for cat in well_supported.index] + + weights = well_supported.sum() / well_supported + weights /= weights.sum() + + representatives, encoded, weight = encode_category( + clini_df=clini_df, + target_label=target_label, + categories=categories, + class_weights=weights.to_dict(), + ) + + encoded_targets[target_label] = EncodedTarget( + categories=representatives, + encoded=encoded, + weight=weight, + ) + + # quantized targets + for target_label, bincount in quantize: + vals = pd.to_numeric(clini_df[target_label]).dropna() + + if vals.empty: + continue + + vals_clamped = vals.replace( + { + -np.inf: vals[vals != -np.inf].min(), + np.inf: vals[vals != np.inf].max(), + } + ) + + thresholds = np.array( + [ + -np.inf, + *np.quantile(vals_clamped, q=np.linspace(0, 1, bincount + 1))[1:-1], + np.inf, + ], + dtype=float, + ) + + representatives, encoded, weight = encode_quantize( + clini_df=clini_df, + target_label=target_label, + thresholds=thresholds, + ) + + if encoded.shape[1] <= 1: + continue + + encoded_targets[target_label] = EncodedTarget( + categories=representatives, + encoded=encoded, + weight=weight, + ) + + return encoded_targets + + +if __name__ == "__main__": + encoded = build_targets( + clini_tables=[ + Path( + "/mnt/bulk-neptune/nguyenmin/stamp-dev/experiments/survival_prediction/TCGA-CRC-DX_CLINI.xlsx" + ) + ], + categorical_labels=["BRAF", "KRAS", "NRAS"], + category_min_count=32, + quantize=[], + ) + for name, enc in encoded.items(): + print(name, enc.encoded.shape) diff --git a/tests/random_data.py b/tests/random_data.py index bd95d1bc..c7c36880 100644 --- a/tests/random_data.py +++ b/tests/random_data.py @@ -254,6 +254,92 @@ def create_random_patient_level_dataset( return clini_path, slide_path, feat_dir, categories +def create_random_multi_target_dataset( + *, + dir: Path, + n_patients: int, + max_slides_per_patient: int, + min_tiles_per_slide: int, + max_tiles_per_slide: int, + feat_dim: int, + target_labels: Sequence[str], + categories_per_target: Sequence[Sequence[str]], + extractor_name: str = "random-test-generator", + min_slides_per_patient: int = 1, +) -> tuple[Path, Path, Path, Sequence[Sequence[str]]]: + """ + Create a random multi-target tile-level dataset. + + Args: + dir: Directory to create dataset in + n_patients: Number of patients + max_slides_per_patient: Maximum slides per patient + min_tiles_per_slide: Minimum tiles per slide + max_tiles_per_slide: Maximum tiles per slide + feat_dim: Feature dimension + target_labels: Names of the target columns (e.g., ["subtype", "grade"]) + categories_per_target: Categories for each target (e.g., [["A", "B"], ["1", "2", "3"]]) + extractor_name: Name of the extractor + min_slides_per_patient: Minimum slides per patient + + Returns: + Tuple of (clini_path, slide_path, feat_dir, categories_per_target) + """ + if len(target_labels) != len(categories_per_target): + raise ValueError( + "target_labels and categories_per_target must have same length" + ) + + slide_path_to_patient: Mapping[Path, PatientId] = {} + patient_to_ground_truths: Mapping[PatientId, dict[str, str]] = {} + clini_path = dir / "clini.csv" + slide_path = dir / "slide.csv" + + feat_dir = dir / "feats" + feat_dir.mkdir() + + for _ in range(n_patients): + # Random patient ID + patient_id = random_string(16) + + # Generate ground truths for each target + ground_truths = {} + for target_label, categories in zip(target_labels, categories_per_target): + ground_truths[target_label] = random.choice(categories) + + patient_to_ground_truths[patient_id] = ground_truths + + # Generate some slides + for _ in range(random.randint(min_slides_per_patient, max_slides_per_patient)): + slide_path_to_patient[ + create_random_feature_file( + tmp_path=feat_dir, + min_tiles=min_tiles_per_slide, + max_tiles=max_tiles_per_slide, + feat_dim=feat_dim, + extractor_name=extractor_name, + ).relative_to(feat_dir) + ] = patient_id + + # Create clinical table with multiple target columns + clini_data = [] + for patient_id, ground_truths in patient_to_ground_truths.items(): + row = {"patient": patient_id} + row.update(ground_truths) + clini_data.append(row) + + clini_df = pd.DataFrame(clini_data) + clini_df.to_csv(clini_path, index=False) + + slide_df = pd.DataFrame( + slide_path_to_patient.items(), + columns=["slide_path", "patient"], + ) + slide_df.to_csv(slide_path, index=False) + + return clini_path, slide_path, feat_dir, categories_per_target + + def create_random_feature_file( *, tmp_path: Path, diff --git a/tests/test_cache_tiles.py b/tests/test_cache_tiles.py index a665e92c..d9f1411d 100644 --- a/tests/test_cache_tiles.py +++ b/tests/test_cache_tiles.py @@ -6,10 +6,10 @@ import numpy as np import pytest -from stamp.cache import download_file from stamp.preprocessing import Microns, TilePixels from stamp.preprocessing.tiling import _Tile, tiles_with_cache from stamp.types import ImageExtension, SlidePixels +from stamp.utils.cache import download_file def _get_tiles_and_images( diff --git a/tests/test_config.py b/tests/test_config.py index 15b5dd80..fbb53a6c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,7 +1,6 @@ # %% from pathlib import Path -from stamp.config import StampConfig from stamp.heatmaps.config import HeatmapConfig from stamp.modeling.config import ( AdvancedConfig, @@ -21,6 +20,7 @@ TilePixels, ) from stamp.statistics import StatsConfig +from stamp.utils.config import StampConfig def test_config_parsing() -> None: diff --git a/tests/test_crossval.py b/tests/test_crossval.py index 184a5c23..5e53f50c 100644 --- a/tests/test_crossval.py +++ b/tests/test_crossval.py @@ -13,7 +13,7 @@ VitModelParams, ) from stamp.modeling.crossval import categorical_crossval_ -from stamp.seed import Seed +from stamp.utils.seed import Seed @pytest.mark.slow diff --git a/tests/test_data.py b/tests/test_data.py index 564d6426..3c86a931 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -23,7 +23,6 @@ get_coords, slide_to_patient_from_slide_table_, ) -from stamp.seed import Seed from stamp.types import ( BagSize, FeaturePath, @@ -33,6 +32,7 @@ SlideMPP, TilePixels, ) +from stamp.utils.seed import Seed @pytest.mark.filterwarnings("ignore:some patients have no associated slides") diff --git a/tests/test_deployment.py b/tests/test_deployment.py index de20ea12..4e1570cc 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import cast import numpy as np import pytest @@ -24,8 +25,8 @@ ) from stamp.modeling.models.mlp import MLP from stamp.modeling.models.vision_tranformer import VisionTransformer -from stamp.seed import Seed from stamp.types import GroundTruth, PatientId, Task +from stamp.utils.seed import Seed def test_predict_patient_level( @@ -83,7 +84,8 @@ def test_predict_patient_level( assert len(predictions) == len(patient_to_data) for pid in patient_ids: - assert predictions[pid].shape == torch.Size([3]), "expected one score per class" + pred = cast(torch.Tensor, predictions[pid]) + assert pred.shape == torch.Size([3]), "expected one score per class" # Check if scores are consistent between runs and different for different patients more_patient_ids = [PatientId(f"pat{i}") for i in range(8, 11)] @@ -124,11 +126,13 @@ def test_predict_patient_level( assert len(more_predictions) == len(all_patient_ids) # Different patients should give different results assert not torch.allclose( - more_predictions[more_patient_ids[0]], more_predictions[more_patient_ids[1]] + cast(torch.Tensor, more_predictions[more_patient_ids[0]]), + cast(torch.Tensor, more_predictions[more_patient_ids[1]]), ), "different inputs should give different results" # The same patient should yield the same result assert torch.allclose( - predictions[patient_ids[0]], more_predictions[patient_ids[0]] + cast(torch.Tensor, predictions[patient_ids[0]]), + cast(torch.Tensor, more_predictions[patient_ids[0]]), ), "the same inputs should repeatedly yield the same results" @@ -295,7 +299,7 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: div_factor=25.0, ) - # ---- Build tile-level feature file so batch = (bags, coords, bag_sizes, gt) + # Build tile-level feature file so batch = (bags, coords, bag_sizes, gt) if task == "classification": feature_file = make_old_feature_file( feats=torch.rand(23, dim_feats), coords=torch.rand(23, 2) @@ -319,7 +323,7 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: ) } - # ---- Use tile_bag_dataloader for ALL tasks (so batch has 4 elements) + # Use tile_bag_dataloader for ALL tasks (so batch has 4 elements) test_dl, _ = tile_bag_dataloader( task=task, # "classification" | "regression" | "survival" patient_data=list(patient_to_data.values()), @@ -341,12 +345,17 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: assert len(predictions) == 1 pred = list(predictions.values())[0] if task == "classification": - assert pred.shape == torch.Size([len(categories)]) + pred_tensor = cast(torch.Tensor, pred) + assert pred_tensor.shape == torch.Size([len(categories)]) elif task == "regression": - assert pred.shape == torch.Size([1]) + pred_tensor = cast(torch.Tensor, pred) + assert pred_tensor.shape == torch.Size([1]) else: # survival # Cox model → scalar log-risk, KM → vector or matrix - assert pred.ndim in (0, 1, 2), f"unexpected survival output shape: {pred.shape}" + pred_tensor = cast(torch.Tensor, pred) + assert pred_tensor.ndim in (0, 1, 2), ( + f"unexpected survival output shape: {pred_tensor.shape}" + ) # Repeatability predictions2 = _predict( @@ -356,4 +365,6 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: accelerator="cpu", ) for pid in predictions: - assert torch.allclose(predictions[pid], predictions2[pid]) + assert torch.allclose( + cast(torch.Tensor, predictions[pid]), cast(torch.Tensor, predictions2[pid]) + ) diff --git a/tests/test_encoders.py b/tests/test_encoders.py index 3edef575..28d7c2c1 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -10,13 +10,13 @@ from huggingface_hub.errors import GatedRepoError from random_data import create_random_dataset, create_random_feature_file, random_string -from stamp.cache import download_file from stamp.encoding import ( EncoderName, init_patient_encoder_, init_slide_encoder_, ) from stamp.preprocessing.config import ExtractorName +from stamp.utils.cache import download_file # Contains an accepted input patch-level feature encoder # TODO: Make a class for each extractor instead of a function. This class diff --git a/tests/test_feature_extractors.py b/tests/test_feature_extractors.py index 699c10f6..6323ee8d 100644 --- a/tests/test_feature_extractors.py +++ b/tests/test_feature_extractors.py @@ -7,8 +7,8 @@ import torch from huggingface_hub.errors import GatedRepoError -from stamp.cache import download_file from stamp.preprocessing import ExtractorName, Microns, TilePixels, extract_ +from stamp.utils.cache import download_file @pytest.mark.slow diff --git a/tests/test_model.py b/tests/test_model.py index 1aa6d80a..f74e0a99 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,5 +1,6 @@ import torch +from stamp.modeling.models.barspoon import EncDecTransformer from stamp.modeling.models.mlp import MLP from stamp.modeling.models.trans_mil import TransMIL from stamp.modeling.models.vision_tranformer import VisionTransformer @@ -162,3 +163,114 @@ def test_trans_mil_inference_reproducibility( ) assert logits1.allclose(logits2) + + +def test_enc_dec_transformer_dims( + batch_size: int = 6, + n_tiles: int = 75, + input_dim: int = 456, + d_model: int = 128, +) -> None: + target_n_outs = {"subtype": 3, "grade": 4} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=256, + positional_encoding=True, + ) + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + logits = model.forward(bags, coords) + + assert set(logits.keys()) == set(target_n_outs.keys()) + for target_label, n_out in target_n_outs.items(): + assert logits[target_label].shape == (batch_size, n_out) + + +def test_enc_dec_transformer_single_target( + batch_size: int = 4, + n_tiles: int = 50, + input_dim: int = 256, + d_model: int = 64, +) -> None: + target_n_outs = {"label": 5} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=1, + num_decoder_layers=1, + dim_feedforward=128, + ) + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + logits = model.forward(bags, coords) + + assert list(logits.keys()) == ["label"] + assert logits["label"].shape == (batch_size, 5) + + +def test_enc_dec_transformer_no_positional_encoding( + batch_size: int = 4, + n_tiles: int = 30, + input_dim: int = 128, + d_model: int = 64, +) -> None: + target_n_outs = {"a": 2, "b": 3} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=1, + num_decoder_layers=1, + dim_feedforward=128, + positional_encoding=False, + ) + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + logits = model.forward(bags, coords) + + for target_label, n_out in target_n_outs.items(): + assert logits[target_label].shape == (batch_size, n_out) + + +def test_enc_dec_transformer_inference_reproducibility( + batch_size: int = 5, + n_tiles: int = 40, + input_dim: int = 200, + d_model: int = 64, +) -> None: + target_n_outs = {"subtype": 3, "grade": 4} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=128, + ) + model = model.eval() + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + + with torch.inference_mode(): + logits1 = model.forward(bags, coords) + logits2 = model.forward(bags, coords) + + for target_label in target_n_outs: + assert logits1[target_label].allclose(logits2[target_label]) diff --git a/tests/test_statistics.py b/tests/test_statistics.py index 790b98ab..b600d7e7 100644 --- a/tests/test_statistics.py +++ b/tests/test_statistics.py @@ -2,6 +2,7 @@ from pathlib import Path import numpy as np +import pandas as pd import torch from random_data import random_patient_preds, random_string @@ -47,3 +48,153 @@ def test_statistics_integration( def test_statistics_integration_for_multiple_patient_preds(tmp_path: Path) -> None: return test_statistics_integration(tmp_path=tmp_path, n_patient_preds=5) + + +def test_statistics_survival_integration( + *, + tmp_path: Path, + n_folds: int = 1, + n_patients: int = 200, +) -> None: + """Check that survival statistics run without crashing.""" + random.seed(0) + np.random.seed(0) + + for fold_i in range(n_folds): + times = np.random.uniform(30, 2000, size=n_patients) + statuses = np.random.choice([0, 1], size=n_patients, p=[0.3, 0.7]) + risks = np.random.randn(n_patients) + df = pd.DataFrame( + { + "patient": [random_string(8) for _ in range(n_patients)], + "day": times, + "status": statuses, + "pred_score": risks, + } + ) + df.to_csv(tmp_path / f"survival-preds-{fold_i}.csv", index=False) + + compute_stats_( + task="survival", + output_dir=tmp_path / "output", + pred_csvs=[tmp_path / f"survival-preds-{i}.csv" for i in range(n_folds)], + time_label="day", + status_label="status", + ) + + assert (tmp_path / "output" / "survival-stats_individual.csv").is_file() + + +def test_statistics_survival_integration_multiple_folds(tmp_path: Path) -> None: + return test_statistics_survival_integration(tmp_path=tmp_path, n_folds=5) + + +def test_statistics_regression_integration( + *, + tmp_path: Path, + n_folds: int = 1, + n_patients: int = 200, +) -> None: + """Check that regression statistics run without crashing.""" + random.seed(0) + np.random.seed(0) + + for fold_i in range(n_folds): + y_true = np.random.uniform(0, 100, size=n_patients) + y_pred = y_true + np.random.randn(n_patients) * 10 # noisy predictions + df = pd.DataFrame( + { + "patient": [random_string(8) for _ in range(n_patients)], + "target": y_true, + "pred": y_pred, + } + ) + df.to_csv(tmp_path / f"regression-preds-{fold_i}.csv", index=False) + + compute_stats_( + task="regression", + output_dir=tmp_path / "output", + pred_csvs=[tmp_path / f"regression-preds-{i}.csv" for i in range(n_folds)], + ground_truth_label="target", + ) + + assert (tmp_path / "output" / "target_regression-stats_individual.csv").is_file() + assert (tmp_path / "output" / "target_regression-stats_aggregated.csv").is_file() + + +def test_statistics_regression_integration_multiple_folds(tmp_path: Path) -> None: + return test_statistics_regression_integration(tmp_path=tmp_path, n_folds=5) + + +def test_statistics_multi_target_classification_integration( + *, + tmp_path: Path, + n_patient_preds: int = 1, +) -> None: + """Check that multi-target classification statistics run without crashing. + + Multi-target predictions produce separate ground-truth columns per target. + We run compute_stats_ once per target, as the statistics pipeline handles + one target at a time. + """ + random.seed(0) + np.random.seed(0) + torch.random.manual_seed(0) + + categories_per_target = {"subtype": ["A", "B"], "grade": ["1", "2", "3"]} + + for pred_i in range(n_patient_preds): + n_patients = random.randint(100, 500) + data: dict[str, list] = { + "patient": [random_string(8) for _ in range(n_patients)], + } + + for target_label, cats in categories_per_target.items(): + data[target_label] = [random.choice(cats) for _ in range(n_patients)] + probs = torch.softmax(torch.rand(len(cats), n_patients), dim=0) + for j, cat in enumerate(cats): + data[f"{target_label}_{cat}"] = probs[j].tolist() + + pd.DataFrame(data).to_csv( + tmp_path / f"multi-target-preds-{pred_i}.csv", index=False + ) + + # Run statistics per target (as the pipeline would do) + for target_label, cats in categories_per_target.items(): + true_class = cats[0] + compute_stats_( + task="classification", + output_dir=tmp_path / "output" / target_label, + pred_csvs=[ + tmp_path / f"multi-target-preds-{i}.csv" for i in range(n_patient_preds) + ], + ground_truth_label=target_label, + true_class=true_class, + ) + + assert ( + tmp_path + / "output" + / target_label + / f"{target_label}_categorical-stats_aggregated.csv" + ).is_file() + assert ( + tmp_path + / "output" + / target_label + / f"roc-curve_{target_label}={true_class}.svg" + ).is_file() + assert ( + tmp_path + / "output" + / target_label + / f"pr-curve_{target_label}={true_class}.svg" + ).is_file() + + +def test_statistics_multi_target_classification_multiple_preds( + tmp_path: Path, +) -> None: + return test_statistics_multi_target_classification_integration( + tmp_path=tmp_path, n_patient_preds=3 + ) diff --git a/tests/test_train_deploy.py b/tests/test_train_deploy.py index 0180d171..ea5547f1 100644 --- a/tests/test_train_deploy.py +++ b/tests/test_train_deploy.py @@ -8,6 +8,7 @@ import torch from random_data import ( create_random_dataset, + create_random_multi_target_dataset, create_random_patient_level_dataset, create_random_patient_level_survival_dataset, create_random_regression_dataset, @@ -22,8 +23,9 @@ VitModelParams, ) from stamp.modeling.deploy import deploy_categorical_model_ +from stamp.modeling.registry import ModelName from stamp.modeling.train import train_categorical_model_ -from stamp.seed import Seed +from stamp.utils.seed import Seed @pytest.mark.slow @@ -123,6 +125,9 @@ def test_train_deploy_integration( pytest.param(False, True, id="use vary_precision_transform"), ], ) +@pytest.mark.filterwarnings( + "ignore:.*violates type hint.*not instance of tuple:UserWarning" +) def test_train_deploy_patient_level_integration( *, tmp_path: Path, @@ -356,6 +361,9 @@ def test_train_deploy_survival_integration( @pytest.mark.slow +@pytest.mark.filterwarnings( + "ignore:.*violates type hint.*not instance of tuple:UserWarning" +) def test_train_deploy_patient_level_regression_integration( *, tmp_path: Path, @@ -465,6 +473,9 @@ def test_train_deploy_patient_level_regression_integration( @pytest.mark.slow +@pytest.mark.filterwarnings( + "ignore:.*violates type hint.*not instance of tuple:UserWarning" +) def test_train_deploy_patient_level_survival_integration( *, tmp_path: Path, @@ -531,3 +542,89 @@ def test_train_deploy_patient_level_survival_integration( accelerator="gpu" if torch.cuda.is_available() else "cpu", num_workers=min(os.cpu_count() or 1, 16), ) + + +@pytest.mark.slow +@pytest.mark.filterwarnings("ignore:No positive samples in targets") +def test_train_deploy_multi_target_integration( + *, + tmp_path: Path, + feat_dim: int = 25, +) -> None: + """Integration test: train + deploy a multi-target tile-level classification model.""" + Seed.set(42) + + (tmp_path / "train").mkdir() + (tmp_path / "deploy").mkdir() + + # Define multi-target setup: subtype (2 categories) and grade (3 categories) + target_labels = ["subtype", "grade"] + categories_per_target = [["A", "B"], ["1", "2", "3"]] + + # Create random multi-target tile-level dataset + train_clini_path, train_slide_path, train_feature_dir, _ = ( + create_random_multi_target_dataset( + dir=tmp_path / "train", + n_patients=400, + max_slides_per_patient=3, + min_tiles_per_slide=20, + max_tiles_per_slide=600, + feat_dim=feat_dim, + target_labels=target_labels, + categories_per_target=categories_per_target, + ) + ) + deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( + create_random_multi_target_dataset( + dir=tmp_path / "deploy", + n_patients=50, + max_slides_per_patient=3, + min_tiles_per_slide=20, + max_tiles_per_slide=600, + feat_dim=feat_dim, + target_labels=target_labels, + categories_per_target=categories_per_target, + ) + ) + + # Build config objects + config = TrainConfig( + task="classification", + clini_table=train_clini_path, + slide_table=train_slide_path, + feature_dir=train_feature_dir, + output_dir=tmp_path / "train_output", + patient_label="patient", + ground_truth_label=target_labels, + filename_label="slide_path", + categories=[cat for cats in categories_per_target for cat in cats], + ) + + advanced = AdvancedConfig( + bag_size=500, + num_workers=min(os.cpu_count() or 1, 16), + batch_size=8, + max_epochs=2, + patience=1, + accelerator="gpu" if torch.cuda.is_available() else "cpu", + model_params=ModelParams(), + model_name=ModelName.BARSPOON, + ) + + # Train + deploy multi-target model + train_categorical_model_(config=config, advanced=advanced) + + deploy_categorical_model_( + output_dir=tmp_path / "deploy_output", + checkpoint_paths=[tmp_path / "train_output" / "model.ckpt"], + clini_table=deploy_clini_path, + slide_table=deploy_slide_path, + feature_dir=deploy_feature_dir, + patient_label="patient", + ground_truth_label=target_labels, + time_label=None, + status_label=None, + filename_label="slide_path", + accelerator="gpu" if torch.cuda.is_available() else "cpu", + num_workers=min(os.cpu_count() or 1, 16), + ) diff --git a/uv.lock b/uv.lock index c4015d9f..96b4b73a 100644 --- a/uv.lock +++ b/uv.lock @@ -3699,13 +3699,14 @@ wheels = [ [[package]] name = "stamp" -version = "2.3.0" +version = "2.4.0" source = { editable = "." } dependencies = [ { name = "beartype" }, { name = "einops" }, { name = "h5py" }, { name = "jaxtyping" }, + { name = "lifelines" }, { name = "lightning" }, { name = "matplotlib" }, { name = "numpy" }, @@ -3807,7 +3808,6 @@ gigapath = [ { name = "fvcore" }, { name = "gigapath" }, { name = "iopath" }, - { name = "lifelines" }, { name = "monai" }, { name = "scikit-image" }, { name = "scikit-survival" }, @@ -3828,7 +3828,6 @@ gpu = [ { name = "huggingface-hub" }, { name = "iopath" }, { name = "jinja2" }, - { name = "lifelines" }, { name = "madeleine" }, { name = "mamba-ssm" }, { name = "monai" }, @@ -3920,7 +3919,7 @@ requires-dist = [ { name = "iopath", marker = "extra == 'gigapath'" }, { name = "jaxtyping", specifier = ">=0.3.2" }, { name = "jinja2", marker = "extra == 'cobra'", specifier = ">=3.1.4" }, - { name = "lifelines", marker = "extra == 'gigapath'" }, + { name = "lifelines", specifier = ">=0.28.0" }, { name = "lightning", specifier = ">=2.5.2" }, { name = "madeleine", marker = "extra == 'madeleine'", git = "https://github.com/mahmoodlab/MADELEINE.git?rev=de7c85acc2bdad352e6df8eee5694f8b6f288012" }, { name = "mamba-ssm", marker = "extra == 'cobra'", specifier = ">=2.2.6.post3" }, @@ -4747,4 +4746,4 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/24/fd/725b8e73ac2a50e78a4534ac43c6addf5c1c2d65380dd48a9169cc6739a9/yarl-1.20.1-cp313-cp313t-win32.whl", hash = "sha256:b121ff6a7cbd4abc28985b6028235491941b9fe8fe226e6fdc539c977ea1739d", size = 86591, upload-time = "2025-06-10T00:45:25.793Z" }, { url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" }, { url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" }, -] \ No newline at end of file +] From d8cc268200732f851c36ab59e4cbb1bae6f5c782 Mon Sep 17 00:00:00 2001 From: mducducd Date: Fri, 13 Feb 2026 14:16:16 +0000 Subject: [PATCH 02/11] add multi-target support --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3069dddf..9a0ed68b 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ STAMP is an **end‑to‑end, weakly‑supervised deep‑learning pipeline** tha * 🎓 **Beginner‑friendly & expert‑ready**: Zero‑code CLI and YAML config for routine use; optional code‑level customization for advanced research. * 🧩 **Model‑rich**: Out‑of‑the‑box support for **+20 foundation models** at [tile level](getting-started.md#feature-extraction) (e.g., *Virchow‑v2*, *UNI‑v2*) and [slide level](getting-started.md#slide-level-encoding) (e.g., *TITAN*, *COBRA*). * 🔬 **Weakly‑supervised**: End‑to‑end MIL with Transformer aggregation for training, cross‑validation and external deployment; no pixel‑level labels required. -* 🧮 **Multi-task learning**: Unified framework for **classification**, **regression**, and **cox-based survival analysis**. +* 🧮 **Multi-task learning**: Unified framework for **classification**, **multi-target classification**, **regression**, and **cox-based survival analysis**. * 📊 **Stats & results**: Built‑in metrics and patient‑level predictions, ready for analysis and reporting. * 🖼️ **Explainable**: Generates heatmaps and top‑tile exports out‑of‑the‑box for transparent model auditing and publication‑ready figures. * 🤝 **Collaborative by design**: Clinicians drive hypothesis & interpretation while engineers handle compute; STAMP’s modular CLI mirrors real‑world workflows and tracks every step for full reproducibility. From f7fa2c599a2e34ddd8cf67ae764e887995ec80f7 Mon Sep 17 00:00:00 2001 From: mducducd Date: Mon, 16 Feb 2026 14:12:38 +0000 Subject: [PATCH 03/11] add multi-target statistics --- src/stamp/statistics/categorical.py | 82 +++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/src/stamp/statistics/categorical.py b/src/stamp/statistics/categorical.py index 2b5c859e..30a03a86 100755 --- a/src/stamp/statistics/categorical.py +++ b/src/stamp/statistics/categorical.py @@ -21,6 +21,30 @@ ] +def _detect_targets_from_columns(columns: Sequence[str]) -> list[str]: + """Detect target columns from CSV column names. + + Assumes multi-target format where each target has: + - A ground truth column (target name) + - A prediction column (pred_{target}) + - Probability columns ({target}_{class1}, {target}_{class2}, ...) + + Returns: + List of target names detected. + """ + # Convert to list to handle pandas Index + columns = list(columns) + targets = [] + for col in columns: + # Look for columns that start with "pred_" + if col.startswith("pred_"): + target_name = col[5:] # Remove "pred_" prefix + # Verify the target column exists + if target_name in columns: + targets.append(target_name) + return sorted(targets) + + def _categorical(preds_df: pd.DataFrame, target_label: str) -> pd.DataFrame: """Calculates some stats for categorical prediction tables. @@ -110,3 +134,61 @@ def categorical_aggregated_( preds_df.to_csv(outpath / f"{ground_truth_label}_categorical-stats_individual.csv") stats_df = _aggregate_categorical_stats(preds_df.reset_index()) stats_df.to_csv(outpath / f"{ground_truth_label}_categorical-stats_aggregated.csv") + + +def categorical_aggregated_multitarget_( + *, + preds_csvs: Sequence[Path], + outpath: Path, + target_labels: Sequence[str], +) -> None: + """Calculate statistics for multi-target categorical deployments. + + Args: + preds_csvs: CSV files containing predictions. + outpath: Path to save the results to. + target_labels: List of target labels to compute statistics for. + + This will apply `_categorical` to each target in the multi-target setup, + calculate statistics per target, and save both individual and aggregated results. + """ + outpath.mkdir(parents=True, exist_ok=True) + + all_target_stats = {} + + for target_label in target_labels: + # Process each target separately + preds_dfs = {} + for p in preds_csvs: + df = pd.read_csv(p, dtype=str) + # Drop rows where this target's ground truth is missing + df_clean = df.dropna(subset=[target_label]) + if len(df_clean) > 0: + preds_dfs[Path(p).parent.name] = _categorical(df_clean, target_label) + + if not preds_dfs: + continue + + # Concatenate and save individual stats for this target + preds_df = pd.concat(preds_dfs).sort_index() + preds_df.to_csv(outpath / f"{target_label}_categorical-stats_individual.csv") + + # Aggregate stats for this target + stats_df = _aggregate_categorical_stats(preds_df.reset_index()) + stats_df.to_csv(outpath / f"{target_label}_categorical-stats_aggregated.csv") + + # Store for summary + all_target_stats[target_label] = stats_df + + # Create a combined summary across all targets + if all_target_stats: + summary_dfs = [] + for target_name, stats_df in all_target_stats.items(): + stats_copy = stats_df.copy() + stats_copy.index = pd.MultiIndex.from_product( + [[target_name], stats_copy.index], names=["target", "class"] + ) + summary_dfs.append(stats_copy) + + combined_summary = pd.concat(summary_dfs) + combined_summary.to_csv(outpath / "multitarget_categorical-stats_summary.csv") From 747143f65b57edd50df8725b63f78b089f6a5a3f Mon Sep 17 00:00:00 2001 From: mducducd Date: Mon, 16 Feb 2026 15:28:08 +0000 Subject: [PATCH 04/11] refactor --- src/stamp/encoding/__init__.py | 2 +- src/stamp/encoding/encoder/__init__.py | 4 +- src/stamp/heatmaps/__init__.py | 100 ++--- src/stamp/modeling/crossval.py | 2 +- src/stamp/modeling/data.py | 53 ++- src/stamp/modeling/deploy.py | 2 +- src/stamp/modeling/models/__init__.py | 125 ++++--- src/stamp/modeling/models/mlp.py | 2 +- .../preprocessing/extractor/ctranspath.py | 2 +- src/stamp/preprocessing/tiling.py | 5 +- src/stamp/statistics/__init__.py | 271 +++++++++++--- src/stamp/statistics/categorical.py | 14 +- src/stamp/utils/target_file.py | 351 ------------------ tests/random_data.py | 6 +- tests/test_data.py | 27 ++ tests/test_deployment.py | 10 +- 16 files changed, 432 insertions(+), 544 deletions(-) delete mode 100644 src/stamp/utils/target_file.py diff --git a/src/stamp/encoding/__init__.py b/src/stamp/encoding/__init__.py index 9cb873bb..02d46594 100644 --- a/src/stamp/encoding/__init__.py +++ b/src/stamp/encoding/__init__.py @@ -73,7 +73,7 @@ def init_slide_encoder_( selected_encoder = encoder case _ as unreachable: - assert_never(unreachable) # type: ignore + assert_never(unreachable) selected_encoder.encode_slides_( output_dir=output_dir, diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index 86daa54a..4720ef9b 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from pathlib import Path from tempfile import NamedTemporaryFile +from typing import cast import h5py import numpy as np @@ -183,7 +184,8 @@ def _read_h5( elif not h5_path.endswith(".h5"): raise ValueError(f"File is not of type .h5: {os.path.basename(h5_path)}") with h5py.File(h5_path, "r") as f: - feats: Tensor = torch.tensor(f["feats"][:], dtype=self.precision) # type: ignore + feats_ds = cast(h5py.Dataset, f["feats"]) + feats: Tensor = torch.tensor(feats_ds[:], dtype=self.precision) coords: CoordsInfo = get_coords(f) extractor: str = f.attrs.get("extractor", "") if extractor == "": diff --git a/src/stamp/heatmaps/__init__.py b/src/stamp/heatmaps/__init__.py index 22fb5250..fb704fe6 100644 --- a/src/stamp/heatmaps/__init__.py +++ b/src/stamp/heatmaps/__init__.py @@ -5,7 +5,7 @@ import logging from collections.abc import Collection, Iterable from pathlib import Path -from typing import cast, no_type_check +from typing import cast import h5py import matplotlib.pyplot as plt @@ -19,7 +19,7 @@ from packaging.version import Version from PIL import Image from torch import Tensor -from torch.func import jacrev # pyright: ignore[reportPrivateImportUsage] +from torch.func import jacrev from stamp.modeling.data import get_coords, get_stride from stamp.modeling.deploy import load_model_from_ckpt @@ -29,6 +29,8 @@ _logger = logging.getLogger("stamp") +_SlideLike = openslide.OpenSlide | openslide.ImageSlide + def _gradcam_per_category( model: torch.nn.Module, @@ -37,23 +39,19 @@ def _gradcam_per_category( ) -> Float[Tensor, "tile category"]: feat_dim = -1 - cam = ( - ( - feats - * jacrev( - lambda bags: model.forward( - bags.unsqueeze(0), - coords=coords.unsqueeze(0), - mask=None, - ).squeeze(0) - )(feats) - ) - .mean(feat_dim) # type: ignore - .abs() + jac = cast( + Tensor, + jacrev( + lambda bags: model.forward( + bags.unsqueeze(0), + coords=coords.unsqueeze(0), + mask=None, + ).squeeze(0) + )(feats), ) + cam = (feats * jac).mean(feat_dim).abs() cam = torch.softmax(cam, dim=-1) - return cam.permute(-1, -2) @@ -70,21 +68,28 @@ def _attention_rollout_single( device = feats.device - # 1. Forward pass to fill attn_weights in each SelfAttention layer + # --- 1. Forward pass to fill attn_weights in each SelfAttention layer --- _ = model( bags=feats.unsqueeze(0), coords=coords.unsqueeze(0), mask=torch.zeros(1, len(feats), dtype=torch.bool, device=device), ) - # 2. Rollout computation + # --- 2. Rollout computation --- attn_rollout: torch.Tensor | None = None - for layer in model.transformer.layers: # type: ignore - attn = getattr(layer[0], "attn_weights", None) # SelfAttention.attn_weights + transformer = getattr(model, "transformer", None) + if transformer is None: + raise RuntimeError("Model does not have a transformer attribute") + for layer in transformer.layers: + attn = getattr(layer, "attn_weights", None) + if attn is None: + first_child = next(iter(layer.children()), None) + if first_child is not None: + attn = getattr(first_child, "attn_weights", None) if attn is None: raise RuntimeError( "SelfAttention.attn_weights not found. " - "Make sure SelfAttention stores them." + "Make sure SelfAttention stores them on the layer or its first child." ) # attn: [heads, seq, seq] @@ -96,10 +101,10 @@ def _attention_rollout_single( if attn_rollout is None: raise RuntimeError("No attention maps collected from transformer layers.") - # 3. Extract CLS → tiles attention + # --- 3. Extract CLS → tiles attention --- cls_attn = attn_rollout[0, 1:] # [tile] - # 4. Normalize for visualization consistency + # --- 4. Normalize for visualization consistency --- cls_attn = cls_attn - cls_attn.min() cls_attn = cls_attn / (cls_attn.max().clamp(min=1e-8)) @@ -117,15 +122,18 @@ def _gradcam_single( """ feat_dim = -1 - jac = jacrev( - lambda bags: model.forward( - bags.unsqueeze(0), - coords=coords.unsqueeze(0), - mask=None, - ).squeeze() - )(feats) + jac = cast( + Tensor, + jacrev( + lambda bags: model.forward( + bags.unsqueeze(0), + coords=coords.unsqueeze(0), + mask=None, + ).squeeze() + )(feats), + ) - cam = (feats * jac).mean(feat_dim).abs() # type: ignore # [tile] + cam = (feats * jac).mean(feat_dim).abs() # [tile] return cam @@ -148,17 +156,21 @@ def _vals_to_im( def _show_thumb( - slide, thumb_ax: Axes, attention: Tensor, default_slide_mpp: SlideMPP | None + slide: _SlideLike, + thumb_ax: Axes, + attention: Tensor, + default_slide_mpp: SlideMPP | None, ) -> np.ndarray: mpp = get_slide_mpp_(slide, default_mpp=default_slide_mpp) dims_um = np.array(slide.dimensions) * mpp - thumb = slide.get_thumbnail(np.round(dims_um * 8 / 256).astype(int)) + thumb_size = tuple(np.round(dims_um * 8 / 256).astype(int).tolist()) + thumb = slide.get_thumbnail(thumb_size) thumb_ax.imshow(np.array(thumb)[: attention.shape[0] * 8, : attention.shape[1] * 8]) return np.array(thumb)[: attention.shape[0] * 8, : attention.shape[1] * 8] def _get_thumb_array( - slide, + slide: _SlideLike, attention: torch.Tensor, default_slide_mpp: SlideMPP | None, ) -> np.ndarray: @@ -168,12 +180,12 @@ def _get_thumb_array( """ mpp = get_slide_mpp_(slide, default_mpp=default_slide_mpp) dims_um = np.array(slide.dimensions) * mpp - thumb = np.array(slide.get_thumbnail(np.round(dims_um * 8 / 256).astype(int))) + thumb_size = tuple(np.round(dims_um * 8 / 256).astype(int).tolist()) + thumb = np.array(slide.get_thumbnail(thumb_size)) thumb_crop = thumb[: attention.shape[0] * 8, : attention.shape[1] * 8] return thumb_crop -@no_type_check # beartype<=0.19.0 breaks here for some reason def _show_class_map( class_ax: Axes, top_score_indices: Integer[Tensor, "width height"], @@ -298,13 +310,8 @@ def heatmaps_( raise ValueError( f"Feature file {h5_path} is a slide or patient level feature. Heatmaps are currently supported for tile-level features only." ) - feats = ( - torch.tensor( - h5["feats"][:] # pyright: ignore[reportIndexIssue] - ) - .float() - .to(device) - ) + feats_np = np.asarray(h5["feats"]) + feats = torch.from_numpy(feats_np).float().to(device) coords_info = get_coords(h5) coords_um = torch.from_numpy(coords_info.coords_um).float() stride_um = Microns(get_stride(coords_um)) @@ -322,9 +329,10 @@ def heatmaps_( model = load_model_from_ckpt(checkpoint_path).eval() # TODO: Update version when a newer model logic breaks heatmaps. - if Version(model.stamp_version) < Version("2.4.0"): + stamp_version = str(getattr(model, "stamp_version", "")) + if Version(stamp_version) < Version("2.4.0"): raise ValueError( - f"model has been built with stamp version {model.stamp_version} " + f"model has been built with stamp version {stamp_version} " f"which is incompatible with the current version." ) @@ -356,7 +364,7 @@ def heatmaps_( with torch.no_grad(): scores = torch.softmax( - model.model.forward( + model.model( feats.unsqueeze(-2), coords=coords_um.unsqueeze(-2), mask=torch.zeros( diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index 2caccecb..8ddfb03d 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -121,7 +121,7 @@ def categorical_crossval_( categories_for_export: ( dict[str, list] | list ) = [] # declare upfront to avoid unbound variable warnings - categories: Sequence[GroundTruth] | list | None = [] # type: ignore # declare upfront to avoid unbound variable warnings + categories: Sequence[GroundTruth] | list | None = [] if config.task == "classification": # Determine categories for training (single-target) and for export (supports multi-target) diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 8b30ab11..eadb42f8 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -3,13 +3,13 @@ import logging from collections.abc import Callable, Iterable, Mapping, Sequence from dataclasses import KW_ONLY, dataclass +from io import BytesIO # accept in _BinaryIOLike at runtime from itertools import groupby from pathlib import Path from typing import ( IO, Any, BinaryIO, - Dict, Final, Generic, List, @@ -23,6 +23,9 @@ import numpy as np import pandas as pd import torch + +# Use beartype's typing for PEP-585 deprecation-safe hints +from beartype.typing import Dict from packaging.version import Version from torch import Tensor from torch.utils.data import DataLoader, Dataset @@ -58,7 +61,9 @@ _EncodedTarget: TypeAlias = ( Tensor | dict[str, Tensor] ) # Union of encoded targets or multi-target dict -_BinaryIOLike: TypeAlias = Union[BinaryIO, IO[bytes]] +_BinaryIOLike: TypeAlias = Union[ + BinaryIO, IO[bytes], BytesIO +] # includes io.BytesIO for runtime checks """The ground truth, encoded numerically - classification: one-hot float [C] - regression: float [1] @@ -73,7 +78,7 @@ class PatientData(Generic[GroundTruthType]): _ = KW_ONLY ground_truth: GroundTruthType - feature_files: Iterable[FeaturePath | BinaryIO] + feature_files: Iterable[FeaturePath | _BinaryIOLike] def tile_bag_dataloader( @@ -533,9 +538,19 @@ def __getitem__( for bag_file in self.bags[index]: with h5py.File(bag_file, "r") as h5: if "feats" in h5: - arr = h5["feats"][:] # pyright: ignore[reportIndexIssue] # original STAMP files + feats_obj = h5["feats"] + if not isinstance(feats_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" + ) + arr = feats_obj[:] # original STAMP files else: - arr = h5["patch_embeddings"][:] # type: ignore # your Kronos files + embeddings_obj = h5["patch_embeddings"] + if not isinstance(embeddings_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'patch_embeddings' to be an HDF5 dataset but got {type(embeddings_obj)}" + ) + arr = embeddings_obj[:] # your Kronos files feats.append(torch.from_numpy(arr)) coords_um.append(torch.from_numpy(get_coords(h5).coords_um)) @@ -569,7 +584,7 @@ class PatientFeatureDataset(Dataset): def __init__( self, - feature_files: Sequence[FeaturePath | BinaryIO], + feature_files: Sequence[FeaturePath | _BinaryIOLike], ground_truths: Tensor, # shape: [num_samples, num_classes] transform: Callable[[Tensor], Tensor] | None, ): @@ -585,7 +600,12 @@ def __len__(self): def __getitem__(self, idx: int): feature_file = self.feature_files[idx] with h5py.File(feature_file, "r") as h5: - feats = torch.from_numpy(h5["feats"][:]) # pyright: ignore[reportIndexIssue] + feats_obj = h5["feats"] + if not isinstance(feats_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" + ) + feats = torch.from_numpy(feats_obj[:]) # Accept [V] or [1, V] if feats.ndim == 2 and feats.shape[0] == 1: feats = feats[0] @@ -634,10 +654,15 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: tile_size_px = TilePixels(0) return CoordsInfo(coords_um, tile_size_um, tile_size_px) - coords: np.ndarray = feature_h5["coords"][:] # type: ignore - coords_um: np.ndarray | None = None + coords_obj = feature_h5["coords"] + if not isinstance(coords_obj, h5py.Dataset): + raise RuntimeError( + f"{feature_h5.filename}: expected 'coords' to be an HDF5 dataset but got {type(coords_obj)}" + ) + coords: np.ndarray = coords_obj[:] tile_size_um: Microns | None = None tile_size_px: TilePixels | None = None + coords_um: np.ndarray | None = None if (tile_size := feature_h5.attrs.get("tile_size", None)) and feature_h5.attrs.get( "unit", None ) == "um": @@ -672,7 +697,15 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: ) if not tile_size_px and "tile_size_px" in feature_h5.attrs: - tile_size_px = TilePixels(int(feature_h5.attrs["tile_size_px"])) # pyright: ignore[reportArgumentType] + tile_size_px_attr = feature_h5.attrs.get("tile_size_px") + if tile_size_px_attr is not None and isinstance( + tile_size_px_attr, (int, float) + ): + tile_size_px = TilePixels(int(tile_size_px_attr)) + else: + raise RuntimeError( + "Invalid or missing 'tile_size_px' attribute in the feature file." + ) if not tile_size_um or coords_um is None: raise RuntimeError( diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index 10272328..d3b29ebd 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -276,7 +276,7 @@ def deploy_categorical_model_( for model_i, model in enumerate(models): predictions = _predict( model=model, - test_dl=test_dl, # pyright: ignore[reportPossiblyUnboundVariable] + test_dl=test_dl, patient_ids=patient_ids, accelerator=accelerator, ) diff --git a/src/stamp/modeling/models/__init__.py b/src/stamp/modeling/models/__init__.py index 86003c52..0b6a3885 100644 --- a/src/stamp/modeling/models/__init__.py +++ b/src/stamp/modeling/models/__init__.py @@ -3,11 +3,13 @@ import inspect from abc import ABC from collections.abc import Iterable, Sequence -from typing import Any, Mapping, TypeAlias +from typing import Any, TypeAlias import lightning -import numpy as np import torch + +# Use beartype.typing.Mapping to avoid PEP-585 deprecation warnings in beartype +from beartype.typing import Mapping from jaxtyping import Bool, Float from packaging.version import Version from torch import Tensor, nn, optim @@ -148,6 +150,19 @@ def on_train_batch_end(self, outputs, batch, batch_idx): ) +class _TileLevelMixin: + """Mixin for tile-level models providing shared MIL masking logic.""" + + @staticmethod + def _mask_from_bags(bags: Bags, bag_sizes: BagSizes) -> Bool[Tensor, "batch tile"]: + """Create attention mask for padded tiles in variable-length bags.""" + max_possible_bag_size = bags.size(1) + mask = torch.arange(max_possible_bag_size).type_as(bag_sizes).unsqueeze( + 0 + ).repeat(len(bags), 1) >= bag_sizes.unsqueeze(1) + return mask + + class LitBaseClassifier(Base): """ PyTorch Lightning wrapper for tile level and patient level clasification. @@ -199,15 +214,16 @@ def __init__( self.class_weights = category_weights self.valid_auroc = MulticlassAUROC(len(categories)) # Number classes - self.categories = np.array(categories) + self.categories = list(categories) self.hparams.update({"task": "classification"}) -class LitTileClassifier(LitBaseClassifier): +class LitTileClassifier(_TileLevelMixin, LitBaseClassifier): """ - PyTorch Lightning wrapper for the model used in weakly supervised - learning settings, such as Multiple Instance Learning (MIL) for whole-slide images or patch-based data. + PyTorch Lightning wrapper for tile-level MIL classification. + + Used in weakly supervised settings for whole-slide images or patch-based data. """ supported_features = ["tile"] @@ -249,7 +265,6 @@ def _step( ) if step_name == "validation": - # TODO this is a bit ugly, we'd like to have `_step` without special cases self.valid_auroc.update(logits, targets.long().argmax(dim=-1)) self.log( f"{step_name}_auroc", @@ -291,31 +306,19 @@ def predict_step( # adding a mask here will *drastically* and *unbearably* increase memory usage return self.model(bags, coords=coords, mask=None) - def _mask_from_bags( - *, - bags: Bags, - bag_sizes: BagSizes, - ) -> Bool[Tensor, "batch tile"]: - max_possible_bag_size = bags.size(1) - mask = torch.arange(max_possible_bag_size).type_as(bag_sizes).unsqueeze( - 0 - ).repeat(len(bags), 1) >= bag_sizes.unsqueeze(1) - - return mask - class LitSlideClassifier(LitBaseClassifier): - """ - PyTorch Lightning wrapper for MLPClassifier. - """ + """PyTorch Lightning wrapper for slide/patient-level classification.""" supported_features = ["slide"] def forward(self, x: Tensor) -> Tensor: return self.model(x) - def _step(self, batch, step_name: str): - feats, targets = batch + def _step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], step_name: str + ) -> Loss: + feats, targets = list(batch) # Works for both tuple and list logits = self.model(feats.float()) loss = nn.functional.cross_entropy( logits, @@ -341,17 +344,25 @@ def _step(self, batch, step_name: str): ) return loss - def training_step(self, batch, batch_idx): + def training_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch, "training") - def validation_step(self, batch, batch_idx): + def validation_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch, "validation") - def test_step(self, batch, batch_idx): + def test_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch, "test") - def predict_step(self, batch, batch_idx): - feats, _ = batch + def predict_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Tensor: + feats, _ = batch if isinstance(batch, tuple) else batch return self.model(feats) @@ -402,9 +413,10 @@ def _compute_loss(y_true: Tensor, y_pred: Tensor) -> Loss: return nn.functional.l1_loss(y_true, y_pred) -class LitTileRegressor(LitBaseRegressor): +class LitTileRegressor(_TileLevelMixin, LitBaseRegressor): """ - PyTorch Lightning wrapper for weakly supervised / MIL regression at tile/patient level. + PyTorch Lightning wrapper for tile-level MIL regression. + Produces a single continuous output per bag (dim_output = 1). """ @@ -491,38 +503,22 @@ def predict_step( # keep memory usage low as in classifier return self.model(bags, coords=coords, mask=None) - def _mask_from_bags( - *, - bags: Bags, - bag_sizes: BagSizes, - ) -> Bool[Tensor, "batch tile"]: - max_possible_bag_size = bags.size(1) - mask = torch.arange(max_possible_bag_size).type_as(bag_sizes).unsqueeze( - 0 - ).repeat(len(bags), 1) >= bag_sizes.unsqueeze(1) - - return mask - class LitSlideRegressor(LitBaseRegressor): - """ - PyTorch Lightning wrapper for slide-level or patient-level regression. - Produces a single continuous output per slide (dim_output = 1). - """ + """PyTorch Lightning wrapper for slide/patient-level regression.""" supported_features = ["slide"] def forward(self, feats: Tensor) -> Tensor: - """Forward pass for slide-level features.""" return self.model(feats.float()) def _step( self, *, - batch: tuple[Tensor, Tensor], + batch: tuple[Tensor, Tensor] | list[Tensor], step_name: str, ) -> Loss: - feats, targets = batch + feats, targets = list(batch) # Works for both tuple and list preds = self.model(feats.float(), mask=None) # (B, 1) y = targets.to(preds).float() @@ -539,7 +535,6 @@ def _step( ) if step_name == "validation": - # same metrics as LitTileRegressor p = preds.squeeze(-1) t = y.squeeze(-1) self.log( @@ -552,17 +547,25 @@ def _step( return loss - def training_step(self, batch, batch_idx): + def training_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch=batch, step_name="training") - def validation_step(self, batch, batch_idx): + def validation_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch=batch, step_name="validation") - def test_step(self, batch, batch_idx): + def test_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch=batch, step_name="test") - def predict_step(self, batch, batch_idx): - feats, _ = batch + def predict_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Tensor: + feats, _ = batch if isinstance(batch, tuple) else batch return self.model(feats.float()) @@ -707,12 +710,12 @@ def on_train_epoch_end(self): self.hparams.update({"train_pred_median": self.train_pred_median}) -class LitTileSurvival(LitSurvivalBase): +class LitTileSurvival(_TileLevelMixin, LitSurvivalBase): """ - Tile-level or patch-level survival analysis. - Expects dataloader batches like: - (bags, coords, bag_sizes, targets) - where targets is shape (B,2): [:,0]=time, [:,1]=event (1=event, 0=censored). + Tile-level survival analysis with Cox proportional hazards loss. + + Expects batches: (bags, coords, bag_sizes, targets) + where targets.shape = (B, 2): [:,0]=time, [:,1]=event (0=censored, 1=event). """ supported_features = ["tile"] diff --git a/src/stamp/modeling/models/mlp.py b/src/stamp/modeling/models/mlp.py index e4f8881f..e88a77ca 100644 --- a/src/stamp/modeling/models/mlp.py +++ b/src/stamp/modeling/models/mlp.py @@ -29,7 +29,7 @@ def __init__( layers.append(nn.Dropout(dropout)) in_dim = dim_hidden layers.append(nn.Linear(in_dim, dim_output)) - self.mlp = nn.Sequential(*layers) # type: ignore + self.mlp = nn.Sequential(*layers) @jaxtyped(typechecker=beartype) def forward( diff --git a/src/stamp/preprocessing/extractor/ctranspath.py b/src/stamp/preprocessing/extractor/ctranspath.py index 387e7947..d189e279 100644 --- a/src/stamp/preprocessing/extractor/ctranspath.py +++ b/src/stamp/preprocessing/extractor/ctranspath.py @@ -518,7 +518,7 @@ def forward(self, x, mask: Optional[torch.Tensor] = None): attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.view(-1) + self.relative_position_index.view(-1) # pyright: ignore[reportCallIssue] ].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], diff --git a/src/stamp/preprocessing/tiling.py b/src/stamp/preprocessing/tiling.py index e09a06fa..ce684ba4 100644 --- a/src/stamp/preprocessing/tiling.py +++ b/src/stamp/preprocessing/tiling.py @@ -461,7 +461,10 @@ def _extract_mpp_from_metadata(slide: openslide.AbstractSlide) -> SlideMPP | Non return None doc = minidom.parseString(xml_path) collection = doc.documentElement - images = collection.getElementsByTagName("Image") # pyright: ignore[reportOptionalMemberAccess] + if collection is None: + _logger.error("Document element is None, unable to extract MPP.") + return None + images = collection.getElementsByTagName("Image") pixels = images[0].getElementsByTagName("Pixels") mpp = float(pixels[0].getAttribute("PhysicalSizeX")) except Exception: diff --git a/src/stamp/statistics/__init__.py b/src/stamp/statistics/__init__.py index bdbef1fa..b3243ecc 100644 --- a/src/stamp/statistics/__init__.py +++ b/src/stamp/statistics/__init__.py @@ -15,7 +15,10 @@ from matplotlib import pyplot as plt from pydantic import BaseModel, ConfigDict, Field -from stamp.statistics.categorical import categorical_aggregated_ +from stamp.statistics.categorical import ( + categorical_aggregated_, + categorical_aggregated_multitarget_, +) from stamp.statistics.prc import ( plot_multiple_decorated_precision_recall_curves, plot_single_decorated_precision_recall_curve, @@ -51,7 +54,7 @@ class StatsConfig(BaseModel): task: Task = Field(default="classification") output_dir: Path pred_csvs: list[Path] - ground_truth_label: PandasLabel | None = None + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None true_class: str | None = None time_label: str | None = None status_label: str | None = None @@ -60,52 +63,65 @@ class StatsConfig(BaseModel): _Inches = NewType("_Inches", float) -def compute_stats_( +def _compute_multitarget_classification_stats( *, - task: Task, output_dir: Path, pred_csvs: Sequence[Path], - ground_truth_label: PandasLabel | None = None, - true_class: str | None = None, - time_label: str | None = None, - status_label: str | None = None, + target_labels: Sequence[str], ) -> None: - """Compute and save statistics for the provided task and prediction CSVs. + """Compute statistics and plots for multi-target classification. - This wrapper keeps the external API stable while delegating the detailed - computations and plotting to the submodules under `stamp.statistics.*`. + For each target, creates ROC and PRC curves for each class, + similar to single-target classification. """ - match task: - case "classification": - if true_class is None or ground_truth_label is None: - raise ValueError( - "both true_class and ground_truth_label are required in statistic configuration" - ) - - preds_dfs = [ - _read_table( - p, - usecols=[ground_truth_label, f"{ground_truth_label}_{true_class}"], - dtype={ - ground_truth_label: str, - f"{ground_truth_label}_{true_class}": float, - }, - ) - for p in pred_csvs - ] - - y_trues = [ - np.array(df[ground_truth_label] == true_class) for df in preds_dfs - ] - y_preds = [ - np.array(df[f"{ground_truth_label}_{true_class}"].values) - for df in preds_dfs - ] - n_bootstrap_samples = 1000 - figure_width = _Inches(3.8) - threshold_cmap = None - - roc_curve_figure_aspect_ratio = 1.08 + output_dir.mkdir(parents=True, exist_ok=True) + n_bootstrap_samples = 1000 + figure_width = _Inches(3.8) + roc_curve_figure_aspect_ratio = 1.08 + + # Validate all target labels exist in CSV + first_df = _read_table(pred_csvs[0], nrows=0) + missing_targets = [t for t in target_labels if t not in first_df.columns] + if missing_targets: + raise ValueError( + f"Target labels not found in CSV: {missing_targets}. Available columns: {list(first_df.columns)}" + ) + + # Process each target + for target_label in target_labels: + # Load data for this target + preds_dfs = [] + for p in pred_csvs: + df = _read_table(p, dtype=str) + # Only keep rows where this target has ground truth + df_clean = df.dropna(subset=[target_label]) + if len(df_clean) > 0: + preds_dfs.append(df_clean) + + if not preds_dfs: + continue + + # Get unique classes for this target + classes = sorted(preds_dfs[0][target_label].unique()) + + # Create plots for each class in this target + for true_class in classes: + # Extract ground truth and predictions for this class + y_trues = [] + y_preds = [] + + for df in preds_dfs: + prob_col = f"{target_label}_{true_class}" + if prob_col not in df.columns: + continue + + y_trues.append(np.array(df[target_label] == true_class)) + y_preds.append(np.array(df[prob_col].astype(float).values)) + + if not y_trues: + continue + + # Plot ROC curve fig, ax = plt.subplots( figsize=(figure_width, figure_width * roc_curve_figure_aspect_ratio), dpi=300, @@ -116,34 +132,35 @@ def compute_stats_( ax=ax, y_true=y_trues[0], y_score=y_preds[0], - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", n_bootstrap_samples=n_bootstrap_samples, - threshold_cmap=threshold_cmap, + threshold_cmap=None, ) else: plot_multiple_decorated_roc_curves( ax=ax, y_trues=y_trues, y_scores=y_preds, - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", n_bootstrap_samples=None, ) fig.tight_layout() - output_dir.mkdir(parents=True, exist_ok=True) - fig.savefig(output_dir / f"roc-curve_{ground_truth_label}={true_class}.svg") + fig.savefig(output_dir / f"roc-curve_{target_label}={true_class}.svg") plt.close(fig) + # Plot PRC curve fig, ax = plt.subplots( figsize=(figure_width, figure_width * roc_curve_figure_aspect_ratio), dpi=300, ) + if len(preds_dfs) == 1: plot_single_decorated_precision_recall_curve( ax=ax, y_true=y_trues[0], y_score=y_preds[0], - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", n_bootstrap_samples=n_bootstrap_samples, ) else: @@ -151,24 +168,170 @@ def compute_stats_( ax=ax, y_trues=y_trues, y_scores=y_preds, - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", ) fig.tight_layout() - fig.savefig(output_dir / f"pr-curve_{ground_truth_label}={true_class}.svg") + fig.savefig(output_dir / f"pr-curve_{target_label}={true_class}.svg") plt.close(fig) - categorical_aggregated_( - preds_csvs=pred_csvs, - ground_truth_label=ground_truth_label, - outpath=output_dir, + # Compute aggregated statistics for all targets + categorical_aggregated_multitarget_( + preds_csvs=pred_csvs, + outpath=output_dir, + target_labels=target_labels, + ) + + +def compute_stats_( + *, + task: Task, + output_dir: Path, + pred_csvs: Sequence[Path], + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None, + true_class: str | None = None, + time_label: str | None = None, + status_label: str | None = None, +) -> None: + """Compute and save statistics for the provided task and prediction CSVs. + + This wrapper keeps the external API stable while delegating the detailed + computations and plotting to the submodules under `stamp.statistics.*`. + """ + match task: + case "classification": + # Check if multi-target based on ground_truth_label type + is_multitarget = ( + isinstance(ground_truth_label, (list, tuple)) + and len(ground_truth_label) > 1 ) + if is_multitarget: + # Multi-target classification + if not isinstance(ground_truth_label, (list, tuple)): + raise ValueError( + "ground_truth_label must be a list or tuple for multi-target classification" + ) + _compute_multitarget_classification_stats( + output_dir=output_dir, + pred_csvs=pred_csvs, + target_labels=list(ground_truth_label), + ) + else: + # Single-target classification (original behavior) + if true_class is None or ground_truth_label is None: + raise ValueError( + "both true_class and ground_truth_label are required in statistic configuration" + ) + if not isinstance(ground_truth_label, str): + raise ValueError( + "ground_truth_label must be a string for single-target classification" + ) + + preds_dfs = [ + _read_table( + p, + usecols=[ + ground_truth_label, + f"{ground_truth_label}_{true_class}", + ], + dtype={ + ground_truth_label: str, + f"{ground_truth_label}_{true_class}": float, + }, + ) + for p in pred_csvs + ] + + y_trues = [ + np.array(df[ground_truth_label] == true_class) for df in preds_dfs + ] + y_preds = [ + np.array(df[f"{ground_truth_label}_{true_class}"].values) + for df in preds_dfs + ] + n_bootstrap_samples = 1000 + figure_width = _Inches(3.8) + threshold_cmap = None + + roc_curve_figure_aspect_ratio = 1.08 + fig, ax = plt.subplots( + figsize=( + figure_width, + figure_width * roc_curve_figure_aspect_ratio, + ), + dpi=300, + ) + + if len(preds_dfs) == 1: + plot_single_decorated_roc_curve( + ax=ax, + y_true=y_trues[0], + y_score=y_preds[0], + title=f"{ground_truth_label} = {true_class}", + n_bootstrap_samples=n_bootstrap_samples, + threshold_cmap=threshold_cmap, + ) + else: + plot_multiple_decorated_roc_curves( + ax=ax, + y_trues=y_trues, + y_scores=y_preds, + title=f"{ground_truth_label} = {true_class}", + n_bootstrap_samples=None, + ) + + fig.tight_layout() + output_dir.mkdir(parents=True, exist_ok=True) + fig.savefig( + output_dir / f"roc-curve_{ground_truth_label}={true_class}.svg" + ) + plt.close(fig) + + fig, ax = plt.subplots( + figsize=( + figure_width, + figure_width * roc_curve_figure_aspect_ratio, + ), + dpi=300, + ) + if len(preds_dfs) == 1: + plot_single_decorated_precision_recall_curve( + ax=ax, + y_true=y_trues[0], + y_score=y_preds[0], + title=f"{ground_truth_label} = {true_class}", + n_bootstrap_samples=n_bootstrap_samples, + ) + else: + plot_multiple_decorated_precision_recall_curves( + ax=ax, + y_trues=y_trues, + y_scores=y_preds, + title=f"{ground_truth_label} = {true_class}", + ) + + fig.tight_layout() + fig.savefig( + output_dir / f"pr-curve_{ground_truth_label}={true_class}.svg" + ) + plt.close(fig) + + categorical_aggregated_( + preds_csvs=pred_csvs, + ground_truth_label=ground_truth_label, + outpath=output_dir, + ) + case "regression": if ground_truth_label is None: raise ValueError( "no ground_truth_label configuration supplied in statistic" ) + if not isinstance(ground_truth_label, str): + raise ValueError( + "ground_truth_label must be a string for regression (multi-target regression not yet supported)" + ) regression_aggregated_( preds_csvs=pred_csvs, ground_truth_label=ground_truth_label, diff --git a/src/stamp/statistics/categorical.py b/src/stamp/statistics/categorical.py index 30a03a86..9d6c4c12 100755 --- a/src/stamp/statistics/categorical.py +++ b/src/stamp/statistics/categorical.py @@ -62,29 +62,29 @@ def _categorical(preds_df: pd.DataFrame, target_label: str) -> pd.DataFrame: # roc_auc stats_df["roc_auc_score"] = [ - metrics.roc_auc_score(y_true == cat, y_pred[:, i]) # pyright: ignore[reportCallIssue,reportArgumentType] + metrics.roc_auc_score(y_true == cat, y_pred[:, i]) for i, cat in enumerate(categories) ] # average_precision stats_df["average_precision_score"] = [ - metrics.average_precision_score(y_true == cat, y_pred[:, i]) # pyright: ignore[reportCallIssue,reportArgumentType] + metrics.average_precision_score(y_true == cat, y_pred[:, i]) for i, cat in enumerate(categories) ] # f1 score y_pred_labels = categories[y_pred.argmax(axis=1)] stats_df["f1_score"] = [ - metrics.f1_score(y_true == cat, y_pred_labels == cat) # pyright: ignore[reportCallIssue,reportArgumentType] - for cat in categories + metrics.f1_score(y_true == cat, y_pred_labels == cat) for cat in categories ] # p values p_values = [] for i, cat in enumerate(categories): - pos_scores = y_pred[:, i][y_true == cat] # pyright: ignore[reportCallIssue,reportArgumentType] - neg_scores = y_pred[:, i][y_true != cat] # pyright: ignore[reportCallIssue,reportArgumentType] - p_values.append(st.ttest_ind(pos_scores, neg_scores).pvalue) # pyright: ignore[reportGeneralTypeIssues, reportAttributeAccessIssue] + pos_scores = y_pred[:, i][y_true == cat] + neg_scores = y_pred[:, i][y_true != cat] + _, p_value = st.ttest_ind(pos_scores, neg_scores) + p_values.append(p_value) stats_df["p_value"] = p_values assert set(_score_labels) & set(stats_df.columns) == set(_score_labels) diff --git a/src/stamp/utils/target_file.py b/src/stamp/utils/target_file.py deleted file mode 100644 index 08c30b6b..00000000 --- a/src/stamp/utils/target_file.py +++ /dev/null @@ -1,351 +0,0 @@ -"""Automatically generate target information from clini table - -# The `barspoon-targets 2.0` File Format - -A barspoon target file is a [TOML][1] file with the following entries: - - - A `version` key mapping to a version string `"barspoon-targets "`, where - `` is a [PEP-440 version string][2] compatible with `2.0`. - - A `targets` table, the keys of which are target labels (as found in the - clinical table) and the values specify exactly one of the following: - 1. A categorical target label, marked by the presence of a `categories` - key-value pair. - 2. A target label to quantize, marked by the presence of a `thresholds` - key-value pair. - 3. A target format defined in in a later version of barspoon targets. - A target may only ever have one of the fields `categories` or `thresholds`. - A definition of these entries can be found below. - -[1]: https://toml.io "Tom's Obvious Minimal Language" -[2]: https://peps.python.org/pep-0440/ - "PEP 440 - Version Identification and Dependency Specification" - -## Categorical Target Label - -A categorical target is a target table with a key-value pair `categories`. -`categories` contains a list of lists of literal strings. Each list of strings -will be treated as one category, with all literal strings within that list being -treated as one representative for that category. This allows the user to easily -group related classes into one large class (i.e. `"True", "1", "Yes"` could all -be unified into the same category). - -### Category Weights - -It is possible to assign a weight to each category, to e.g. weigh rarer classes -more heavily. The weights are stored in a table `targets.LABEL.class_weights`, -whose keys is the first representative of each category, and the values of which -is the weight of the category as a floating point number. - -## Target Label to Quantize - -If a target has the `thresholds` option key set, it is interpreted as a -continuous target which has to be quantized. `thresholds` has to be a list of -floating point numbers [t_0, t_n], n > 1 containing the thresholds of the bins -to quantize the values into. A categorical target will be quantized into bins - -```asciimath -b_0 = [t_0; t_1], b_1 = (t_1; b_2], ... b_(n-1) = (t_(n-1); t_n] -``` - -The bins will be treated as categories with names -`f"[{t_0:+1.2e};{t_1:+1.2e}]"` for the first bin and -`f"({t_i:+1.2e};{t_(i+1):+1.2e}]"` for all other bins - -To avoid confusion, we recommend to also format the `thresholds` list the same -way. - -The bins can also be weighted. See _Categorical Target Label: Category Weights_ -for details. - - > Experience has shown that many labels contain non-negative values with a - > disproportionate amount (more than n_samples/n_bins) of zeroes. We thus - > decided to make the _right_ side of each bin inclusive, as the bin (-A,0] - > then naturally includes those zero values. -""" - -import logging -from pathlib import Path -from typing import ( - Any, - Dict, - List, - NamedTuple, - Optional, - Sequence, - TextIO, - Tuple, -) - -import numpy as np -import numpy.typing as npt -import pandas as pd -import torch -import torch.nn.functional as F -from packaging.specifiers import Specifier - - -def read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: - if not isinstance(path, Path): - return pd.read_csv(path, **kwargs) - elif path.suffix == ".xlsx": - return pd.read_excel(path, **kwargs) - elif path.suffix == ".csv": - return pd.read_csv(path, **kwargs) - else: - raise ValueError( - "table to load has to either be an excel (`*.xlsx`) or csv (`*.csv`) file." - ) - - -__all__ = ["build_targets", "decode_targets"] - - -class TargetSpec(NamedTuple): - version: str - targets: Dict[str, Dict[str, Any]] - - -class EncodedTarget(NamedTuple): - categories: List[str] - encoded: torch.Tensor - weight: torch.Tensor - - -def encode_category( - *, - clini_df: pd.DataFrame, - target_label: str, - categories: Sequence[List[str]], - class_weights: Optional[Dict[str, float]] = None, - **ignored, -) -> Tuple[List[str], torch.Tensor, torch.Tensor]: - # Map each category to its index - category_map = {member: idx for idx, cat in enumerate(categories) for member in cat} - - # Map each item to it's category's index, mapping nans to num_classes+1 - # This way we can easily discard the NaN column later - indexes = clini_df[target_label].map(lambda c: category_map.get(c, len(categories))) - indexes = torch.tensor(indexes.values) - - # Discard nan column - one_hot = F.one_hot(indexes, num_classes=len(categories) + 1)[:, :-1] - - # Class weights - if class_weights is not None: - weight = torch.tensor([class_weights[c[0]] for c in categories]) - else: - # No class weights given; use normalized inverse frequency - counts = one_hot.sum(dim=0) - weight = (w := (counts.sum() / counts)) / w.sum() - - # Warn user of unused labels - if ignored: - logging.warn(f"ignored labels in target {target_label}: {ignored}") - - return [c[0] for c in categories], one_hot, weight - - -def encode_quantize( - *, - clini_df: pd.DataFrame, - target_label: str, - thresholds: npt.NDArray[np.floating[Any]], - class_weights: Optional[Dict[str, float]] = None, - **ignored, -) -> Tuple[List[str], torch.Tensor, torch.Tensor]: - # Warn user of unused labels - if ignored: - logging.warn(f"ignored labels in target {target_label}: {ignored}") - - n_bins = len(thresholds) - 1 - numeric_vals = torch.tensor(pd.to_numeric(clini_df[target_label]).values).reshape( - -1, 1 - ) - - # Map each value to a class index as follows: - # 1. If the value is NaN or less than the left-most threshold, use class - # index 0 - # 2. If it is between the left-most and the right-most threshold, set it to - # the bin number (starting from 1) - # 3. If it is larger than the right-most threshold, set it to N_bins + 1 - bin_index = ( - (numeric_vals > torch.tensor(thresholds).reshape(1, -1)).count_nonzero(1) - # For the first bucket, we have to include the lower threshold - + (numeric_vals.reshape(-1) == thresholds[0]) - ) - # One hot encode and discard nan columns (first and last col) - one_hot = F.one_hot(bin_index, num_classes=n_bins + 2)[:, 1:-1] - - # Class weights - categories = [ - f"[{thresholds[0]:+1.2e};{thresholds[1]:+1.2e}]", - *( - f"({lower:+1.2e};{upper:+1.2e}]" - for lower, upper in zip(thresholds[1:-1], thresholds[2:], strict=True) - ), - ] - - if class_weights is not None: - weight = torch.tensor([class_weights[c] for c in categories]) - else: - # No class weights given; use normalized inverse frequency - counts = one_hot.sum(0) - weight = (w := (np.divide(counts.sum(), counts, where=counts > 0))) / w.sum() - - return categories, one_hot, weight - - -def decode_targets( - encoded: torch.Tensor, - *, - target_labels: Sequence[str], - targets: Dict[str, Any], - version: str = "barspoon-targets 2.0", - **ignored, -) -> List[np.ndarray]: - name, version = version.split(" ") - spec = Specifier("~=2.0") - - if not (name == "barspoon-targets" and spec.contains(version)): - raise ValueError( - f"incompatible target file: expected barspoon-targets{spec}, found `{name} {version}`" - ) - - # Warn user of unused labels - if ignored: - logging.warn(f"ignored parameters: {ignored}") - - decoded_targets = [] - curr_col = 0 - for target_label in target_labels: - info = targets[target_label] - - if (categories := info.get("categories")) is not None: - # Add another column which is one iff all the other values are zero - encoded_target = encoded[:, curr_col : curr_col + len(categories)] - is_none = ~encoded_target.any(dim=1).view(-1, 1) - encoded_target = torch.cat([encoded_target, is_none], dim=1) - - # Decode to class labels - representatives = np.array([c[0] for c in categories] + [None]) - category_index = encoded_target.argmax(dim=1) - decoded = representatives[category_index] - decoded_targets.append(decoded) - - curr_col += len(categories) - - elif (thresholds := info.get("thresholds")) is not None: - n_bins = len(thresholds) - 1 - encoded_target = encoded[:, curr_col : curr_col + n_bins] - is_none = ~encoded_target.any(dim=1).view(-1, 1) - encoded_target = torch.cat([encoded_target, is_none], dim=1) - - bin_edges = [-np.inf, *thresholds, np.inf] - representatives = np.array( - [ - f"[{lower:+1.2e};{upper:+1.2e})" - for lower, upper in zip(bin_edges[:-1], bin_edges[1:]) - ] - ) - decoded = representatives[encoded_target.argmax(dim=1)] - - decoded_targets.append(decoded) - - curr_col += n_bins - - else: - raise ValueError(f"cannot decode {target_label}: no target info") - - return decoded_targets - - -def build_targets( - *, - clini_tables: Sequence[Path], - categorical_labels: Sequence[str], - category_min_count: int = 32, - quantize: Sequence[tuple[str, int]] = (), -) -> Dict[str, EncodedTarget]: - clini_df = pd.concat([read_table(c) for c in clini_tables]) - encoded_targets: Dict[str, EncodedTarget] = {} - - # categorical targets - for target_label in categorical_labels: - counts = clini_df[target_label].value_counts() - well_supported = counts[counts >= category_min_count] - - if len(well_supported) <= 1: - continue - - categories = [[str(cat)] for cat in well_supported.index] - - weights = well_supported.sum() / well_supported - weights /= weights.sum() - - representatives, encoded, weight = encode_category( - clini_df=clini_df, - target_label=target_label, - categories=categories, - class_weights=weights.to_dict(), - ) - - encoded_targets[target_label] = EncodedTarget( - categories=representatives, - encoded=encoded, - weight=weight, - ) - - # quantized targets - for target_label, bincount in quantize: - vals = pd.to_numeric(clini_df[target_label]).dropna() - - if vals.empty: - continue - - vals_clamped = vals.replace( - { - -np.inf: vals[vals != -np.inf].min(), - np.inf: vals[vals != np.inf].max(), - } - ) - - thresholds = np.array( - [ - -np.inf, - *np.quantile(vals_clamped, q=np.linspace(0, 1, bincount + 1))[1:-1], - np.inf, - ], - dtype=float, - ) - - representatives, encoded, weight = encode_quantize( - clini_df=clini_df, - target_label=target_label, - thresholds=thresholds, - ) - - if encoded.shape[1] <= 1: - continue - - encoded_targets[target_label] = EncodedTarget( - categories=representatives, - encoded=encoded, - weight=weight, - ) - - return encoded_targets - - -if __name__ == "__main__": - encoded = build_targets( - clini_tables=[ - Path( - "/mnt/bulk-neptune/nguyenmin/stamp-dev/experiments/survival_prediction/TCGA-CRC-DX_CLINI.xlsx" - ) - ], - categorical_labels=["BRAF", "KRAS", "NRAS"], - category_min_count=32, - quantize=[], - ) - for name, enc in encoded.items(): - print(name, enc.encoded.shape) diff --git a/tests/random_data.py b/tests/random_data.py index c7c36880..b79c9f42 100644 --- a/tests/random_data.py +++ b/tests/random_data.py @@ -74,13 +74,13 @@ def create_random_dataset( clini_df = pd.DataFrame( patient_to_ground_truth.items(), - columns=["patient", "ground-truth"], # pyright: ignore[reportArgumentType] + columns=["patient", "ground-truth"], ) clini_df.to_csv(clini_path, index=False) slide_df = pd.DataFrame( slide_path_to_patient.items(), - columns=["slide_path", "patient"], # pyright: ignore[reportArgumentType] + columns=["slide_path", "patient"], ) slide_df.to_csv(slide_path, index=False) @@ -130,7 +130,7 @@ def create_random_regression_dataset( # --- Write clini + slide tables --- clini_df = pd.DataFrame(patient_to_target, columns=["patient", "target"]) - clini_df["target"] = clini_df["target"].astype(float) # ✅ ensure numeric dtype + clini_df["target"] = clini_df["target"].astype(float) # ensure numeric dtype clini_df.to_csv(clini_path, index=False) slide_df = pd.DataFrame( diff --git a/tests/test_data.py b/tests/test_data.py index 3c86a931..a60d830e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -3,6 +3,7 @@ from pathlib import Path import h5py +import pandas as pd import pytest import torch from random_data import ( @@ -21,6 +22,7 @@ PatientFeatureDataset, filter_complete_patient_data_, get_coords, + patient_to_ground_truth_from_clini_table_, slide_to_patient_from_slide_table_, ) from stamp.types import ( @@ -85,6 +87,31 @@ def test_get_cohort_df(tmp_path: Path) -> None: } +def test_patient_to_ground_truth_multi_target(tmp_path: Path) -> None: + """Verify multi-target clini parsing returns dicts and drops rows missing all targets.""" + df = pd.DataFrame( + { + "patient": ["p1", "p2", "p3", "p4"], + "subtype": ["A", None, "B", None], + "grade": ["1", "2", None, None], + } + ) + df.to_csv(tmp_path / "clini.csv", index=False) + + result = patient_to_ground_truth_from_clini_table_( + clini_table_path=tmp_path / "clini.csv", + patient_label="patient", + ground_truth_label=["subtype", "grade"], + ) + + # p4 has both targets missing → dropped + assert "p4" not in result + + assert result["p1"] == {"subtype": "A", "grade": "1"} + assert result["p2"] == {"subtype": None, "grade": "2"} + assert result["p3"] == {"subtype": "B", "grade": None} + + @pytest.mark.parametrize( "feature_file_creator", [make_feature_file, make_old_feature_file], diff --git a/tests/test_deployment.py b/tests/test_deployment.py index 4e1570cc..7d1d6589 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -167,7 +167,7 @@ def test_to_prediction_df(task: str) -> None: ) if task == "classification": preds_df = _to_prediction_df( - categories=list(model.categories), # type: ignore + categories=list(cast(list, model.categories)), patient_to_ground_truth={ PatientId("pat5"): GroundTruth("foo"), PatientId("pat6"): None, @@ -196,13 +196,13 @@ def test_to_prediction_df(task: str) -> None: # Check if no loss / target is given for targets with missing ground truths no_ground_truth = preds_df[preds_df["patient"].isin(["pat6"])] - assert no_ground_truth["target"].isna().all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] - assert no_ground_truth["loss"].isna().all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] + assert no_ground_truth["target"].isna().all() + assert no_ground_truth["loss"].isna().all() # Check if loss / target is given for targets with ground truths with_ground_truth = preds_df[preds_df["patient"].isin(["pat5", "pat7"])] - assert (~with_ground_truth["target"].isna()).all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] - assert (~with_ground_truth["loss"].isna()).all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] + assert (~with_ground_truth["target"].isna()).all() + assert (~with_ground_truth["loss"].isna()).all() elif task == "regression": patient_to_ground_truth = {} From f08e33ad705e3d028df75c643d2b1cee00d8ebf6 Mon Sep 17 00:00:00 2001 From: mducducd Date: Tue, 17 Feb 2026 16:16:41 +0000 Subject: [PATCH 05/11] refactor survival training/validation --- src/stamp/modeling/crossval.py | 218 ++++++++++++++++++-------- src/stamp/modeling/data.py | 140 +++++++++++------ src/stamp/modeling/deploy.py | 14 +- src/stamp/modeling/models/__init__.py | 43 +++-- src/stamp/modeling/train.py | 173 ++++++++++++++++++-- src/stamp/types.py | 2 + 6 files changed, 453 insertions(+), 137 deletions(-) diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index 8ddfb03d..26196065 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -22,7 +22,7 @@ _to_survival_prediction_df, load_model_from_ckpt, ) -from stamp.modeling.train import setup_model_for_training, train_model_ +from stamp.modeling.train import setup_model_from_dataloaders, train_model_ from stamp.modeling.transforms import VaryPrecisionTransform from stamp.types import ( GroundTruth, @@ -80,19 +80,28 @@ def categorical_crossval_( # Generate the splits, or load them from the splits file if they already exist if not splits_file.exists(): - splits = ( - _get_splits( - patient_to_data=patient_to_data, - n_splits=config.n_splits, - spliter=KFold, - ) - if config.task == "regression" - else _get_splits( - patient_to_data=patient_to_data, - n_splits=config.n_splits, - spliter=StratifiedKFold, - ) + # Detect multi-target classification (ground_truth is a dict) + is_multitarget = any( + isinstance(pd.ground_truth, dict) for pd in patient_to_data.values() ) + + # Use KFold for regression or multi-target classification; otherwise StratifiedKFold. + # For survival we want StratifiedKFold so folds are balanced by event status. + spliter = ( + KFold + if (config.task == "regression" or is_multitarget) + else StratifiedKFold + ) + + _logger.info(f"Using {spliter.__name__} for cross-validation splits") + + splits = _get_splits( + patient_to_data=patient_to_data, + n_splits=config.n_splits, + spliter=spliter, + task=config.task, + ) + with open(splits_file, "w") as fp: fp.write(splits.model_dump_json(indent=4)) else: @@ -183,48 +192,93 @@ def categorical_crossval_( # Train the model if not (split_dir / "model.ckpt").exists(): - model, train_dl, valid_dl = setup_model_for_training( - clini_table=config.clini_table, - slide_table=config.slide_table, - feature_dir=config.feature_dir, - ground_truth_label=config.ground_truth_label, - time_label=config.time_label, - status_label=config.status_label, - advanced=advanced, - task=config.task, - patient_to_data={ - patient_id: patient_data - for patient_id, patient_data in patient_to_data.items() - if patient_id in split.train_patients - }, - categories=( - categories - if categories is not None - else ( - sorted( - { - patient_data.ground_truth - for patient_data in patient_to_data.values() - if patient_data.ground_truth is not None - and not isinstance(patient_data.ground_truth, dict) - } - ) - if not isinstance(config.ground_truth_label, Sequence) - else None + # Build train and test dataloaders directly (pure 2-way k-fold split) + train_patient_ids = [ + pid for pid in split.train_patients if pid in patient_to_data + ] + test_patient_ids = [ + pid for pid in split.test_patients if pid in patient_to_data + ] + train_patient_data = [patient_to_data[pid] for pid in train_patient_ids] + test_patient_data = [patient_to_data[pid] for pid in test_patient_ids] + + fold_categories = ( + categories + if categories is not None + else ( + sorted( + { + patient_data.ground_truth + for patient_data in patient_to_data.values() + if patient_data.ground_truth is not None + and not isinstance(patient_data.ground_truth, dict) + } ) - ), - train_transform=( - VaryPrecisionTransform(min_fraction_bits=1) - if config.use_vary_precision_transform + if not isinstance(config.ground_truth_label, Sequence) else None - ), + ) + ) + + train_transform = ( + VaryPrecisionTransform(min_fraction_bits=1) + if config.use_vary_precision_transform + else None + ) + + train_dl, train_categories = create_dataloader( + feature_type=feature_type, + task=config.task, + patient_data=train_patient_data, + bag_size=advanced.bag_size, + batch_size=advanced.batch_size, + shuffle=True, + num_workers=advanced.num_workers, + transform=train_transform, + categories=fold_categories, + ) + test_dl, _ = create_dataloader( + feature_type=feature_type, + task=config.task, + patient_data=test_patient_data, + bag_size=None, + batch_size=1, + shuffle=False, + num_workers=advanced.num_workers, + transform=None, + categories=train_categories, + ) + + # Infer feature dimension + batch = next(iter(train_dl)) + if feature_type == "tile": + bags, _, _, _ = batch + dim_feats = bags.shape[-1] + else: + feats, _ = batch + dim_feats = feats.shape[-1] + + model = setup_model_from_dataloaders( + train_dl=train_dl, + valid_dl=test_dl, + task=config.task, + train_categories=train_categories, + dim_feats=dim_feats, + train_patients=train_patient_ids, + valid_patients=test_patient_ids, feature_type=feature_type, + advanced=advanced, + ground_truth_label=config.ground_truth_label, + time_label=config.time_label, + status_label=config.status_label, + clini_table=config.clini_table, + slide_table=config.slide_table, + feature_dir=config.feature_dir, ) model = train_model_( output_dir=split_dir, model=model, train_dl=train_dl, - valid_dl=valid_dl, + valid_dl=test_dl, max_epochs=advanced.max_epochs, patience=advanced.patience, accelerator=advanced.accelerator, @@ -235,9 +289,9 @@ def categorical_crossval_( else: model = load_model_from_ckpt(split_dir / "model.ckpt") - # Deploy on test set + # Deploy on test fold (used as validation/prediction set) if not (split_dir / "patient-preds.csv").exists(): - # Prepare test dataloader + # Prepare validation dataloader for predictions test_patients = [ pid for pid in split.test_patients if pid in patient_to_data ] @@ -262,19 +316,24 @@ def categorical_crossval_( ) if config.task == "survival": - if isinstance(config.ground_truth_label, str): + # Export only when patients have single-target survival labels ("time status"). + # Don't rely on `ground_truth_label` for survival — check the loaded patient data. + if any(isinstance(gt, dict) for gt in patient_to_ground_truth.values()): + _logger.warning( + "Multi-target survival prediction export not yet supported; skipping CSV save" + ) + else: _to_survival_prediction_df( patient_to_ground_truth=cast( - Mapping[PatientId, str | None], patient_to_ground_truth + Mapping[ + PatientId, str | tuple[float | None, int | None] | None + ], + patient_to_ground_truth, ), predictions=cast(Mapping[PatientId, torch.Tensor], predictions), patient_label=config.patient_label, cut_off=getattr(model.hparams, "train_pred_median", None), ).to_csv(split_dir / "patient-preds.csv", index=False) - else: - _logger.warning( - "Multi-target survival prediction export not yet supported; skipping CSV save" - ) elif config.task == "regression": if config.ground_truth_label is None: raise RuntimeError("Grounf truth label is required for regression") @@ -310,27 +369,56 @@ def categorical_crossval_( def _get_splits( - *, patient_to_data: Mapping[PatientId, PatientData[Any]], n_splits: int, spliter + *, + patient_to_data: Mapping[PatientId, PatientData[Any]], + n_splits: int, + spliter, + task: str | None = None, ) -> _Splits: patients = np.array(list(patient_to_data.keys())) - # Extract ground truth for stratification. - # For multi-target (dict), use the first target's value - y_strat = np.array( - [ - next(iter(gt.values())) if isinstance(gt, dict) else gt - for gt in [patient.ground_truth for patient in patient_to_data.values()] - ] - ) + # Build stratification labels depending on the task + gts = [patient.ground_truth for patient in patient_to_data.values()] + + if task == "survival": + # use event status (0/1) for stratification + # Ground-truths are expected to be pre-parsed into (time, status) tuples + # by the data loading pipeline; extract the status element directly. + statuses: list[int] = [] + for gt in gts: + # support multi-target fallback: use first target + val = next(iter(gt.values())) if isinstance(gt, dict) else gt + # If structured (time, status) use second element, otherwise assume val is status + if isinstance(val, (tuple, list)) and len(val) == 2: + status_val = val[1] + else: + status_val = val + statuses.append(int(cast(int, status_val)) if status_val is not None else 0) + + y_strat = np.array(statuses) + elif task == "classification": + # For multi-target (dict), use the first target's value + y_strat = np.array( + [next(iter(gt.values())) if isinstance(gt, dict) else gt for gt in gts] + ) + else: + # regression or unknown: do not stratify (KFold will ignore y) + y_strat = None skf = spliter(n_splits=n_splits, shuffle=True, random_state=0) + + if y_strat is None: + splits_iter = skf.split(patients) + else: + splits_iter = skf.split(patients, y_strat) + splits = _Splits( splits=[ _Split( train_patients=set(patients[train_indices]), test_patients=set(patients[test_indices]), ) - for train_indices, test_indices in skf.split(patients, y_strat) + for train_indices, test_indices in splits_iter ] ) return splits diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index eadb42f8..6ddc00fd 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -121,6 +121,7 @@ def tile_bag_dataloader( bag_size=bag_size, ground_truths=targets, transform=transform, + deterministic=(not shuffle), ) dl = DataLoader( ds, @@ -229,9 +230,20 @@ def _parse_targets( events.append(np.nan) continue - time_str, status_str = gt.split(" ", 1) - times.append(np.nan if time_str.lower() == "nan" else float(time_str)) - events.append(_parse_survival_status(status_str)) + # Accept either structured tuple/list (time, event) or the legacy + # string form "time status". + if isinstance(gt, (tuple, list)) and len(gt) == 2: + t_val, e_val = gt + times.append( + np.nan + if t_val is None or str(t_val).lower() == "nan" + else float(t_val) + ) + events.append(float(e_val) if e_val is not None else np.nan) + else: + time_str, status_str = str(gt).split(" ", 1) + times.append(np.nan if time_str.lower() == "nan" else float(time_str)) + events.append(_parse_survival_status(status_str)) y = torch.tensor(np.column_stack([times, events]), dtype=torch.float32) return y, [] @@ -356,22 +368,20 @@ def create_dataloader( if gt is None: continue if isinstance(gt, dict): - # Use first value for multi-target regression - first_val = next(iter(gt.values())) - values.append(float(first_val)) - else: - values.append(float(gt)) + raise ValueError( + "Multi-target regression is not supported; provide a single numeric target per patient" + ) + values.append(float(gt)) labels = torch.tensor(values, dtype=torch.float32).reshape(-1, 1) elif task == "survival": times, events = [], [] for p in patient_data: if isinstance(p.ground_truth, dict): - # Multi-target survival: use first target - val = list(p.ground_truth.values())[0] - t, e = (val or "nan nan").split(" ", 1) - else: - t, e = (p.ground_truth or "nan nan").split(" ", 1) + raise ValueError( + "Multi-target survival is not supported; provide a single survival time/status per patient" + ) + t, e = (p.ground_truth or "nan nan").split(" ", 1) times.append(float(t) if t.lower() != "nan" else np.nan) events.append(_parse_survival_status(e)) @@ -449,6 +459,15 @@ def load_patient_level_data( """ # Load ground truth mapping + # Multi-target ground truths are only supported for classification. + if task is not None and task != "classification": + if isinstance(ground_truth_label, Sequence) and not isinstance( + ground_truth_label, str + ): + raise ValueError( + "Multi-target ground_truth_label is only supported for classification tasks" + ) + if task == "survival" and time_label is not None and status_label is not None: # Survival: use the existing helper patient_to_ground_truth = patient_to_survival_from_clini_table_( @@ -519,6 +538,7 @@ class BagDataset(Dataset[tuple[_Bag, _Coordinates, BagSize, _EncodedTarget]]): # """The ground truth for each bag, one-hot encoded.""" transform: Callable[[Tensor], Tensor] | None + deterministic: bool = False def __post_init__(self) -> None: if len(self.bags) != len(self.ground_truths): @@ -564,7 +584,12 @@ def __getitem__( # Sample a subset, if required if self.bag_size is not None: return ( - *_to_fixed_size_bag(feats, coords=coords_um, bag_size=self.bag_size), + *_to_fixed_size_bag( + feats, + coords=coords_um, + bag_size=self.bag_size, + deterministic=self.deterministic, + ), self.ground_truths[index], ) else: @@ -697,15 +722,7 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: ) if not tile_size_px and "tile_size_px" in feature_h5.attrs: - tile_size_px_attr = feature_h5.attrs.get("tile_size_px") - if tile_size_px_attr is not None and isinstance( - tile_size_px_attr, (int, float) - ): - tile_size_px = TilePixels(int(tile_size_px_attr)) - else: - raise RuntimeError( - "Invalid or missing 'tile_size_px' attribute in the feature file." - ) + tile_size_px = TilePixels(int(feature_h5.attrs["tile_size_px"])) # pyright: ignore[reportArgumentType] if not tile_size_um or coords_um is None: raise RuntimeError( @@ -716,7 +733,7 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: def _to_fixed_size_bag( - bag: _Bag, coords: _Coordinates, bag_size: BagSize + bag: _Bag, coords: _Coordinates, bag_size: BagSize, deterministic: bool = False ) -> tuple[_Bag, _Coordinates, BagSize]: """Samples a fixed-size bag of tiles from an arbitrary one. @@ -725,23 +742,47 @@ def _to_fixed_size_bag( """ # get up to bag_size elements n_tiles, _dim_feats = bag.shape - bag_idxs = torch.randperm(n_tiles)[:bag_size] + if n_tiles <= bag_size: + # take all and pad later + bag_idxs = torch.arange(n_tiles, device=bag.device) + else: + if deterministic: + # equidistant indices across the bag + idxs = torch.linspace(0, n_tiles - 1, steps=bag_size, device=bag.device) + bag_idxs = idxs.round().long() + else: + bag_idxs = torch.randperm(n_tiles, device=bag.device)[:bag_size] + bag_samples = bag[bag_idxs] coord_samples = coords[bag_idxs] # zero-pad if we don't have enough samples - zero_padded_bag = torch.cat( - ( - bag_samples, - torch.zeros(bag_size - bag_samples.shape[0], bag_samples.shape[1]), + if bag_samples.shape[0] < bag_size: + zero_padded_bag = torch.cat( + ( + bag_samples, + torch.zeros( + bag_size - bag_samples.shape[0], + bag_samples.shape[1], + device=bag.device, + dtype=bag.dtype, + ), + ) ) - ) - zero_padded_coord = torch.cat( - ( - coord_samples, - torch.zeros(bag_size - coord_samples.shape[0], coord_samples.shape[1]), + zero_padded_coord = torch.cat( + ( + coord_samples, + torch.zeros( + bag_size - coord_samples.shape[0], + coord_samples.shape[1], + device=coords.device, + dtype=coords.dtype, + ), + ) ) - ) + else: + zero_padded_bag = bag_samples + zero_padded_coord = coord_samples return zero_padded_bag, zero_padded_coord, min(bag_size, len(bag)) @@ -822,13 +863,14 @@ def patient_to_survival_from_clini_table_( patient_label: PandasLabel, time_label: PandasLabel, status_label: PandasLabel, -) -> dict[PatientId, GroundTruth]: +) -> dict[PatientId, tuple[float | None, int | None]]: """ Loads patients and their survival ground truths (time + event) from a clini table. Returns dict[PatientId, GroundTruth] - Mapping patient_id -> "time status" (e.g. "302 dead", "476 alive"). + Mapping patient_id -> tuple (time, event) where `time` is a numeric follow-up + value and `event` is an integer indicator (1=event/death, 0=censored). """ clini_df = read_table( clini_table_path, @@ -864,7 +906,7 @@ def patient_to_survival_from_clini_table_( # Only drop rows where BOTH time and status are missing clini_df = clini_df.dropna(subset=[time_label, status_label], how="all") - patient_to_ground_truth: dict[PatientId, GroundTruth] = {} + patient_to_ground_truth: dict[PatientId, tuple[float | None, int | None]] = {} for _, row in clini_df.iterrows(): pid = row[patient_label] time_str = row[time_label] @@ -877,10 +919,9 @@ def patient_to_survival_from_clini_table_( # Encode status: keep both dead (event=1) and alive (event=0) status = _parse_survival_status(status_str) - # Encode back to "alive"/"dead" like before - # status = "dead" if status_val == 1 else "alive" - - patient_to_ground_truth[pid] = f"{time_str} {status}" + # Store structured ground-truth as (time, event) to avoid repeated string parsing + time_val = None if pd.isna(time_str) else float(time_str) + patient_to_ground_truth[pid] = (time_val, status) return patient_to_ground_truth @@ -940,7 +981,7 @@ def read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: def filter_complete_patient_data_( *, patient_to_ground_truth: Mapping[ - PatientId, GroundTruth | dict[str, GroundTruth] | None + PatientId, GroundTruth | dict[str, GroundTruth] | tuple[float | None, int | None] | None ], slide_to_patient: Mapping[FeaturePath, PatientId], drop_patients_with_missing_ground_truth: bool, @@ -1021,8 +1062,10 @@ def _log_patient_slide_feature_inconsistencies( if slides_without_features := { slide for slide in slide_to_patient.keys() if not slide.exists() }: + slides_list = sorted(str(s) for s in slides_without_features) _logger.warning( - f"some feature files could not be found: {slides_without_features}" + "some feature files could not be found: %s", + ", ".join(slides_list), ) @@ -1131,6 +1174,15 @@ def load_patient_data_( raise ValueError( "Ground truth label is required for classification or regression modeling" ) + # Disallow multi-target ground truth for non-classification tasks + if ( + task != "classification" + and isinstance(ground_truth_label, Sequence) + and not isinstance(ground_truth_label, str) + ): + raise ValueError( + "Multi-target ground_truth_label is only supported for classification tasks" + ) patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( clini_table_path=clini_table, ground_truth_label=ground_truth_label, diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index d3b29ebd..c61a2512 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -20,7 +20,13 @@ slide_to_patient_from_slide_table_, ) from stamp.modeling.registry import ModelName, load_model_class -from stamp.types import Category, GroundTruth, PandasLabel, PatientId +from stamp.types import ( + Category, + GroundTruth, + PandasLabel, + PatientId, + SurvivalGroundTruth, +) __all__ = ["deploy_categorical_model_"] @@ -218,7 +224,7 @@ def deploy_categorical_model_( } patient_to_data = filter_complete_patient_data_( patient_to_ground_truth=cast( - Mapping[PatientId, GroundTruth | None], + Mapping[PatientId, GroundTruth | dict[str, GroundTruth] | None], patient_to_ground_truth, ), slide_to_patient=slide_to_patient, @@ -622,7 +628,9 @@ def _to_regression_prediction_df( def _to_survival_prediction_df( *, - patient_to_ground_truth: Mapping[PatientId, GroundTruth | None], + patient_to_ground_truth: Mapping[ + PatientId, GroundTruth | SurvivalGroundTruth | None + ], predictions: Mapping[PatientId, torch.Tensor], patient_label: PandasLabel, cut_off: float | None = None, diff --git a/src/stamp/modeling/models/__init__.py b/src/stamp/modeling/models/__init__.py index 0b6a3885..b5a59b5f 100644 --- a/src/stamp/modeling/models/__init__.py +++ b/src/stamp/modeling/models/__init__.py @@ -6,7 +6,9 @@ from typing import Any, TypeAlias import lightning +import numpy as np import torch +from lifelines.utils import concordance_index as lifelines_cindex # Use beartype.typing.Mapping to avoid PEP-585 deprecation warnings in beartype from beartype.typing import Mapping @@ -653,27 +655,36 @@ def cox_loss( def c_index( scores: torch.Tensor, times: torch.Tensor, events: torch.Tensor ) -> torch.Tensor: - # """ - # Concordance index: proportion of correctly ordered comparable pairs. - # """ - N = len(times) - if N <= 1: + """ + Concordance index using lifelines implementation for consistency with statistics. + + Uses the same convention as stamp.statistics.survival: + - Higher risk scores should correspond to shorter survival (worse outcome). + - We negate scores so lifelines interprets higher values as longer survival. + """ + # Convert to numpy for lifelines + scores_np = scores.detach().cpu().numpy().flatten() + times_np = times.detach().cpu().numpy().flatten() + events_np = events.detach().cpu().numpy().flatten() + + # Filter out NaN values + valid_mask = ~(np.isnan(times_np) | np.isnan(events_np) | np.isnan(scores_np)) + if valid_mask.sum() <= 1: return torch.tensor(float("nan"), device=scores.device) - t_i = times.view(-1, 1).expand(N, N) - t_j = times.view(1, -1).expand(N, N) - e_i = events.view(-1, 1).expand(N, N) + times_np = times_np[valid_mask] + events_np = events_np[valid_mask] + scores_np = scores_np[valid_mask] - mask = (t_i < t_j) & e_i.bool() - if mask.sum() == 0: + # Use lifelines concordance_index with negated risk (same as statistics module) + # lifelines expects: higher predicted value = longer survival + # Cox outputs: higher risk = shorter survival, so we negate + try: + ci = lifelines_cindex(times_np, -scores_np, events_np) + except Exception: return torch.tensor(float("nan"), device=scores.device) - s_i = scores.view(-1, 1).expand(N, N)[mask] - s_j = scores.view(1, -1).expand(N, N)[mask] - - conc = (s_i > s_j).float() - ties = (s_i == s_j).float() * 0.5 - return (conc + ties).sum() / mask.sum() + return torch.tensor(ci, device=scores.device, dtype=scores.dtype) def on_validation_epoch_end(self): if ( diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index fb2bdb3b..7031eed6 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -17,6 +17,7 @@ BagDataset, PatientData, PatientFeatureDataset, + _parse_survival_status, create_dataloader, load_patient_data_, ) @@ -154,11 +155,21 @@ def setup_model_for_training( f"No model specified, defaulting to '{advanced.model_name.value}' for feature type '{feature_type}'" ) + # Prevent selecting `barspoon` for single-target classification + if ( + task == "classification" + and isinstance(ground_truth_label, str) + and advanced.model_name == ModelName.BARSPOON + ): + raise ValueError( + "Model 'barspoon' requires multi-target classification. " + "For single-target classification set model_name to 'vit', 'trans_mil', or 'mlp'." + ) + # 2. Instantiate the lightning wrapper (based on provided task, feature type) and model backbone dynamically LitModelClass, ModelClass = load_model_class( task, feature_type, advanced.model_name ) - print(f"Using Lightning wrapper class: {LitModelClass}") # 3. Validate that the chosen model supports the feature type if feature_type not in LitModelClass.supported_features: @@ -220,6 +231,124 @@ def setup_model_for_training( return model, train_dl, valid_dl +def setup_model_from_dataloaders( + *, + train_dl: DataLoader, + valid_dl: DataLoader, + task: Task, + train_categories: Sequence[Category] | Mapping[str, Sequence[Category]], + dim_feats: int, + train_patients: Sequence[PatientId], + valid_patients: Sequence[PatientId], + feature_type: str, + advanced: AdvancedConfig, + # Metadata, has no effect on model training + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, + time_label: PandasLabel | None, + status_label: PandasLabel | None, + clini_table: Path, + slide_table: Path | None, + feature_dir: Path, +) -> lightning.LightningModule: + """Creates a model from pre-built dataloaders (no internal split).""" + + _logger.info( + "Training dataloaders: task=%s, feature_type=%s", + task, + feature_type, + ) + + category_weights: torch.Tensor | dict[str, torch.Tensor] | list = [] + if task == "classification": + category_weights = _compute_class_weights_and_check_categories( + train_dl=train_dl, + feature_type=feature_type, + train_categories=train_categories, + ) + + # 1. Default to a model if none is specified + if advanced.model_name is None: + advanced.model_name = ModelName.VIT if feature_type == "tile" else ModelName.MLP + _logger.info( + f"No model specified, defaulting to '{advanced.model_name.value}' for feature type '{feature_type}'" + ) + + # Prevent selecting `barspoon` for single-target classification + if ( + task == "classification" + and isinstance(ground_truth_label, str) + and advanced.model_name == ModelName.BARSPOON + ): + raise ValueError( + "Model 'barspoon' requires multi-target classification. " + "For single-target classification set model_name to 'vit', 'trans_mil', or 'mlp'." + ) + + # 2. Instantiate the lightning wrapper (based on provided task, feature type) and model backbone dynamically + LitModelClass, ModelClass = load_model_class( + task, feature_type, advanced.model_name + ) + + # 3. Validate that the chosen model supports the feature type + if feature_type not in LitModelClass.supported_features: + raise ValueError( + f"Model '{advanced.model_name.value}' does not support feature type '{feature_type}'. " + f"Supported types are: {LitModelClass.supported_features}" + ) + elif ( + feature_type in ("slide", "patient") + and advanced.model_name.value.lower() != "mlp" + ): + raise ValueError( + f"Feature type '{feature_type}' only supports MLP backbones. " + f"Got '{advanced.model_name.value}'. Please set model_name='mlp'." + ) + + # 4. Get model-specific hyperparameters + model_specific_params = ( + advanced.model_params.model_dump().get(advanced.model_name.value) or {} + ) + + # 5. Calculate total steps for scheduler + steps_per_epoch = len(train_dl) + total_steps = steps_per_epoch * advanced.max_epochs + + # 6. Prepare common parameters + common_params = { + "categories": train_categories, + "category_weights": category_weights, + "dim_input": dim_feats, + "total_steps": total_steps, + "max_lr": advanced.max_lr, + "div_factor": advanced.div_factor, + # Metadata, has no effect on model training + "model_name": advanced.model_name.value, + "ground_truth_label": ground_truth_label, + "time_label": time_label, + "status_label": status_label, + "train_patients": train_patients, + "valid_patients": valid_patients, + "clini_table": clini_table, + "slide_table": slide_table, + "feature_dir": feature_dir, + } + + all_params = {**common_params, **model_specific_params} + + _logger.info( + f"Instantiating model '{advanced.model_name.value}' with parameters: {model_specific_params}" + ) + _logger.info( + "Other params: max_epochs=%s, patience=%s", + advanced.max_epochs, + advanced.patience, + ) + + model = LitModelClass(model_class=ModelClass, **all_params) + + return model + + def setup_dataloaders_for_training( *, patient_to_data: Mapping[PatientId, PatientData[GroundTruth | None | dict]], @@ -259,6 +388,12 @@ def setup_dataloaders_for_training( "patient_to_data must have a ground truth defined for all targets!" ) + # Multi-target ground truths are only supported for classification + if task != "classification" and any(isinstance(gt, dict) for gt in ground_truths): + raise ValueError( + "Multi-target ground truths are only supported for classification tasks" + ) + if task == "classification": # Handle both single and multi-target cases if ground_truths and isinstance(ground_truths[0], dict): @@ -268,17 +403,37 @@ def setup_dataloaders_for_training( else: stratify = ground_truths elif task == "survival": - # Extract event indicator (status) - handle both single and multi-target - statuses = [] + # Extract event indicator (status). Accept either structured (time,event) + # or legacy string "time status" formats. + statuses: list[int] = [] for gt in ground_truths: if isinstance(gt, dict): - # Multi-target survival: extract from first target - first_key = list(gt.keys())[0] - val = cast(dict, gt)[first_key] - if val: - statuses.append(int(val.split()[1])) + raise ValueError( + "Multi-target survival is not supported; provide a single survival time/status per patient" + ) + if isinstance(gt, (tuple, list)) and len(gt) == 2: + status_val = gt[1] + if status_val is None: + raise ValueError( + "Missing survival status for a patient; cannot stratify" + ) + statuses.append(int(status_val)) else: - statuses.append(int(gt.split()[1])) + parts = str(gt).split() + if len(parts) >= 2: + status_token = parts[1] + elif len(parts) == 1: + status_token = parts[0] + else: + raise ValueError( + "Unrecognized survival ground-truth format for stratification" + ) + parsed_status = _parse_survival_status(status_token) + if parsed_status is None: + raise ValueError( + f"Unrecognized survival status token for stratification: {status_token!r}" + ) + statuses.append(int(parsed_status)) stratify = statuses elif task == "regression": stratify = None diff --git a/src/stamp/types.py b/src/stamp/types.py index c1ff6873..fddb9d4c 100644 --- a/src/stamp/types.py +++ b/src/stamp/types.py @@ -37,6 +37,8 @@ PatientId: TypeAlias = str GroundTruth: TypeAlias = str +# Survival ground-truth is represented as (time, event) +SurvivalGroundTruth: TypeAlias = tuple[float | None, int | None] MultiClassGroundTruth: TypeAlias = tuple[str, ...] FeaturePath = NewType("FeaturePath", Path) From 2daf92193c0fd70cc0f0e6285f09bb11497ff5e1 Mon Sep 17 00:00:00 2001 From: mducducd Date: Tue, 17 Feb 2026 16:17:02 +0000 Subject: [PATCH 06/11] refactor survival training/validation --- src/stamp/modeling/data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 6ddc00fd..bb295bcf 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -981,7 +981,8 @@ def read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: def filter_complete_patient_data_( *, patient_to_ground_truth: Mapping[ - PatientId, GroundTruth | dict[str, GroundTruth] | tuple[float | None, int | None] | None + PatientId, + GroundTruth | dict[str, GroundTruth] | tuple[float | None, int | None] | None, ], slide_to_patient: Mapping[FeaturePath, PatientId], drop_patients_with_missing_ground_truth: bool, From 8d96c425b8b73493d698a174d83ca35348af0061 Mon Sep 17 00:00:00 2001 From: Minh Duc Nguyen <37109868+mducducd@users.noreply.github.com> Date: Tue, 3 Mar 2026 11:02:18 +0100 Subject: [PATCH 07/11] Remove unused import from survival.py Removed unused import for add_at_risk_counts. --- src/stamp/statistics/survival.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/stamp/statistics/survival.py b/src/stamp/statistics/survival.py index 6ff75b61..87415a7a 100644 --- a/src/stamp/statistics/survival.py +++ b/src/stamp/statistics/survival.py @@ -8,7 +8,6 @@ import numpy as np import pandas as pd from lifelines import KaplanMeierFitter -from lifelines.plotting import add_at_risk_counts from lifelines.statistics import logrank_test from lifelines.utils import concordance_index From cc490df804534bdc1ffda846e48a558cc1fe74cc Mon Sep 17 00:00:00 2001 From: mducducd Date: Wed, 11 Mar 2026 12:26:05 +0000 Subject: [PATCH 08/11] update data.py to latest v2.4.1 --- src/stamp/modeling/data.py | 175 +++++++++++++++++++++++++++---------- 1 file changed, 127 insertions(+), 48 deletions(-) diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 868dc63b..c3696c0f 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -1,6 +1,7 @@ """Helper classes to manage pytorch data.""" import logging +from collections import OrderedDict from collections.abc import Callable, Iterable, Mapping, Sequence from dataclasses import KW_ONLY, dataclass from io import BytesIO # accept in _BinaryIOLike at runtime @@ -130,6 +131,8 @@ def tile_bag_dataloader( num_workers=num_workers, collate_fn=collate_fn, worker_init_fn=Seed.get_loader_worker_init() if Seed._is_set() else None, + persistent_workers=(num_workers > 0), + pin_memory=torch.cuda.is_available(), ) return ( @@ -229,9 +232,7 @@ def _parse_targets( times.append(np.nan) events.append(np.nan) continue - - # Accept either structured tuple/list (time, event) or the legacy - # string form "time status". + # Expect a structured tuple/list (time, event). if isinstance(gt, (tuple, list)) and len(gt) == 2: t_val, e_val = gt times.append( @@ -241,9 +242,9 @@ def _parse_targets( ) events.append(float(e_val) if e_val is not None else np.nan) else: - time_str, status_str = str(gt).split(" ", 1) - times.append(np.nan if time_str.lower() == "nan" else float(time_str)) - events.append(_parse_survival_status(status_str)) + raise ValueError( + "survival ground truth must be a (time, event) tuple/list" + ) y = torch.tensor(np.column_stack([times, events]), dtype=torch.float32) return y, [] @@ -358,10 +359,8 @@ def create_dataloader( if task == "classification": raw = np.array([p.ground_truth for p in patient_data]) - categories_out = categories or list(np.unique(raw)) - labels = torch.tensor( - raw.reshape(-1, 1) == categories_out, dtype=torch.float32 - ) + categories = categories or list(np.unique(raw)) + labels = torch.tensor(raw.reshape(-1, 1) == categories, dtype=torch.float32) elif task == "regression": values: list[float] = [] for gt in (p.ground_truth for p in patient_data): @@ -381,8 +380,28 @@ def create_dataloader( raise ValueError( "Multi-target survival is not supported; provide a single survival time/status per patient" ) - t, e = (p.ground_truth or "nan nan").split(" ", 1) - times.append(float(t) if t.lower() != "nan" else np.nan) + gt = p.ground_truth + # Prefer structured (time, event) tuples. Do NOT call .split + # on the ground-truth value. If not a tuple/list, treat time + # as the whole value and event as unknown. + if isinstance(gt, (tuple, list)) and len(gt) == 2: + t, e = gt + elif gt is None: + t, e = None, None + else: + t, e = str(gt), "nan" + + # Parse time defensively + if t is None: + times.append(np.nan) + elif isinstance(t, str): + try: + times.append(np.nan if t.lower() == "nan" else float(t)) + except Exception: + times.append(np.nan) + else: + times.append(float(t)) + events.append(_parse_survival_status(e)) labels = torch.tensor(np.column_stack([times, events]), dtype=torch.float32) @@ -396,6 +415,8 @@ def create_dataloader( shuffle=shuffle, num_workers=num_workers, worker_init_fn=Seed.get_loader_worker_init() if Seed._is_set() else None, + persistent_workers=(num_workers > 0), + pin_memory=torch.cuda.is_available(), ) return dl, categories or [] else: @@ -545,6 +566,19 @@ def __post_init__(self) -> None: raise ValueError( "the number of ground truths has to match the number of bags" ) + # Initialise per-worker HDF5 handle cache here so __getitem__ avoids + # a hasattr() call on every tile read. + self._h5_handle_cache: OrderedDict[FeaturePath | _BinaryIOLike, h5py.File] = ( + OrderedDict() + ) + + def __getstate__(self) -> dict: + # h5py file handles cannot be pickled (required when DataLoader uses + # spawn-based multiprocessing). Drop the cache; each worker reopens + # files lazily on the first __getitem__ access. + state = self.__dict__.copy() + state["_h5_handle_cache"] = OrderedDict() + return state def __len__(self) -> int: return len(self.bags) @@ -556,24 +590,46 @@ def __getitem__( feats = [] coords_um = [] for bag_file in self.bags[index]: - with h5py.File(bag_file, "r") as h5: - if "feats" in h5: - feats_obj = h5["feats"] - if not isinstance(feats_obj, h5py.Dataset): - raise RuntimeError( - f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" - ) - arr = feats_obj[:] # original STAMP files - else: - embeddings_obj = h5["patch_embeddings"] - if not isinstance(embeddings_obj, h5py.Dataset): - raise RuntimeError( - f"expected 'patch_embeddings' to be an HDF5 dataset but got {type(embeddings_obj)}" - ) - arr = embeddings_obj[:] # your Kronos files + if bag_file not in self._h5_handle_cache: + # Limit open handles to avoid reaching OS ulimits + if len(self._h5_handle_cache) >= 128: + _, h = self._h5_handle_cache.popitem(last=False) + h.close() + + try: + # libver='latest' and swmr=True can provide better performance + # on some network/HPC filesystems + self._h5_handle_cache[bag_file] = h5py.File( + bag_file, "r", swmr=True, libver="latest" + ) + except Exception: + # Fallback for older HDF5 files or unconventional storage + self._h5_handle_cache[bag_file] = h5py.File(bag_file, "r") + else: + # Move recently accessed file to end (mark as recently used) + self._h5_handle_cache.move_to_end(bag_file) - feats.append(torch.from_numpy(arr)) - coords_um.append(torch.from_numpy(get_coords(h5).coords_um)) + h5 = self._h5_handle_cache[bag_file] + + if "feats" in h5: + feats_obj = h5["feats"] + if not isinstance(feats_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" + ) + arr = feats_obj[ + () + ] # uses [()] instead of [:] for clarity, both read entire dataset + else: + embeddings_obj = h5["patch_embeddings"] + if not isinstance(embeddings_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'patch_embeddings' to be an HDF5 dataset but got {type(embeddings_obj)}" + ) + arr = embeddings_obj[()] # your Kronos files + + feats.append(torch.from_numpy(arr)) + coords_um.append(torch.from_numpy(get_coords(h5).coords_um)) feats = torch.concat(feats).float() coords_um = torch.concat(coords_um).float() @@ -618,31 +674,53 @@ def __init__( self.feature_files = feature_files self.ground_truths = ground_truths self.transform = transform + # Initialise per-worker HDF5 handle cache eagerly so __getitem__ avoids + # a hasattr() call on every sample read. + self._h5_handle_cache: dict[FeaturePath | _BinaryIOLike, h5py.File] = {} + + def __getstate__(self) -> dict: + # h5py file handles cannot be pickled (required when DataLoader uses + # spawn-based multiprocessing). Drop the cache; each worker reopens + # files lazily on the first __getitem__ access. + state = self.__dict__.copy() + state["_h5_handle_cache"] = {} + return state def __len__(self): return len(self.feature_files) def __getitem__(self, idx: int): feature_file = self.feature_files[idx] - with h5py.File(feature_file, "r") as h5: - feats_obj = h5["feats"] - if not isinstance(feats_obj, h5py.Dataset): - raise RuntimeError( - f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" + if feature_file not in self._h5_handle_cache: + if len(self._h5_handle_cache) >= 128: + _, h = self._h5_handle_cache.popitem() + h.close() + try: + self._h5_handle_cache[feature_file] = h5py.File( + feature_file, "r", swmr=True, libver="latest" ) - feats = torch.from_numpy(feats_obj[:]) - # Accept [V] or [1, V] - if feats.ndim == 2 and feats.shape[0] == 1: - feats = feats[0] - elif feats.ndim == 1: - pass - else: - raise RuntimeError( - f"Expected single feature vector (shape [F] or [1, F]), got {feats.shape} in {feature_file}." - "Check that the features are patient-level." - ) - if self.transform is not None: - feats = self.transform(feats) + except Exception: + self._h5_handle_cache[feature_file] = h5py.File(feature_file, "r") + + h5 = self._h5_handle_cache[feature_file] + feats_obj = h5["feats"] + if not isinstance(feats_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" + ) + feats = torch.from_numpy(feats_obj[()]) + # Accept [V] or [1, V] + if feats.ndim == 2 and feats.shape[0] == 1: + feats = feats[0] + elif feats.ndim == 1: + pass + else: + raise RuntimeError( + f"Expected single feature vector (shape [F] or [1, F]), got {feats.shape} in {feature_file}." + "Check that the features are patient-level." + ) + if self.transform is not None: + feats = self.transform(feats) label = self.ground_truths[idx] return feats, label @@ -1063,7 +1141,8 @@ def _log_patient_slide_feature_inconsistencies( if slides_without_features := { slide for slide in slide_to_patient.keys() if not slide.exists() }: - slides_list = sorted(str(s) for s in slides_without_features) + # Log only the filenames (not full paths) to keep warnings concise. + slides_list = sorted(s.name for s in slides_without_features) _logger.warning( "some feature files could not be found: %s", ", ".join(slides_list), From fb6d114b0a69fafa5697eef9aea47fe8ecbaa9ed Mon Sep 17 00:00:00 2001 From: mducducd Date: Wed, 11 Mar 2026 13:12:58 +0000 Subject: [PATCH 09/11] update tests to latest v2.4.1 --- tests/conftest.py | 10 +++ tests/test_deployment.py | 12 +-- tests/test_mics.py | 161 +++++++++++++++++++++++++++++++++++++ tests/test_train_deploy.py | 84 +++++++++---------- 4 files changed, 220 insertions(+), 47 deletions(-) create mode 100644 tests/test_mics.py diff --git a/tests/conftest.py b/tests/conftest.py index f7f4f005..16dcae5b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,15 @@ +import multiprocessing from stamp.preprocessing import ExtractorName +# Ensure tests use a safe multiprocessing start method to avoid +# fork-from-multi-threaded-process warnings on Linux. +if multiprocessing.get_start_method(allow_none=True) != "spawn": + try: + multiprocessing.set_start_method("spawn") + except RuntimeError: + # start method already set by the test runner/environment + pass + # This lets you choose which extractors to run on pytest. Useful for the # CI pipeline as extractors like MUSK take an eternity on CPU. diff --git a/tests/test_deployment.py b/tests/test_deployment.py index 7d1d6589..55939e66 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -221,8 +221,8 @@ def test_to_prediction_df(task: str) -> None: assert preds_df["loss"].isna().all() else: patient_to_ground_truth = { - PatientId("p1"): "10.0 1", - PatientId("p2"): "12.3 0", + PatientId("p1"): (10.0, 1), + PatientId("p2"): (12.3, 0), } predictions = { PatientId("p1"): torch.tensor([0.8]), @@ -304,17 +304,19 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: feature_file = make_old_feature_file( feats=torch.rand(23, dim_feats), coords=torch.rand(23, 2) ) - gt = GroundTruth("foo") + gt = cast(GroundTruth, "foo") elif task == "regression": feature_file = make_old_feature_file( feats=torch.rand(30, dim_feats), coords=torch.rand(30, 2) ) - gt = GroundTruth(42.5) # numeric target wrapped for typing + gt = cast(GroundTruth, 42.5) # numeric target wrapped for typing else: # survival feature_file = make_old_feature_file( feats=torch.rand(40, dim_feats), coords=torch.rand(40, 2) ) - gt = GroundTruth("12 0") # (time, status) + gt = cast( + GroundTruth, (12.0, 0) + ) # (time, status) - use raw tuple (GroundTruth is a str alias) patient_to_data = { PatientId("pat_test"): PatientData( diff --git a/tests/test_mics.py b/tests/test_mics.py new file mode 100644 index 00000000..faacbe0a --- /dev/null +++ b/tests/test_mics.py @@ -0,0 +1,161 @@ +from pathlib import Path + +import h5py +import numpy as np +import torch + +from typing import cast, Sequence + +from stamp.modeling.config import TrainConfig +from stamp.modeling.data import ( + PatientData, + _log_patient_slide_feature_inconsistencies, + create_dataloader, +) +from stamp.modeling.models.__init__ import LitSlideClassifier +from stamp.modeling.models.mlp import MLP +from stamp.modeling.registry import ModelName, load_model_class +from stamp.types import FeaturePath +from stamp.utils.config import StampConfig + + +def _make_h5(path: Path, vec: np.ndarray) -> None: + with h5py.File(path, "w") as h5: + h5.create_dataset("feats", data=vec) + + +def test_create_dataloader_infers_categories_patient(tmp_path: Path) -> None: + # create two patient-level feature files + p1 = tmp_path / "p1.h5" + p2 = tmp_path / "p2.h5" + _make_h5(p1, np.zeros((1, 4), dtype=np.float32)) + _make_h5(p2, np.zeros((1, 4), dtype=np.float32)) + + pd1 = PatientData(ground_truth="A", feature_files=[FeaturePath(p1)]) + pd2 = PatientData(ground_truth="B", feature_files=[FeaturePath(p2)]) + + dl, cats = create_dataloader( + feature_type="patient", + task="classification", + patient_data=[pd1, pd2], + bag_size=None, + batch_size=1, + shuffle=False, + num_workers=0, + transform=None, + categories=None, + ) + + assert set(cats) == {"A", "B"} + + +def test_create_dataloader_survival_tuple_and_legacy_string(tmp_path: Path) -> None: + p1 = tmp_path / "s1.h5" + p2 = tmp_path / "s2.h5" + _make_h5(p1, np.zeros((1, 3), dtype=np.float32)) + _make_h5(p2, np.zeros((1, 3), dtype=np.float32)) + + # tuple (time, event) and legacy string (will be parsed defensively) + pd1 = PatientData(ground_truth=(5.0, 1), feature_files=[FeaturePath(p1)]) + pd2 = PatientData(ground_truth="7 0", feature_files=[FeaturePath(p2)]) + + dl, cats = create_dataloader( + feature_type="patient", + task="survival", + patient_data=cast(Sequence[PatientData], [pd1, pd2]), + bag_size=None, + batch_size=2, + shuffle=False, + num_workers=0, + transform=None, + categories=None, + ) + + batch_feats, labels = next(iter(dl)) + # labels: shape (2,2) -> (time, event) + assert labels.shape[0] == 2 + # first sample is (5.0, 1) + assert torch.isclose(labels[0, 0], torch.tensor(5.0)) + assert labels[0, 1] == 1.0 + # legacy string parsing yields an event token parsed to 0 (censored) + assert labels[1, 1] == 0.0 + + +def test_predict_dtype_casting_no_error() -> None: + # Build a minimal LitSlideClassifier with MLP backbone + categories = ["A", "B"] + category_weights = torch.tensor([1.0, 1.0], dtype=torch.float32) + + model = LitSlideClassifier( + model_class=MLP, + ground_truth_label="y", + categories=categories, + category_weights=category_weights, + dim_input=4, + # provide backbone hyperparams + dim_hidden=8, + num_layers=2, + dropout=0.1, + # Base required args + total_steps=10, + max_lr=1e-3, + div_factor=25.0, + train_patients=[], + valid_patients=[], + ) + + # force model params to half precision to simulate mixed-precision checkpoints + model.model.half() + + feats = torch.rand((1, 4), dtype=torch.float32) + labels = torch.tensor([[0.0, 1.0]], dtype=torch.float32) + + # predict_step should cast feats to model dtype (half) and not raise + out = model.predict_step((feats, labels), 0) + assert isinstance(out, torch.Tensor) + assert out.dtype == next(model.model.parameters()).dtype + + +def test_log_missing_feature_filenames(caplog, tmp_path: Path) -> None: + missing = tmp_path / "missing.h5" + # slide_to_patient maps FeaturePath -> patient id + slide_to_patient = {FeaturePath(missing): "P1"} + patient_to_ground_truth = {"P1": "A"} + + with caplog.at_level("WARNING"): + _log_patient_slide_feature_inconsistencies( + patient_to_ground_truth=patient_to_ground_truth, + slide_to_patient=slide_to_patient, + ) + + assert "some feature files could not be found" in caplog.text + # ensure only filename is logged (not full path) + assert "missing.h5" in caplog.text + + +def test_model_registry_returns_classes() -> None: + LitClass, ModelClass = load_model_class("classification", "patient", ModelName.MLP) + assert callable(LitClass) + assert callable(ModelClass) + + +def test_stampconfig_training_task_default() -> None: + cfg = StampConfig.model_validate( + { + "training": { + "output_dir": "out", + "clini_table": "cl.csv", + "feature_dir": "feats", + "ground_truth_label": "gt", + } + } + ) + assert cfg.training is not None + assert ( + cfg.training.task + == TrainConfig( + output_dir=Path("out"), + clini_table=Path("cl.csv"), + feature_dir=Path("feats"), + ).task + ) diff --git a/tests/test_train_deploy.py b/tests/test_train_deploy.py index ea5547f1..a21af576 100644 --- a/tests/test_train_deploy.py +++ b/tests/test_train_deploy.py @@ -54,20 +54,20 @@ def test_train_deploy_integration( create_random_dataset( dir=tmp_path / "train", n_categories=3, - n_patients=400, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=30, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) ) deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = create_random_dataset( dir=tmp_path / "deploy", categories=categories, - n_patients=50, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=10, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) @@ -85,7 +85,7 @@ def test_train_deploy_integration( advanced = AdvancedConfig( # Dataset and -loader parameters - bag_size=500, + bag_size=32, num_workers=min(os.cpu_count() or 1, 16), # Training paramenters batch_size=8, @@ -142,7 +142,7 @@ def test_train_deploy_patient_level_integration( create_random_patient_level_dataset( dir=tmp_path / "train", n_categories=3, - n_patients=400, + n_patients=30, feat_dim=feat_dim, ) ) @@ -150,7 +150,7 @@ def test_train_deploy_patient_level_integration( create_random_patient_level_dataset( dir=tmp_path / "deploy", categories=categories, - n_patients=50, + n_patients=10, feat_dim=feat_dim, ) ) @@ -218,20 +218,20 @@ def test_train_deploy_regression_integration( train_clini_path, train_slide_path, train_feature_dir, _ = ( create_random_regression_dataset( dir=tmp_path / "train", - n_patients=400, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=30, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) ) deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( create_random_regression_dataset( dir=tmp_path / "deploy", - n_patients=50, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=10, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) ) @@ -250,9 +250,9 @@ def test_train_deploy_regression_integration( ) advanced = AdvancedConfig( - bag_size=500, + bag_size=32, num_workers=min(os.cpu_count() or 1, 16), - batch_size=1, + batch_size=8, max_epochs=2, patience=1, accelerator="gpu" if torch.cuda.is_available() else "cpu", @@ -297,20 +297,20 @@ def test_train_deploy_survival_integration( train_clini_path, train_slide_path, train_feature_dir, _ = ( create_random_survival_dataset( dir=tmp_path / "train", - n_patients=400, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=30, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) ) deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( create_random_survival_dataset( dir=tmp_path / "deploy", - n_patients=50, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=10, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) ) @@ -329,7 +329,7 @@ def test_train_deploy_survival_integration( ) advanced = AdvancedConfig( - bag_size=500, + bag_size=32, num_workers=min(os.cpu_count() or 1, 16), batch_size=8, max_epochs=2, @@ -385,7 +385,7 @@ def test_train_deploy_patient_level_regression_integration( train_feat_dir.mkdir(parents=True, exist_ok=True) deploy_feat_dir.mkdir(parents=True, exist_ok=True) - n_train, n_deploy = 300, 60 + n_train, n_deploy = 30, 10 train_rows, deploy_rows = [], [] # --- Generate random patient-level features and numeric targets --- @@ -490,14 +490,14 @@ def test_train_deploy_patient_level_survival_integration( train_clini_path, train_slide_path, train_feature_dir, _ = ( create_random_patient_level_survival_dataset( dir=tmp_path / "train", - n_patients=300, + n_patients=30, feat_dim=feat_dim, ) ) deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( create_random_patient_level_survival_dataset( dir=tmp_path / "deploy", - n_patients=60, + n_patients=10, feat_dim=feat_dim, ) ) @@ -565,10 +565,10 @@ def test_train_deploy_multi_target_integration( train_clini_path, train_slide_path, train_feature_dir, _ = ( create_random_multi_target_dataset( dir=tmp_path / "train", - n_patients=400, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=30, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, target_labels=target_labels, categories_per_target=categories_per_target, @@ -577,10 +577,10 @@ def test_train_deploy_multi_target_integration( deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( create_random_multi_target_dataset( dir=tmp_path / "deploy", - n_patients=50, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=10, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, target_labels=target_labels, categories_per_target=categories_per_target, @@ -601,7 +601,7 @@ def test_train_deploy_multi_target_integration( ) advanced = AdvancedConfig( - bag_size=500, + bag_size=32, num_workers=min(os.cpu_count() or 1, 16), batch_size=8, max_epochs=2, From b4885c5a3502ce8194eb8913ba27045ab8f0cc58 Mon Sep 17 00:00:00 2001 From: mducducd Date: Wed, 11 Mar 2026 13:17:15 +0000 Subject: [PATCH 10/11] remove misc test --- tests/test_mics.py | 161 --------------------------------------------- 1 file changed, 161 deletions(-) delete mode 100644 tests/test_mics.py diff --git a/tests/test_mics.py b/tests/test_mics.py deleted file mode 100644 index faacbe0a..00000000 --- a/tests/test_mics.py +++ /dev/null @@ -1,161 +0,0 @@ -from pathlib import Path - -import h5py -import numpy as np -import torch - -from typing import cast, Sequence - -from stamp.modeling.config import TrainConfig -from stamp.modeling.data import ( - PatientData, - _log_patient_slide_feature_inconsistencies, - create_dataloader, -) -from stamp.modeling.models.__init__ import LitSlideClassifier -from stamp.modeling.models.mlp import MLP -from stamp.modeling.registry import ModelName, load_model_class -from stamp.types import FeaturePath -from stamp.utils.config import StampConfig - - -def _make_h5(path: Path, vec: np.ndarray) -> None: - with h5py.File(path, "w") as h5: - h5.create_dataset("feats", data=vec) - - -def test_create_dataloader_infers_categories_patient(tmp_path: Path) -> None: - # create two patient-level feature files - p1 = tmp_path / "p1.h5" - p2 = tmp_path / "p2.h5" - _make_h5(p1, np.zeros((1, 4), dtype=np.float32)) - _make_h5(p2, np.zeros((1, 4), dtype=np.float32)) - - pd1 = PatientData(ground_truth="A", feature_files=[FeaturePath(p1)]) - pd2 = PatientData(ground_truth="B", feature_files=[FeaturePath(p2)]) - - dl, cats = create_dataloader( - feature_type="patient", - task="classification", - patient_data=[pd1, pd2], - bag_size=None, - batch_size=1, - shuffle=False, - num_workers=0, - transform=None, - categories=None, - ) - - assert set(cats) == {"A", "B"} - - -def test_create_dataloader_survival_tuple_and_legacy_string(tmp_path: Path) -> None: - p1 = tmp_path / "s1.h5" - p2 = tmp_path / "s2.h5" - _make_h5(p1, np.zeros((1, 3), dtype=np.float32)) - _make_h5(p2, np.zeros((1, 3), dtype=np.float32)) - - # tuple (time, event) and legacy string (will be parsed defensively) - pd1 = PatientData(ground_truth=(5.0, 1), feature_files=[FeaturePath(p1)]) - pd2 = PatientData(ground_truth="7 0", feature_files=[FeaturePath(p2)]) - - dl, cats = create_dataloader( - feature_type="patient", - task="survival", - patient_data=cast(Sequence[PatientData], [pd1, pd2]), - bag_size=None, - batch_size=2, - shuffle=False, - num_workers=0, - transform=None, - categories=None, - ) - - batch_feats, labels = next(iter(dl)) - # labels: shape (2,2) -> (time, event) - assert labels.shape[0] == 2 - # first sample is (5.0, 1) - assert torch.isclose(labels[0, 0], torch.tensor(5.0)) - assert labels[0, 1] == 1.0 - # legacy string parsing yields an event token parsed to 0 (censored) - assert labels[1, 1] == 0.0 - - -def test_predict_dtype_casting_no_error() -> None: - # Build a minimal LitSlideClassifier with MLP backbone - categories = ["A", "B"] - category_weights = torch.tensor([1.0, 1.0], dtype=torch.float32) - - model = LitSlideClassifier( - model_class=MLP, - ground_truth_label="y", - categories=categories, - category_weights=category_weights, - dim_input=4, - # provide backbone hyperparams - dim_hidden=8, - num_layers=2, - dropout=0.1, - # Base required args - total_steps=10, - max_lr=1e-3, - div_factor=25.0, - train_patients=[], - valid_patients=[], - ) - - # force model params to half precision to simulate mixed-precision checkpoints - model.model.half() - - feats = torch.rand((1, 4), dtype=torch.float32) - labels = torch.tensor([[0.0, 1.0]], dtype=torch.float32) - - # predict_step should cast feats to model dtype (half) and not raise - out = model.predict_step((feats, labels), 0) - assert isinstance(out, torch.Tensor) - assert out.dtype == next(model.model.parameters()).dtype - - -def test_log_missing_feature_filenames(caplog, tmp_path: Path) -> None: - missing = tmp_path / "missing.h5" - # slide_to_patient maps FeaturePath -> patient id - slide_to_patient = {FeaturePath(missing): "P1"} - patient_to_ground_truth = {"P1": "A"} - - with caplog.at_level("WARNING"): - _log_patient_slide_feature_inconsistencies( - patient_to_ground_truth=patient_to_ground_truth, - slide_to_patient=slide_to_patient, - ) - - assert "some feature files could not be found" in caplog.text - # ensure only filename is logged (not full path) - assert "missing.h5" in caplog.text - - -def test_model_registry_returns_classes() -> None: - LitClass, ModelClass = load_model_class("classification", "patient", ModelName.MLP) - assert callable(LitClass) - assert callable(ModelClass) - - -def test_stampconfig_training_task_default() -> None: - cfg = StampConfig.model_validate( - { - "training": { - "output_dir": "out", - "clini_table": "cl.csv", - "feature_dir": "feats", - "ground_truth_label": "gt", - } - } - ) - assert cfg.training is not None - assert ( - cfg.training.task - == TrainConfig( - output_dir=Path("out"), - clini_table=Path("cl.csv"), - feature_dir=Path("feats"), - ).task - ) From 0f61e978ee16345578fbc19207713b325a3a6a40 Mon Sep 17 00:00:00 2001 From: mducducd Date: Thu, 12 Mar 2026 10:49:51 +0000 Subject: [PATCH 11/11] fix(ticon): auto-select runtime device --- src/stamp/preprocessing/extractor/ticon.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/stamp/preprocessing/extractor/ticon.py b/src/stamp/preprocessing/extractor/ticon.py index ab7eb829..5780afc4 100644 --- a/src/stamp/preprocessing/extractor/ticon.py +++ b/src/stamp/preprocessing/extractor/ticon.py @@ -624,7 +624,6 @@ def load_ticon(device: str = "cuda") -> nn.Module: class HOptimusTICON(nn.Module): def __init__(self, device: torch.device): super().__init__() - self.device = device # ---------------------------- # Stage 1: H-OptimUS @@ -689,7 +688,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: [B, 3, 224, 224] (CPU or CUDA) """ - x = x.to(self.device, non_blocking=True) + # Respect the current module device (it may be moved after construction). + device = next(self.parameters()).device + x = x.to(device, non_blocking=True) # H-Optimus_1 emb = self.tile_encoder(x) # [B, 1536] @@ -700,7 +701,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: emb.size(0), 1, 2, - device=self.device, + device=device, dtype=torch.float32, ) @@ -713,7 +714,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out.squeeze(1) # [B, 1536] -def ticon(device: str = "cuda") -> Extractor[nn.Module]: +def ticon(device: str | None = None) -> Extractor[nn.Module]: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" model = HOptimusTICON(torch.device(device)) transform = transforms.Compose(