From 98d6a67d776c5a88b86eba17b9129a25f93f2928 Mon Sep 17 00:00:00 2001 From: mducducd Date: Fri, 13 Feb 2026 14:01:16 +0000 Subject: [PATCH 1/9] add multi-target support; tests, fixes, and docs --- src/stamp/__main__.py | 4 +- src/stamp/config.yaml | 24 +- src/stamp/encoding/encoder/__init__.py | 2 +- src/stamp/encoding/encoder/chief.py | 2 +- src/stamp/encoding/encoder/eagle.py | 2 +- src/stamp/encoding/encoder/gigapath.py | 2 +- src/stamp/encoding/encoder/madeleine.py | 2 +- src/stamp/encoding/encoder/titan.py | 2 +- src/stamp/heatmaps/__init__.py | 8 +- src/stamp/modeling/config.py | 26 +- src/stamp/modeling/crossval.py | 230 ++++---- src/stamp/modeling/data.py | 512 +++++++++++++----- src/stamp/modeling/deploy.py | 321 +++++++++-- src/stamp/modeling/models/__init__.py | 79 ++- src/stamp/modeling/models/barspoon.py | 367 +++++++++++++ src/stamp/modeling/registry.py | 10 + src/stamp/modeling/train.py | 156 +++--- src/stamp/preprocessing/__init__.py | 2 +- .../extractor/chief_ctranspath.py | 2 +- .../preprocessing/extractor/ctranspath.py | 2 +- .../preprocessing/extractor/dinobloom.py | 2 +- src/stamp/preprocessing/tiling.py | 2 +- src/stamp/statistics/__init__.py | 40 +- src/stamp/statistics/survival.py | 12 +- src/stamp/types.py | 1 + src/stamp/{ => utils}/cache.py | 0 src/stamp/{ => utils}/config.py | 0 src/stamp/{ => utils}/seed.py | 0 src/stamp/utils/target_file.py | 351 ++++++++++++ tests/random_data.py | 86 +++ tests/test_cache_tiles.py | 2 +- tests/test_config.py | 2 +- tests/test_crossval.py | 2 +- tests/test_data.py | 2 +- tests/test_deployment.py | 31 +- tests/test_encoders.py | 2 +- tests/test_feature_extractors.py | 2 +- tests/test_model.py | 112 ++++ tests/test_statistics.py | 151 ++++++ tests/test_train_deploy.py | 99 +++- uv.lock | 9 +- 41 files changed, 2210 insertions(+), 453 deletions(-) create mode 100644 src/stamp/modeling/models/barspoon.py rename src/stamp/{ => utils}/cache.py (100%) rename src/stamp/{ => utils}/config.py (100%) rename src/stamp/{ => utils}/seed.py (100%) create mode 100644 src/stamp/utils/target_file.py diff --git a/src/stamp/__main__.py b/src/stamp/__main__.py index 4ab8416f..ffa98bae 100755 --- a/src/stamp/__main__.py +++ b/src/stamp/__main__.py @@ -6,14 +6,14 @@ import yaml -from stamp.config import StampConfig from stamp.modeling.config import ( AdvancedConfig, MlpModelParams, ModelParams, VitModelParams, ) -from stamp.seed import Seed +from stamp.utils.config import StampConfig +from stamp.utils.seed import Seed STAMP_FACTORY_SETTINGS = Path(__file__).with_name("config.yaml") diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index 796140a5..4f16dcb2 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -4,7 +4,7 @@ preprocessing: # Extractor to use for feature extractor. Possible options are "ctranspath", # "uni", "conch", "chief-ctranspath", "conch1_5", "uni2", "dino-bloom", # "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow", - # "virchow-full", "musk", "mstar", "plip" + # "virchow-full", "musk", "mstar", "plip", "ticon" # Some of them require requesting access to the respective authors beforehand. extractor: "chief-ctranspath" @@ -73,6 +73,8 @@ crossval: # Name of the column from the clini table to train on. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # For survival (should be status and follow-up days columns in clini table) # status_label: "event" @@ -130,6 +132,8 @@ training: # Name of the column from the clini table to train on. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # For survival (should be status and follow-up days columns in clini table) # status_label: "event" @@ -172,6 +176,8 @@ deployment: # Name of the column from the clini to compare predictions to. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # For survival (should be status and follow-up days columns in clini table) # status_label: "event" @@ -197,6 +203,8 @@ statistics: # Name of the target label. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # A lot of the statistics are computed "one-vs-all", i.e. there needs to be # a positive class to calculate the statistics for. @@ -316,7 +324,7 @@ advanced_config: max_lr: 1e-4 div_factor: 25. # Select a model regardless of task - model_name: "vit" # or mlp, trans_mil + model_name: "vit" # or mlp, trans_mil, barspoon model_params: vit: # Vision Transformer @@ -335,3 +343,15 @@ advanced_config: dim_hidden: 512 num_layers: 2 dropout: 0.25 + + # NOTE: Only the `barspoon` model supports multi-target classification + # (i.e. `ground_truth_label` can be a list of column names). Other + # models expect a single target column. + barspoon: # Encoder-Decoder Transformer for multi-target classification + d_model: 512 + num_encoder_heads: 8 + num_decoder_heads: 8 + num_encoder_layers: 2 + num_decoder_layers: 2 + dim_feedforward: 2048 + positional_encoding: true \ No newline at end of file diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index 5827e884..86daa54a 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -12,11 +12,11 @@ from tqdm import tqdm import stamp -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.modeling.data import CoordsInfo, get_coords, read_table from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, PandasLabel +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/chief.py b/src/stamp/encoding/encoder/chief.py index 2ad4b91b..924ceebb 100644 --- a/src/stamp/encoding/encoder/chief.py +++ b/src/stamp/encoding/encoder/chief.py @@ -10,11 +10,11 @@ from numpy import ndarray from tqdm import tqdm -from stamp.cache import STAMP_CACHE_DIR, file_digest, get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, PandasLabel +from stamp.utils.cache import STAMP_CACHE_DIR, file_digest, get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/eagle.py b/src/stamp/encoding/encoder/eagle.py index 03bc833e..9266f315 100644 --- a/src/stamp/encoding/encoder/eagle.py +++ b/src/stamp/encoding/encoder/eagle.py @@ -9,13 +9,13 @@ from torch import Tensor from tqdm import tqdm -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.encoding.encoder.chief import CHIEF from stamp.modeling.data import CoordsInfo from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, PandasLabel +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/gigapath.py b/src/stamp/encoding/encoder/gigapath.py index 9cb3f6f5..4c0a2f6b 100644 --- a/src/stamp/encoding/encoder/gigapath.py +++ b/src/stamp/encoding/encoder/gigapath.py @@ -9,12 +9,12 @@ from gigapath import slide_encoder from tqdm import tqdm -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.modeling.data import CoordsInfo from stamp.preprocessing.config import ExtractorName from stamp.types import PandasLabel, SlideMPP +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/madeleine.py b/src/stamp/encoding/encoder/madeleine.py index 5798a592..a0c74dcd 100644 --- a/src/stamp/encoding/encoder/madeleine.py +++ b/src/stamp/encoding/encoder/madeleine.py @@ -3,10 +3,10 @@ import torch from numpy import ndarray -from stamp.cache import STAMP_CACHE_DIR from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.preprocessing.config import ExtractorName +from stamp.utils.cache import STAMP_CACHE_DIR try: from madeleine.models.factory import create_model_from_pretrained diff --git a/src/stamp/encoding/encoder/titan.py b/src/stamp/encoding/encoder/titan.py index 1012d98f..568254ca 100644 --- a/src/stamp/encoding/encoder/titan.py +++ b/src/stamp/encoding/encoder/titan.py @@ -10,12 +10,12 @@ from tqdm import tqdm from transformers import AutoModel -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.modeling.data import CoordsInfo from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, Microns, PandasLabel, SlideMPP +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/heatmaps/__init__.py b/src/stamp/heatmaps/__init__.py index 446d85d6..22fb5250 100644 --- a/src/stamp/heatmaps/__init__.py +++ b/src/stamp/heatmaps/__init__.py @@ -70,14 +70,14 @@ def _attention_rollout_single( device = feats.device - # --- 1. Forward pass to fill attn_weights in each SelfAttention layer --- + # 1. Forward pass to fill attn_weights in each SelfAttention layer _ = model( bags=feats.unsqueeze(0), coords=coords.unsqueeze(0), mask=torch.zeros(1, len(feats), dtype=torch.bool, device=device), ) - # --- 2. Rollout computation --- + # 2. Rollout computation attn_rollout: torch.Tensor | None = None for layer in model.transformer.layers: # type: ignore attn = getattr(layer[0], "attn_weights", None) # SelfAttention.attn_weights @@ -96,10 +96,10 @@ def _attention_rollout_single( if attn_rollout is None: raise RuntimeError("No attention maps collected from transformer layers.") - # --- 3. Extract CLS → tiles attention --- + # 3. Extract CLS → tiles attention cls_attn = attn_rollout[0, 1:] # [tile] - # --- 4. Normalize for visualization consistency --- + # 4. Normalize for visualization consistency cls_attn = cls_attn - cls_attn.min() cls_attn = cls_attn / (cls_attn.max().clamp(min=1e-8)) diff --git a/src/stamp/modeling/config.py b/src/stamp/modeling/config.py index 21ce69db..5b9a6bcc 100644 --- a/src/stamp/modeling/config.py +++ b/src/stamp/modeling/config.py @@ -21,7 +21,7 @@ class TrainConfig(BaseModel): ) feature_dir: Path = Field(description="Directory containing feature files") - ground_truth_label: PandasLabel | None = Field( + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = Field( default=None, description="Name of categorical column in clinical table to train on", ) @@ -64,7 +64,7 @@ class DeploymentConfig(BaseModel): slide_table: Path feature_dir: Path - ground_truth_label: PandasLabel | None = None + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None patient_label: PandasLabel = "PATIENT" filename_label: PandasLabel = "FILENAME" @@ -99,8 +99,29 @@ class TransMILModelParams(BaseModel): dim_hidden: int = 512 +class BarspoonParams(BaseModel): + model_config = ConfigDict(extra="forbid") + d_model: int = 512 + num_encoder_heads: int = 8 + num_decoder_heads: int = 8 + num_encoder_layers: int = 2 + num_decoder_layers: int = 2 + dim_feedforward: int = 2048 + positional_encoding: bool = True + # Other hparams + learning_rate: float = 1e-4 + + class LinearModelParams(BaseModel): model_config = ConfigDict(extra="forbid") + num_encoder_heads: int = 8 + num_decoder_heads: int = 8 + num_encoder_layers: int = 2 + num_decoder_layers: int = 2 + dim_feedforward: int = 2048 + positional_encoding: bool = True + # Other hparams + learning_rate: float = 1e-4 class ModelParams(BaseModel): @@ -109,6 +130,7 @@ class ModelParams(BaseModel): trans_mil: TransMILModelParams = Field(default_factory=TransMILModelParams) mlp: MlpModelParams = Field(default_factory=MlpModelParams) linear: LinearModelParams = Field(default_factory=LinearModelParams) + barspoon: BarspoonParams = Field(default_factory=BarspoonParams) class AdvancedConfig(BaseModel): diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index 43e76f01..2caccecb 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -1,8 +1,10 @@ import logging +from collections import Counter from collections.abc import Mapping, Sequence -from typing import Any, Final +from typing import Any, cast import numpy as np +import torch from pydantic import BaseModel from sklearn.model_selection import KFold, StratifiedKFold @@ -10,13 +12,8 @@ from stamp.modeling.data import ( PatientData, create_dataloader, - detect_feature_type, - filter_complete_patient_data_, - load_patient_level_data, + load_patient_data_, log_patient_class_summary, - patient_to_ground_truth_from_clini_table_, - patient_to_survival_from_clini_table_, - slide_to_patient_from_slide_table_, ) from stamp.modeling.deploy import ( _predict, @@ -28,7 +25,6 @@ from stamp.modeling.train import setup_model_for_training, train_model_ from stamp.modeling.transforms import VaryPrecisionTransform from stamp.types import ( - FeaturePath, GroundTruth, PatientId, ) @@ -53,67 +49,31 @@ def categorical_crossval_( config: CrossvalConfig, advanced: AdvancedConfig, ) -> None: - feature_type = detect_feature_type(config.feature_dir) + if config.task is None: + raise ValueError( + "task must be set to 'classification' | 'regression' | 'survival'" + ) + + patient_to_data, feature_type = load_patient_data_( + feature_dir=config.feature_dir, + clini_table=config.clini_table, + slide_table=config.slide_table, + task=config.task, + ground_truth_label=config.ground_truth_label, + time_label=config.time_label, + status_label=config.status_label, + patient_label=config.patient_label, + filename_label=config.filename_label, + drop_patients_with_missing_ground_truth=True, + ) _logger.info(f"Detected feature type: {feature_type}") - if feature_type in ("tile", "slide"): - if config.slide_table is None: - raise ValueError("A slide table is required for modeling") - if config.task == "survival": - if config.time_label is None or config.status_label is None: - raise ValueError( - "Both time_label and status_label are is required for survival modeling" - ) - patient_to_ground_truth: dict[PatientId, GroundTruth] = ( - patient_to_survival_from_clini_table_( - clini_table_path=config.clini_table, - time_label=config.time_label, - status_label=config.status_label, - patient_label=config.patient_label, - ) - ) - else: - if config.ground_truth_label is None: - raise ValueError( - "Ground truth label is required for classification or regression modeling" - ) - patient_to_ground_truth: dict[PatientId, GroundTruth] = ( - patient_to_ground_truth_from_clini_table_( - clini_table_path=config.clini_table, - ground_truth_label=config.ground_truth_label, - patient_label=config.patient_label, - ) - ) - slide_to_patient: Final[dict[FeaturePath, PatientId]] = ( - slide_to_patient_from_slide_table_( - slide_table_path=config.slide_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - filename_label=config.filename_label, - ) - ) - patient_to_data: Mapping[PatientId, PatientData] = ( - filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, - slide_to_patient=slide_to_patient, - drop_patients_with_missing_ground_truth=True, - ) - ) - elif feature_type == "patient": - patient_to_data: Mapping[PatientId, PatientData] = load_patient_level_data( - task=config.task, - clini_table=config.clini_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - ground_truth_label=config.ground_truth_label, - time_label=config.time_label, - status_label=config.status_label, - ) - patient_to_ground_truth: dict[PatientId, GroundTruth] = { - pid: pd.ground_truth for pid, pd in patient_to_data.items() - } - else: - raise RuntimeError(f"Unsupported feature type: {feature_type}") + patient_to_ground_truth = { + pid: pd.ground_truth for pid, pd in patient_to_data.items() + } + + if feature_type not in ("tile", "slide", "patient"): + raise ValueError(f"Unknown feature type: {feature_type}") config.output_dir.mkdir(parents=True, exist_ok=True) splits_file = config.output_dir / "splits.json" @@ -158,18 +118,51 @@ def categorical_crossval_( f"{ground_truths_not_in_split}" ) + categories_for_export: ( + dict[str, list] | list + ) = [] # declare upfront to avoid unbound variable warnings + categories: Sequence[GroundTruth] | list | None = [] # type: ignore # declare upfront to avoid unbound variable warnings + if config.task == "classification": - categories = config.categories or sorted( - { - patient_data.ground_truth - for patient_data in patient_to_data.values() - if patient_data.ground_truth is not None - } - ) - log_patient_class_summary( - patient_to_data={pid: patient_to_data[pid] for pid in patient_to_data}, - categories=categories, - ) + # Determine categories for training (single-target) and for export (supports multi-target) + if isinstance(config.ground_truth_label, str): + categories = config.categories or sorted( + { + patient_data.ground_truth + for patient_data in patient_to_data.values() + if patient_data.ground_truth is not None + } + ) + log_patient_class_summary( + patient_to_data={pid: patient_to_data[pid] for pid in patient_to_data}, + categories=categories, + ) + categories_for_export = cast(list, categories) + else: + # Multi-target: build a mapping from target label -> sorted list of categories + categories_accum: dict[str, set[GroundTruth]] = {} + for patient_data in patient_to_data.values(): + gt = patient_data.ground_truth + if isinstance(gt, dict): + for k, v in gt.items(): + if v is not None: + categories_accum.setdefault(k, set()).add(v) + categories_for_export = {k: sorted(v) for k, v in categories_accum.items()} + # Log summary per target + for t, cats in categories_for_export.items(): + ground_truths = [ + pd.ground_truth.get(t) + for pd in patient_to_data.values() + if isinstance(pd.ground_truth, dict) + and pd.ground_truth.get(t) is not None + ] + counter = Counter(ground_truths) + _logger.info( + f"{t} | Total patients: {len(ground_truths)} | " + + " | ".join([f"Class {c}: {counter.get(c, 0)}" for c in cats]) + ) + # For training, categories can remain None (inferred later) + categories = config.categories or None else: categories = [] @@ -206,12 +199,18 @@ def categorical_crossval_( }, categories=( categories - or sorted( - { - patient_data.ground_truth - for patient_data in patient_to_data.values() - if patient_data.ground_truth is not None - } + if categories is not None + else ( + sorted( + { + patient_data.ground_truth + for patient_data in patient_to_data.values() + if patient_data.ground_truth is not None + and not isinstance(patient_data.ground_truth, dict) + } + ) + if not isinstance(config.ground_truth_label, Sequence) + else None ) ), train_transform=( @@ -263,30 +262,48 @@ def categorical_crossval_( ) if config.task == "survival": - _to_survival_prediction_df( - patient_to_ground_truth=patient_to_ground_truth, - predictions=predictions, - patient_label=config.patient_label, - cut_off=getattr(model.hparams, "train_pred_median", None), - ).to_csv(split_dir / "patient-preds.csv", index=False) + if isinstance(config.ground_truth_label, str): + _to_survival_prediction_df( + patient_to_ground_truth=cast( + Mapping[PatientId, str | None], patient_to_ground_truth + ), + predictions=cast(Mapping[PatientId, torch.Tensor], predictions), + patient_label=config.patient_label, + cut_off=getattr(model.hparams, "train_pred_median", None), + ).to_csv(split_dir / "patient-preds.csv", index=False) + else: + _logger.warning( + "Multi-target survival prediction export not yet supported; skipping CSV save" + ) elif config.task == "regression": if config.ground_truth_label is None: raise RuntimeError("Grounf truth label is required for regression") - _to_regression_prediction_df( - patient_to_ground_truth=patient_to_ground_truth, - predictions=predictions, - patient_label=config.patient_label, - ground_truth_label=config.ground_truth_label, - ).to_csv(split_dir / "patient-preds.csv", index=False) + if isinstance(config.ground_truth_label, str): + _to_regression_prediction_df( + patient_to_ground_truth=cast( + Mapping[PatientId, str | None], patient_to_ground_truth + ), + predictions=cast(Mapping[PatientId, torch.Tensor], predictions), + patient_label=config.patient_label, + ground_truth_label=config.ground_truth_label, + ).to_csv(split_dir / "patient-preds.csv", index=False) + else: + _logger.warning( + "Multi-target regression prediction export not yet supported; skipping CSV save" + ) else: if config.ground_truth_label is None: raise RuntimeError( "Grounf truth label is required for classification" ) _to_prediction_df( - categories=categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, - predictions=predictions, + predictions=cast( + Mapping[PatientId, torch.Tensor] + | Mapping[PatientId, dict[str, torch.Tensor]], + predictions, + ), patient_label=config.patient_label, ground_truth_label=config.ground_truth_label, ).to_csv(split_dir / "patient-preds.csv", index=False) @@ -296,6 +313,16 @@ def _get_splits( *, patient_to_data: Mapping[PatientId, PatientData[Any]], n_splits: int, spliter ) -> _Splits: patients = np.array(list(patient_to_data.keys())) + + # Extract ground truth for stratification. + # For multi-target (dict), use the first target's value + y_strat = np.array( + [ + next(iter(gt.values())) if isinstance(gt, dict) else gt + for gt in [patient.ground_truth for patient in patient_to_data.values()] + ] + ) + skf = spliter(n_splits=n_splits, shuffle=True, random_state=0) splits = _Splits( splits=[ @@ -303,12 +330,7 @@ def _get_splits( train_patients=set(patients[train_indices]), test_patients=set(patients[test_indices]), ) - for train_indices, test_indices in skf.split( - patients, - np.array( - [patient.ground_truth for patient in patient_to_data.values()] - ), - ) + for train_indices, test_indices in skf.split(patients, y_strat) ] ) return splits diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 6cabec64..8b30ab11 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -5,19 +5,29 @@ from dataclasses import KW_ONLY, dataclass from itertools import groupby from pathlib import Path -from typing import IO, BinaryIO, Counter, Generic, TextIO, TypeAlias, Union, cast +from typing import ( + IO, + Any, + BinaryIO, + Dict, + Final, + Generic, + List, + TextIO, + TypeAlias, + Union, + cast, +) import h5py import numpy as np import pandas as pd import torch -from jaxtyping import Float from packaging.version import Version from torch import Tensor from torch.utils.data import DataLoader, Dataset import stamp -from stamp.seed import Seed from stamp.types import ( Bags, BagSize, @@ -35,6 +45,7 @@ Task, TilePixels, ) +from stamp.utils.seed import Seed _logger = logging.getLogger("stamp") @@ -43,14 +54,17 @@ __copyright__ = "Copyright (C) 2022-2025 Marko van Treeck, Minh Duc Nguyen" __license__ = "MIT" -_Bag: TypeAlias = Float[Tensor, "tile feature"] -_EncodedTarget: TypeAlias = Float[Tensor, "category_is_hot"] | Float[Tensor, "1"] # noqa: F821 +_Bag: TypeAlias = Tensor +_EncodedTarget: TypeAlias = ( + Tensor | dict[str, Tensor] +) # Union of encoded targets or multi-target dict _BinaryIOLike: TypeAlias = Union[BinaryIO, IO[bytes]] """The ground truth, encoded numerically - classification: one-hot float [C] - regression: float [1] +- multi-target: dict[target_name -> one-hot/regression value] """ -_Coordinates: TypeAlias = Float[Tensor, "tile 2"] +_Coordinates: TypeAlias = Tensor @dataclass @@ -64,7 +78,7 @@ class PatientData(Generic[GroundTruthType]): def tile_bag_dataloader( *, - patient_data: Sequence[PatientData[GroundTruth | None]], + patient_data: Sequence[PatientData[GroundTruth | None | dict]], bag_size: int | None, task: Task, categories: Sequence[Category] | None = None, @@ -74,7 +88,7 @@ def tile_bag_dataloader( transform: Callable[[Tensor], Tensor] | None, ) -> tuple[ DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], - Sequence[Category], + Sequence[Category] | Mapping[str, Sequence[Category]], ]: """Creates a dataloader from patient data for tile-level (bagged) features. @@ -86,103 +100,139 @@ def tile_bag_dataloader( task='regression': returns float targets """ - if task == "classification": - raw_ground_truths = np.array([patient.ground_truth for patient in patient_data]) - categories = ( - categories if categories is not None else list(np.unique(raw_ground_truths)) - ) - # one_hot = torch.tensor(raw_ground_truths.reshape(-1, 1) == categories) - one_hot = torch.tensor( - raw_ground_truths.reshape(-1, 1) == categories, dtype=torch.float32 - ) - ds = BagDataset( - bags=[patient.feature_files for patient in patient_data], - bag_size=bag_size, - ground_truths=one_hot, - transform=transform, - ) - cats_out: Sequence[Category] = list(categories) - elif task == "regression": - raw_targets = np.array( - [ - np.nan if p.ground_truth is None else float(p.ground_truth) - for p in patient_data - ], - dtype=np.float32, - ) - y = torch.from_numpy(raw_targets).reshape(-1, 1) + targets, cats_out = _parse_targets( + patient_data=patient_data, + task=task, + categories=categories, + ) - ds = BagDataset( - bags=[patient.feature_files for patient in patient_data], - bag_size=bag_size, - ground_truths=y, - transform=transform, - ) - cats_out = [] + is_multitarget = isinstance(targets[0], dict) - elif task == "survival": # Not yet support logistic-harzard - times: list[float] = [] - events: list[float] = [] + collate_fn = _collate_multitarget if is_multitarget else _collate_to_tuple - for p in patient_data: - if p.ground_truth is None: - times.append(np.nan) - events.append(np.nan) - continue + ds = BagDataset( + bags=[patient.feature_files for patient in patient_data], + bag_size=bag_size, + ground_truths=targets, + transform=transform, + ) + dl = DataLoader( + ds, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=collate_fn, + worker_init_fn=Seed.get_loader_worker_init() if Seed._is_set() else None, + ) - try: - time_str, status_str = p.ground_truth.split(" ", 1) + return ( + cast( + DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], + dl, + ), + cats_out, + ) - # Handle missing values encoded as "nan" - if time_str.lower() == "nan": - times.append(np.nan) - else: - times.append(float(time_str)) - - if status_str.lower() == "nan": - events.append(np.nan) - elif status_str.lower() in {"dead", "event", "1", "Yes", "yes"}: - events.append(1.0) - elif status_str.lower() in {"alive", "censored", "0", "No", "no"}: - events.append(0.0) - else: - events.append(np.nan) # unknown status → mark missing - except Exception: +def _parse_targets( + *, + patient_data: Sequence, + task: Task, + categories: Sequence[Category] | None = None, + target_spec: dict[str, Any] | None = None, + target_label: str | None = None, +) -> tuple[ + Union[torch.Tensor, list[dict[str, torch.Tensor]]], + Sequence[Category] | Mapping[str, Sequence[Category]], +]: + """ + Parse raw GroundTruth (str) into model-ready tensors. + This is the ONLY place task semantics live. + """ + + gts = [p.ground_truth for p in patient_data] + + if task == "classification": + if any(isinstance(gt, dict) for gt in gts if gt is not None): + # infer target names from the first non-None dict + first_dict = next(gt for gt in gts if isinstance(gt, dict)) + target_names = list(first_dict.keys()) + + # infer categories per target (ignore None patients, ignore None values) + categories_out: dict[str, list[str]] = {t: [] for t in target_names} + for gt in gts: + if not isinstance(gt, dict): + continue + for t in target_names: + v = gt.get(t) + if v is not None: + categories_out[t].append(v) + + # make unique + sorted + categories_out = { + t: sorted(set(vals)) for t, vals in categories_out.items() + } + + # encode per patient; if gt missing -> all zeros + encoded: list[dict[str, Tensor]] = [] + for gt in gts: + patient_encoded: dict[str, Tensor] = {} + for t in target_names: + cats = categories_out[t] + if not isinstance(gt, dict) or gt.get(t) is None: + one_hot = torch.zeros(len(cats), dtype=torch.float32) + else: + one_hot = torch.tensor( + [gt[t] == c for c in cats], + dtype=torch.float32, + ) + patient_encoded[t] = one_hot + encoded.append(patient_encoded) + + # IMPORTANT: return categories as mapping, not list-of-target-names + return encoded, categories_out + + # single target + unique = {gt for gt in gts if gt is not None} + if len(unique) >= 2 or categories is not None: + raw = np.array([p.ground_truth for p in patient_data]) + categories = categories or list(sorted(unique)) + labels = torch.tensor( + raw.reshape(-1, 1) == categories, + dtype=torch.float32, + ) + return labels, categories + + raise ValueError( + "Only one unique class found in classification task. " + "This is usually a data or configuration error." + ) + + elif task == "regression": + y = torch.tensor( + [np.nan if gt is None else float(gt) for gt in gts], + dtype=torch.float32, + ).reshape(-1, 1) + return y, [] + + elif task == "survival": + times, events = [], [] + for gt in gts: + if gt is None: times.append(np.nan) events.append(np.nan) + continue - # Final tensor shape: (N, 2) - y = torch.tensor(np.column_stack([times, events]), dtype=torch.float32) + time_str, status_str = gt.split(" ", 1) + times.append(np.nan if time_str.lower() == "nan" else float(time_str)) + events.append(_parse_survival_status(status_str)) - ds = BagDataset( - bags=[patient.feature_files for patient in patient_data], - bag_size=bag_size, - ground_truths=y, - transform=transform, - ) - cats_out: Sequence[Category] = [] # survival has no categories + y = torch.tensor(np.column_stack([times, events]), dtype=torch.float32) + return y, [] else: - raise ValueError(f"Unknown task: {task}") - - return ( - cast( - DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], - DataLoader( - ds, - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - collate_fn=_collate_to_tuple, - worker_init_fn=Seed.get_loader_worker_init() - if Seed._is_set() - else None, - ), - ), - cats_out, - ) + raise ValueError(f"Unsupported task: {task}") def _collate_to_tuple( @@ -210,6 +260,24 @@ def _collate_to_tuple( return (bags, coords, bag_sizes, encoded_targets) +def _collate_multitarget( + items: list[tuple[_Bag, _Coordinates, BagSize, Dict[str, Tensor]]], +) -> tuple[Bags, CoordinatesBatch, BagSizes, Dict[str, Tensor]]: + bags = torch.stack([b for b, _, _, _ in items]) + coords = torch.stack([c for _, c, _, _ in items]) + bag_sizes = torch.tensor([s for _, _, s, _ in items]) + + acc: Dict[str, List[Tensor]] = {} + + for _, _, _, tdict in items: + for k, v in tdict.items(): + acc.setdefault(k, []).append(v) + + targets: Dict[str, Tensor] = {k: torch.stack(v, dim=0) for k, v in acc.items()} + + return bags, coords, bag_sizes, targets + + def patient_feature_dataloader( *, patient_data: Sequence[PatientData[GroundTruth | None]], @@ -237,21 +305,31 @@ def create_dataloader( *, feature_type: str, task: Task, - patient_data: Sequence[PatientData[GroundTruth | None]], + patient_data: Sequence[PatientData[GroundTruth | None | dict]], bag_size: int | None = None, batch_size: int, shuffle: bool, num_workers: int, transform: Callable[[Tensor], Tensor] | None, - categories: Sequence[Category] | None = None, -) -> tuple[DataLoader, Sequence[Category]]: + categories: Sequence[Category] | Mapping[str, Sequence[Category]] | None = None, +) -> tuple[DataLoader, Sequence[Category] | Mapping[str, Sequence[Category]]]: """Unified dataloader for all feature types and tasks.""" if feature_type == "tile": + # For multi-target classification, categories may be a mapping from + # target name to per-target categories. _parse_targets (used inside + # tile_bag_dataloader) only consumes explicit categories for the + # single-target case, so we pass a sequence or None here. + cats_arg: Sequence[Category] | None + if isinstance(categories, Mapping): + cats_arg = None + else: + cats_arg = categories + return tile_bag_dataloader( patient_data=patient_data, bag_size=bag_size, task=task, - categories=categories, + categories=cats_arg, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, @@ -263,21 +341,32 @@ def create_dataloader( if task == "classification": raw = np.array([p.ground_truth for p in patient_data]) - categories = categories or list(np.unique(raw)) - labels = torch.tensor(raw.reshape(-1, 1) == categories, dtype=torch.float32) - elif task == "regression": + categories_out = categories or list(np.unique(raw)) labels = torch.tensor( - [ - float(gt) - for gt in (p.ground_truth for p in patient_data) - if gt is not None - ], - dtype=torch.float32, - ).reshape(-1, 1) + raw.reshape(-1, 1) == categories_out, dtype=torch.float32 + ) + elif task == "regression": + values: list[float] = [] + for gt in (p.ground_truth for p in patient_data): + if gt is None: + continue + if isinstance(gt, dict): + # Use first value for multi-target regression + first_val = next(iter(gt.values())) + values.append(float(first_val)) + else: + values.append(float(gt)) + + labels = torch.tensor(values, dtype=torch.float32).reshape(-1, 1) elif task == "survival": times, events = [], [] for p in patient_data: - t, e = (p.ground_truth or "nan nan").split(" ", 1) + if isinstance(p.ground_truth, dict): + # Multi-target survival: use first target + val = list(p.ground_truth.values())[0] + t, e = (val or "nan nan").split(" ", 1) + else: + t, e = (p.ground_truth or "nan nan").split(" ", 1) times.append(float(t) if t.lower() != "nan" else np.nan) events.append(_parse_survival_status(e)) @@ -340,9 +429,9 @@ def load_patient_level_data( clini_table: Path, feature_dir: Path, patient_label: PandasLabel, - ground_truth_label: PandasLabel | None = None, # <- now optional - time_label: PandasLabel | None = None, # <- for survival - status_label: PandasLabel | None = None, # <- for survival + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None, + time_label: PandasLabel | None = None, + status_label: PandasLabel | None = None, feature_ext: str = ".h5", ) -> dict[PatientId, PatientData]: """ @@ -419,7 +508,7 @@ class BagDataset(Dataset[tuple[_Bag, _Coordinates, BagSize, _EncodedTarget]]): If `bag_size` is None, all the samples will be used. """ - ground_truths: Float[Tensor, "index category_is_hot"] | Float[Tensor, "index 1"] + ground_truths: Tensor | list[dict[str, Tensor]] # ground_truths: Bool[Tensor, "index category_is_hot"] # """The ground truth for each bag, one-hot encoded.""" @@ -529,7 +618,7 @@ def mpp(self) -> SlideMPP: def get_coords(feature_h5: h5py.File) -> CoordsInfo: - # --- NEW: handle missing coords ----multiplex data bypass: no coords found; generated fake coords + # NEW: handle missing coords - multiplex data bypass: no coords found; generated fake coords if "coords" not in feature_h5: feats_obj = feature_h5["patch_embeddings"] @@ -627,33 +716,71 @@ def patient_to_ground_truth_from_clini_table_( *, clini_table_path: Path | TextIO, patient_label: PandasLabel, - ground_truth_label: PandasLabel, -) -> dict[PatientId, GroundTruth]: - """Loads the patients and their ground truths from a clini table.""" + ground_truth_label: PandasLabel | Sequence[PandasLabel], +) -> ( + dict[PatientId, GroundTruth | None] | dict[PatientId, dict[str, GroundTruth | None]] +): + """Loads the patients and their ground truths from a clini table. + + `ground_truth_label` may be either a single column name (str) or a sequence + of column names. In the latter case the returned mapping will contain a + dict mapping column -> value for each patient (supporting multi-target + setups). + """ + # Normalize to list for uniform handling + if isinstance(ground_truth_label, str): + cols = [patient_label, ground_truth_label] + multi = False + target_cols_inner: list[PandasLabel] = [] + else: + cols = [patient_label, *list(ground_truth_label)] + multi = True + target_cols_inner = [c for c in cols if c != patient_label] + clini_df = read_table( clini_table_path, - usecols=[patient_label, ground_truth_label], + usecols=cols, dtype=str, - ).dropna() + ) + + # If multi-target, keep rows where at least one target is present; for + # single target behave like before and drop rows missing the value. + if multi: + clini_df = clini_df.dropna(subset=target_cols_inner, how="all") + else: + clini_df = clini_df.dropna(subset=[ground_truth_label]) + try: - patient_to_ground_truth: Mapping[PatientId, GroundTruth] = clini_df.set_index( - patient_label, verify_integrity=True - )[ground_truth_label].to_dict() + if multi: + # Build mapping patient -> {col: value} + result: dict[PatientId, dict[str, GroundTruth | None]] = {} + for _, row in clini_df.iterrows(): + pid = row[patient_label] + # Convert pandas nan to None and keep strings otherwise + result[pid] = { + col: (None if pd.isna(row[col]) else str(row[col])) + for col in target_cols_inner + } + return result + else: + patient_to_ground_truth: Mapping[PatientId, str] = cast( + Mapping[PatientId, str], + clini_df.set_index(patient_label, verify_integrity=True)[ + cast(PandasLabel, ground_truth_label) + ].to_dict(), + ) + return cast(dict[PatientId, GroundTruth | None], patient_to_ground_truth) except KeyError as e: if patient_label not in clini_df: raise ValueError( f"{patient_label} was not found in clini table " f"(columns in clini table: {clini_df.columns})" ) from e - elif ground_truth_label not in clini_df: + else: raise ValueError( - f"{ground_truth_label} was not found in clini table " + f"One or more ground truth columns were not found in clini table " f"(columns in clini table: {clini_df.columns})" ) from e - else: - raise e from e - - return patient_to_ground_truth def patient_to_survival_from_clini_table_( @@ -667,7 +794,6 @@ def patient_to_survival_from_clini_table_( Loads patients and their survival ground truths (time + event) from a clini table. Returns - ------- dict[PatientId, GroundTruth] Mapping patient_id -> "time status" (e.g. "302 dead", "476 alive"). """ @@ -780,7 +906,9 @@ def read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: def filter_complete_patient_data_( *, - patient_to_ground_truth: Mapping[PatientId, GroundTruth | None], + patient_to_ground_truth: Mapping[ + PatientId, GroundTruth | dict[str, GroundTruth] | None + ], slide_to_patient: Mapping[FeaturePath, PatientId], drop_patients_with_missing_ground_truth: bool, ) -> Mapping[PatientId, PatientData]: @@ -865,7 +993,7 @@ def _log_patient_slide_feature_inconsistencies( ) -def get_stride(coords: Float[Tensor, "tile 2"]) -> float: +def get_stride(coords: Tensor) -> float: """Gets the minimum step width between any two coordintes.""" xs: Tensor = coords[:, 0].unique(sorted=True) ys: Tensor = coords[:, 1].unique(sorted=True) @@ -927,26 +1055,130 @@ def _parse_survival_status(value) -> int | None: ) +def load_patient_data_( + *, + feature_dir: Path, + clini_table: Path, + slide_table: Path | None, + task: Task, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, + time_label: PandasLabel | None, + status_label: PandasLabel | None, + patient_label: PandasLabel, + filename_label: PandasLabel, + drop_patients_with_missing_ground_truth: bool = True, +) -> tuple[Mapping[PatientId, PatientData], str]: + """Load patient data based on feature type (tile, slide, or patient). + + This consolidates the common data loading logic used across train, crossval, and deploy. + + Returns: + (patient_to_data, feature_type) + """ + feature_type = detect_feature_type(feature_dir) + + if feature_type in ("tile", "slide"): + if slide_table is None: + raise ValueError("A slide table is required for tile/slide-level features") + + # Load ground truth based on task + if task == "survival": + if time_label is None or status_label is None: + raise ValueError( + "Both time_label and status_label are required for survival modeling" + ) + patient_to_ground_truth = patient_to_survival_from_clini_table_( + clini_table_path=clini_table, + time_label=time_label, + status_label=status_label, + patient_label=patient_label, + ) + else: + if ground_truth_label is None: + raise ValueError( + "Ground truth label is required for classification or regression modeling" + ) + patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( + clini_table_path=clini_table, + ground_truth_label=ground_truth_label, + patient_label=patient_label, + ) + + # Link slides to patients + slide_to_patient: Final[dict[FeaturePath, PatientId]] = ( + slide_to_patient_from_slide_table_( + slide_table_path=slide_table, + feature_dir=feature_dir, + patient_label=patient_label, + filename_label=filename_label, + ) + ) + + # Filter to complete patient data + patient_to_data = filter_complete_patient_data_( + patient_to_ground_truth=cast( + Mapping[PatientId, GroundTruth | dict[str, GroundTruth] | None], + patient_to_ground_truth, + ), + slide_to_patient=slide_to_patient, + drop_patients_with_missing_ground_truth=drop_patients_with_missing_ground_truth, + ) + elif feature_type == "patient": + patient_to_data = load_patient_level_data( + task=task, + clini_table=clini_table, + feature_dir=feature_dir, + patient_label=patient_label, + ground_truth_label=ground_truth_label, + time_label=time_label, + status_label=status_label, + ) + else: + raise RuntimeError(f"Unknown feature type: {feature_type}") + + return patient_to_data, feature_type + + def log_patient_class_summary( *, patient_to_data: Mapping[PatientId, PatientData], categories: Sequence[Category] | None, - prefix: str = "", ) -> None: + """ + Logs class distribution. + Supports both single-target and multi-target classification. + """ + ground_truths = [ - pd.ground_truth - for pd in patient_to_data.values() - if pd.ground_truth is not None + p.ground_truth for p in patient_to_data.values() if p.ground_truth is not None ] if not ground_truths: - _logger.warning(f"{prefix}No ground truths available to summarize.") + _logger.warning("No ground truths available for summary.") return - cats = categories or sorted(set(ground_truths)) - counter = Counter(ground_truths) + # Multi-target + if isinstance(ground_truths[0], dict): + # Collect per-target values + per_target: dict[str, list] = {} - _logger.info( - f"{prefix}Total patients: {len(ground_truths)} | " - + " | ".join([f"Class {c}: {counter.get(c, 0)}" for c in cats]) - ) + for gt in ground_truths: + for key, value in gt.items(): + per_target.setdefault(key, []).append(value) + + for target_name, values in per_target.items(): + counts = {} + for v in values: + counts[v] = counts.get(v, 0) + 1 + + _logger.info( + f"[Multi-target] Target '{target_name}' distribution: {counts}" + ) + + # Single-target + else: + counts = {} + for gt in ground_truths: + counts[gt] = counts.get(gt, 0) + 1 + + _logger.info(f"Class distribution: {counts}") diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index 905c6005..10272328 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd import torch -from jaxtyping import Float +import torch.nn.functional as F from lightning.pytorch.accelerators.accelerator import Accelerator from stamp.modeling.data import ( @@ -20,7 +20,7 @@ slide_to_patient_from_slide_table_, ) from stamp.modeling.registry import ModelName, load_model_class -from stamp.types import GroundTruth, PandasLabel, PatientId +from stamp.types import Category, GroundTruth, PandasLabel, PatientId __all__ = ["deploy_categorical_model_"] @@ -32,6 +32,13 @@ Logit: TypeAlias = float +# Prediction type aliases +PredictionSingle: TypeAlias = torch.Tensor +PredictionMulti: TypeAlias = dict[str, torch.Tensor] +PredictionsType: TypeAlias = Mapping[ + PatientId, Union[PredictionSingle, PredictionMulti] +] + def load_model_from_ckpt(path: Union[str, Path]): ckpt = torch.load(path, map_location="cpu", weights_only=False) @@ -50,7 +57,7 @@ def deploy_categorical_model_( clini_table: Path | None, slide_table: Path | None, feature_dir: Path, - ground_truth_label: PandasLabel | None, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, time_label: PandasLabel | None, status_label: PandasLabel | None, patient_label: PandasLabel, @@ -126,7 +133,12 @@ def deploy_categorical_model_( # classification/regression: still use ground_truth_label if ( len( - ground_truth_labels := set(model.ground_truth_label for model in models) + ground_truth_labels := { + tuple(model.ground_truth_label) + if isinstance(model.ground_truth_label, list) + else (model.ground_truth_label,) + for model in models + } ) != 1 ): @@ -145,17 +157,21 @@ def deploy_categorical_model_( f"{ground_truth_label} vs {model_ground_truth_label}" ) - ground_truth_label = ground_truth_label or model_ground_truth_label + ground_truth_label = ground_truth_label or cast( + PandasLabel, model_ground_truth_label + ) output_dir.mkdir(exist_ok=True, parents=True) model_categories = None if task == "classification": # Ensure the categories were the same between all models - category_sets = {tuple(m.categories) for m in models} + category_sets = { + tuple(cast(Sequence[GroundTruth], m.categories)) for m in models + } if len(category_sets) != 1: raise RuntimeError(f"Categories differ between models: {category_sets}") - model_categories = list(models[0].categories) + model_categories = list(cast(Sequence[GroundTruth], models[0].categories)) # Data loading logic if feature_type in ("tile", "slide"): @@ -171,6 +187,15 @@ def deploy_categorical_model_( ) if clini_table is not None: if task == "survival": + if not hasattr(models[0], "time_label") or not isinstance( + models[0].time_label, str + ): + raise AttributeError("Model is missing valid 'time_label' (str).") + if not hasattr(models[0], "status_label") or not isinstance( + models[0].status_label, str + ): + raise AttributeError("Model is missing valid 'status_label' (str).") + patient_to_ground_truth = patient_to_survival_from_clini_table_( clini_table_path=clini_table, patient_label=patient_label, @@ -192,7 +217,10 @@ def deploy_categorical_model_( patient_id: None for patient_id in set(slide_to_patient.values()) } patient_to_data = filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, + patient_to_ground_truth=cast( + Mapping[PatientId, GroundTruth | None], + patient_to_ground_truth, + ), slide_to_patient=slide_to_patient, drop_patients_with_missing_ground_truth=False, ) @@ -241,7 +269,10 @@ def deploy_categorical_model_( "regression": _to_regression_prediction_df, "survival": _to_survival_prediction_df, }[task] - all_predictions: list[Mapping[PatientId, Float[torch.Tensor, "category"]]] = [] # noqa: F821 + all_predictions: list[PredictionsType] = [] + categories_for_export: ( + Sequence[Category] | Mapping[str, Sequence[Category]] | None + ) = cast(Sequence[Category] | Mapping[str, Sequence[Category]] | None, None) for model_i, model in enumerate(models): predictions = _predict( model=model, @@ -251,6 +282,26 @@ def deploy_categorical_model_( ) all_predictions.append(predictions) + if isinstance(next(iter(predictions.values())), dict): + # Multi-target case: gather categories across all targets for export (use model categories if available, else infer from GT) + categories_accum: dict[str, set[GroundTruth]] = {} + + for pd_item in patient_to_data.values(): + gt = pd_item.ground_truth + if isinstance(gt, dict): + for k, v in gt.items(): + if v is not None: + categories_accum.setdefault(k, set()).add(v) + + categories_for_export = {k: sorted(v) for k, v in categories_accum.items()} + + else: + # Single-target case: use categories from model if available, else infer from GT + if task == "classification": + categories_for_export = models[0].categories + else: + categories_for_export = [] + # cut-off values from survival ckpt cut_off = ( getattr(model.hparams, "train_pred_median", None) @@ -261,7 +312,7 @@ def deploy_categorical_model_( # Only save individual model files when deploying multiple models (ensemble) if len(models) > 1: df_builder( - categories=model_categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, predictions=predictions, patient_label=patient_label, @@ -270,7 +321,7 @@ def deploy_categorical_model_( ).to_csv(output_dir / f"patient-preds-{model_i}.csv", index=False) else: df_builder( - categories=model_categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, predictions=predictions, patient_label=patient_label, @@ -279,17 +330,29 @@ def deploy_categorical_model_( ).to_csv(output_dir / "patient-preds.csv", index=False) if task == "classification": - # TODO we probably also want to save the 95% confidence interval in addition to the mean + # compute mean prediction across models (supports single- and multi-target) + mean_preds: dict[PatientId, object] = {} + for pid in patient_ids: + model_preds = cast( + list[torch.Tensor], [preds[pid] for preds in all_predictions] + ) + firstp = model_preds[0] + if isinstance(firstp, dict): + # per-target averaging + mean_preds[pid] = { + t: torch.stack([p[t] for p in model_preds]).mean(dim=0) + for t in firstp.keys() + } + else: + mean_preds[pid] = torch.stack(model_preds).mean(dim=0) + + assert categories_for_export is not None, ( + "categories_for_export must be set before use" + ) df_builder( - categories=model_categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, - predictions={ - # Mean prediction - patient_id: torch.stack( - [predictions[patient_id] for predictions in all_predictions] - ).mean(dim=0) - for patient_id in patient_ids - }, + predictions=mean_preds, patient_label=patient_label, ground_truth_label=ground_truth_label, ).to_csv(output_dir / "patient-preds_95_confidence_interval.csv", index=False) @@ -301,7 +364,7 @@ def _predict( test_dl: torch.utils.data.DataLoader, patient_ids: Sequence[PatientId], accelerator: str | Accelerator, -) -> Mapping[PatientId, Float[torch.Tensor, "..."]]: +) -> PredictionsType: model = model.eval() torch.set_float32_matmul_precision("medium") @@ -310,8 +373,11 @@ def _predict( getattr(model, "train_patients", []) ) | set(getattr(model, "valid_patients", [])) if overlap := patients_used_for_training & set(patient_ids): - raise ValueError( - f"some of the patients in the validation set were used during training: {overlap}" + _logger.critical( + "DATA LEAKAGE DETECTED: %d patient(s) in deployment set were used " + "during training/validation. Overlapping IDs: %s", + len(overlap), + sorted(overlap), ) trainer = lightning.Trainer( @@ -320,51 +386,190 @@ def _predict( logger=False, ) - raw_preds = torch.concat(cast(list[torch.Tensor], trainer.predict(model, test_dl))) + outs = trainer.predict(model, test_dl) + + if not outs: + return {} + + first = outs[0] + + # Multi-target case: each element of outs is a dict[target_label -> tensor] + if isinstance(first, dict): + per_target_lists: dict[str, list[torch.Tensor]] = {} + for out in outs: + if not isinstance(out, dict): + raise RuntimeError("Mixed prediction output types from model") + for k, v in out.items(): + per_target_lists.setdefault(k, []).append(v) + + per_target_tensors: dict[str, torch.Tensor] = { + k: torch.cat(vlist, dim=0) for k, vlist in per_target_lists.items() + } + + if getattr(model.hparams, "task", None) == "classification": + for k in list(per_target_tensors.keys()): + per_target_tensors[k] = torch.softmax(per_target_tensors[k], dim=1) + + # build per-patient dicts + num_preds = next(iter(per_target_tensors.values())).shape[0] + predictions: dict[PatientId, dict[str, torch.Tensor]] = {} + for i, pid in enumerate(patient_ids[:num_preds]): + predictions[pid] = { + k: per_target_tensors[k][i] for k in per_target_tensors.keys() + } + + return predictions + + # Single-target case: each element of outs is a tensor + outs_single = cast(list[torch.Tensor], outs) + + raw_preds = torch.cat(outs_single, dim=0) if getattr(model.hparams, "task", None) == "classification": - predictions = torch.softmax(raw_preds, dim=1) + raw_preds = torch.softmax(raw_preds, dim=1) elif getattr(model.hparams, "task", None) == "survival": - predictions = raw_preds.squeeze(-1) # (N,) risk scores - else: # regression - predictions = raw_preds + raw_preds = raw_preds.squeeze(-1) + + result: dict[PatientId, torch.Tensor] = { + pid: raw_preds[i] for i, pid in enumerate(patient_ids) + } - return dict(zip(patient_ids, predictions, strict=True)) + return result def _to_prediction_df( *, - categories: Sequence[GroundTruth], - patient_to_ground_truth: Mapping[PatientId, GroundTruth | None], - predictions: Mapping[PatientId, torch.Tensor], + categories: Sequence[GroundTruth] | Mapping[str, Sequence[GroundTruth]], + patient_to_ground_truth: Mapping[PatientId, GroundTruth | None] + | Mapping[PatientId, dict[str, GroundTruth | None]], + predictions: Mapping[PatientId, torch.Tensor] + | Mapping[PatientId, dict[str, torch.Tensor]], patient_label: PandasLabel, - ground_truth_label: PandasLabel, + ground_truth_label: PandasLabel | Sequence[PandasLabel], **kwargs, ) -> pd.DataFrame: - """Compiles deployment results into a DataFrame.""" - return pd.DataFrame( - [ - { - patient_label: patient_id, - ground_truth_label: patient_to_ground_truth.get(patient_id), - "pred": categories[int(prediction.argmax())], - **{ - f"{ground_truth_label}_{category}": prediction[i_cat].item() - for i_cat, category in enumerate(categories) - }, - "loss": ( - torch.nn.functional.cross_entropy( - prediction.reshape(1, -1), - torch.tensor(np.where(np.array(categories) == ground_truth)[0]), - ).item() - if (ground_truth := patient_to_ground_truth.get(patient_id)) - is not None - else None - ), - } - for patient_id, prediction in predictions.items() - ] - ).sort_values(by="loss") + """Compiles deployment results into a DataFrame. + + Supports single-target and multi-target classification. + - Single-target: `predictions` maps patient -> tensor and `categories` is a sequence. + - Multi-target: `predictions` maps patient -> dict[target_label -> tensor] and + `categories` is a mapping from target_label -> sequence of category names. + """ + first_pred = next(iter(predictions.values())) + + # Multi-target predictions: dict per patient + if isinstance(first_pred, dict): + # determine target labels + target_labels = list(cast(dict, first_pred).keys()) + + # prepare categories mapping + if isinstance(categories, dict): + cats_map = categories + else: + # try infer categories list ordering: assume categories is a sequence-of-sequences + cats_map = {} + if isinstance(categories, Sequence): + try: + for i, t in enumerate(target_labels): + cats_map[t] = list( + cast(Sequence[Sequence[GroundTruth]], categories)[i] + ) + except Exception: + cats_map = {} + + # infer missing category lists from ground truth + if any(t not in cats_map for t in target_labels): + inferred: dict[str, set] = {t: set() for t in target_labels} + for pid, gt in patient_to_ground_truth.items(): + if isinstance(gt, dict): + for t in target_labels: + val = gt.get(t) + if val is not None: + inferred[t].add(val) + for t in target_labels: + if t not in cats_map: + cats_map[t] = sorted(inferred.get(t, [])) + + rows = [] + for pid, pred_dict in predictions.items(): + row: dict = {patient_label: pid} + gt_entry = patient_to_ground_truth.get(pid) + # ground truths per target + for t in target_labels: + if isinstance(gt_entry, dict): + row[t] = gt_entry.get(t) + else: + row[t] = gt_entry + + total_loss = 0.0 + has_loss = False + for t in target_labels: + tensor = cast(dict[str, torch.Tensor], pred_dict)[t] + probs = tensor.detach().cpu() + cats: Sequence[GroundTruth] = cast( + Sequence[GroundTruth], + cats_map.get(t, []), + ) + if probs.numel() == 1: + row[f"pred_{t}"] = float(probs.item()) + else: + pred_idx = int(probs.argmax().item()) + row[f"pred_{t}"] = ( + cats[pred_idx] if pred_idx < len(cats) else pred_idx + ) + for i_cat, cat in enumerate(cats): + if i_cat < probs.shape[0]: + row[f"{t}_{cat}"] = float(probs[i_cat].item()) + else: + row[f"{t}_{cat}"] = None + + if isinstance(gt_entry, dict) and (gt := gt_entry.get(t)) is not None: + try: + target_index = int(np.where(np.array(cats) == gt)[0][0]) + loss = torch.nn.functional.cross_entropy( + probs.reshape(1, -1), torch.tensor([target_index]) + ).item() + total_loss += loss + has_loss = True + except Exception: + pass + + row["loss"] = total_loss if has_loss else None + rows.append(row) + + return pd.DataFrame(rows) + + # Single-target (original behaviour) + if not all(isinstance(p, torch.Tensor) for p in predictions.values()): + raise TypeError("Single-target block received multi-target dict predictions.") + + predictions = cast(Mapping[PatientId, torch.Tensor], predictions) + + rows = [] + for pid, prediction in predictions.items(): + gt = patient_to_ground_truth.get(pid) + cats = cast(Sequence[GroundTruth], categories) + pred_idx = int(prediction.argmax()) + row = { + patient_label: pid, + ground_truth_label: gt, + "pred": cats[pred_idx], + **{ + f"{ground_truth_label}_{category}": float(prediction[i_cat].item()) + for i_cat, category in enumerate(cats) + }, + "loss": ( + torch.nn.functional.cross_entropy( + prediction.reshape(1, -1), + torch.tensor(np.where(np.array(cats) == gt)[0]), + ).item() + if gt is not None + else None + ), + } + rows.append(row) + + return pd.DataFrame(rows).sort_values(by="loss") def _to_regression_prediction_df( @@ -383,8 +588,6 @@ def _to_regression_prediction_df( - pred (float) - loss (per-sample L1 loss if GT available, else None) """ - import torch.nn.functional as F - return pd.DataFrame( [ { diff --git a/src/stamp/modeling/models/__init__.py b/src/stamp/modeling/models/__init__.py index 59a0a3aa..86003c52 100644 --- a/src/stamp/modeling/models/__init__.py +++ b/src/stamp/modeling/models/__init__.py @@ -3,7 +3,7 @@ import inspect from abc import ABC from collections.abc import Iterable, Sequence -from typing import Any, TypeAlias +from typing import Any, Mapping, TypeAlias import lightning import numpy as np @@ -14,6 +14,11 @@ from torchmetrics.classification import MulticlassAUROC import stamp +from stamp.modeling.models.barspoon import ( + EncDecTransformer, + LitMilClassificationMixin, + TargetLabel, +) from stamp.modeling.models.cox import neg_partial_log_likelihood from stamp.types import ( Bags, @@ -818,3 +823,75 @@ class LitPatientSurvival(LitSlideSurvival): """ supported_features = ["patient"] + + +class LitEncDecTransformer(LitMilClassificationMixin): + def __init__( + self, + *, + dim_input: int, + category_weights: Mapping[TargetLabel, torch.Tensor], + model_class: type[nn.Module] | None = None, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, + categories: Mapping[str, Sequence[Category]], + # Model parameters + d_model: int = 512, + num_encoder_heads: int = 8, + num_decoder_heads: int = 8, + num_encoder_layers: int = 2, + num_decoder_layers: int = 2, + dim_feedforward: int = 2048, + positional_encoding: bool = True, + # Other hparams + learning_rate: float = 1e-4, + **hparams: Any, + ) -> None: + weights_dict: dict[TargetLabel, torch.Tensor] = dict(category_weights) + super().__init__( + weights=weights_dict, + learning_rate=learning_rate, + ) + _ = hparams # so we don't get unused parameter warnings + + self.model = EncDecTransformer( + d_features=dim_input, + target_n_outs={t: len(w) for t, w in category_weights.items()}, + d_model=d_model, + num_encoder_heads=num_encoder_heads, + num_decoder_heads=num_decoder_heads, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dim_feedforward=dim_feedforward, + positional_encoding=positional_encoding, + ) + + self.hparams["supported_features"] = "tile" + self.hparams.update({"task": "classification"}) + # ---- Normalize categories into strict mapping[str, list[str]] ---- + if not isinstance(categories, Mapping): + raise ValueError( + "Multi-target classification requires categories as Mapping[str, Sequence[str]]." + ) + + normalized_categories: dict[str, list[str]] = { + str(k): list(v) for k, v in categories.items() + } + + # Sanity check: head size must match category size + for t, w in category_weights.items(): + if t not in normalized_categories: + raise ValueError(f"Missing categories for target '{t}'") + if len(normalized_categories[t]) != len(w): + raise ValueError( + f"Category mismatch for target '{t}': " + f"{len(normalized_categories[t])} categories " + f"but head has {len(w)} outputs." + ) + + self.ground_truth_label = ground_truth_label + self.categories = normalized_categories + + self.save_hyperparameters() + + def forward(self, *args): + return self.model(*args) diff --git a/src/stamp/modeling/models/barspoon.py b/src/stamp/modeling/models/barspoon.py new file mode 100644 index 00000000..f841bb3d --- /dev/null +++ b/src/stamp/modeling/models/barspoon.py @@ -0,0 +1,367 @@ +""" +Port from https://github.com/KatherLab/barspoon-transformer +""" + +import re +from typing import Any, TypeAlias + +import lightning +import torch +import torch.nn.functional as F +import torchmetrics +from packaging.version import Version +from torch import nn +from torchmetrics.classification import MulticlassAUROC +from torchmetrics.utilities.data import dim_zero_cat + +import stamp +from stamp.types import Bags, BagSizes, CoordinatesBatch + +__all__ = [ + "EncDecTransformer", + "LitMilClassificationMixin", + "SafeMulticlassAUROC", +] + + +TargetLabel: TypeAlias = str + + +class EncDecTransformer(nn.Module): + """An encoder decoder architecture for multilabel classification tasks + + This architecture is a modified version of the one found in [Attention Is + All You Need][1]: First, we project the features into a lower-dimensional + feature space, to prevent the transformer architecture's complexity from + exploding for high-dimensional features. We add sinusodial [positional + encodings][1]. We then encode these projected input tokens using a + transformer encoder stack. Next, we decode these tokens using a set of + class tokens, one per output label. Finally, we forward each of the decoded + tokens through a fully connected layer to get a label-wise prediction. + + PE1 + | + +--+ v +---+ + t1 ->|FC|--+-->| |--+ + . +--+ | E | | + . | x | | + . +--+ | m | | + tn ->|FC|--+-->| |--+ + +--+ ^ +---+ | + | | + PEn v + +---+ +---+ + c1 ---------------->| |-->|FC1|--> s1 + . | D | +---+ . + . | x | . + . | l | +---+ . + ck ---------------->| |-->|FCk|--> sk + +---+ +---+ + + We opted for this architecture instead of a more traditional [Vision + Transformer][2] to improve performance for multi-label predictions with many + labels. Our experiments have shown that adding too many class tokens to a + vision transformer decreases its performance, as the same weights have to + both process the tiles' information and the class token's processing. Using + an encoder-decoder architecture alleviates these issues, as the data-flow of + the class tokens is completely independent of the encoding of the tiles. + Furthermore, analysis has shown that there is almost no interaction between + the different classes in the decoder. While this points to the decoder + being more powerful than needed in practice, this also means that each + label's prediction is mostly independent of the others. As a consequence, + noisy labels will not negatively impact the accuracy of non-noisy ones. + + In our experiments so far we did not see any improvement by adding + positional encodings. We tried + + 1. [Sinusodal encodings][1] + 2. Adding absolute positions to the feature vector, scaled down so the + maximum value in the training dataset is 1. + + Since neither reduced performance and the author percieves the first one to + be more elegant (as the magnitude of the positional encodings is bounded), + we opted to keep the positional encoding regardless in the hopes of it + improving performance on future tasks. + + The architecture _differs_ from the one descibed in [Attention Is All You + Need][1] as follows: + + 1. There is an initial projection stage to reduce the dimension of the + feature vectors and allow us to use the transformer with arbitrary + features. + 2. Instead of the language translation task described in [Attention Is All + You Need][1], where the tokens of the words translated so far are used + to predict the next word in the sequence, we use a set of fixed, learned + class tokens in conjunction with equally as many independent fully + connected layers to predict multiple labels at once. + + [1]: https://arxiv.org/abs/1706.03762 "Attention Is All You Need" + [2]: https://arxiv.org/abs/2010.11929 + "An Image is Worth 16x16 Words: + Transformers for Image Recognition at Scale" + """ + + def __init__( + self, + d_features: int, + target_n_outs: dict[str, int], + *, + d_model: int = 512, + num_encoder_heads: int = 8, + num_decoder_heads: int = 8, + num_encoder_layers: int = 2, + num_decoder_layers: int = 2, + dim_feedforward: int = 2048, + positional_encoding: bool = True, + ) -> None: + super().__init__() + + self.projector = nn.Sequential(nn.Linear(d_features, d_model), nn.ReLU()) + + encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=num_encoder_heads, + dim_feedforward=dim_feedforward, + batch_first=True, + norm_first=True, + ) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layer, num_layers=num_encoder_layers, enable_nested_tensor=False + ) + + self.target_labels = target_n_outs.keys() + + # One class token per output label + self.class_tokens = nn.ParameterDict( + { + sanitize(target_label): torch.rand(d_model) + for target_label in target_n_outs + } + ) + + decoder_layer = nn.TransformerDecoderLayer( + d_model=d_model, + nhead=num_decoder_heads, + dim_feedforward=dim_feedforward, + batch_first=True, + norm_first=True, + ) + self.transformer_decoder = nn.TransformerDecoder( + decoder_layer, num_layers=num_decoder_layers + ) + + self.heads = nn.ModuleDict( + { + sanitize(target_label): nn.Linear( + in_features=d_model, out_features=n_out + ) + for target_label, n_out in target_n_outs.items() + } + ) + + self.positional_encoding = positional_encoding + + def forward( + self, + tile_tokens: torch.Tensor, + tile_positions: torch.Tensor, + ) -> dict[str, torch.Tensor]: + batch_size, _, _ = tile_tokens.shape + + tile_tokens = self.projector(tile_tokens) # shape: [bs, seq_len, d_model] + + if self.positional_encoding: + # Add positional encodings + d_model = tile_tokens.size(-1) + x = tile_positions.unsqueeze(-1) / 100_000 ** ( + torch.arange(d_model // 4).type_as(tile_positions) / d_model + ) + positional_encodings = torch.cat( + [ + torch.sin(x).flatten(start_dim=-2), + torch.cos(x).flatten(start_dim=-2), + ], + dim=-1, + ) + tile_tokens = tile_tokens + positional_encodings + + tile_tokens = self.transformer_encoder(tile_tokens) + + class_tokens = torch.stack( + [self.class_tokens[sanitize(t)] for t in self.target_labels] + ).expand(batch_size, -1, -1) + class_tokens = self.transformer_decoder(tgt=class_tokens, memory=tile_tokens) + + # Apply the corresponding head to each class token + logits = { + target_label: self.heads[sanitize(target_label)](class_token) + for target_label, class_token in zip( + self.target_labels, + class_tokens.permute(1, 0, 2), # Permute to [target, batch, d_model] + strict=True, + ) + } + + return logits + + +class LitMilClassificationMixin(lightning.LightningModule): + """Makes a module into a multilabel, multiclass Lightning one""" + + supported_features = ["tile"] + + def __init__( + self, + *, + weights: dict[TargetLabel, torch.Tensor], + # Other hparams + learning_rate: float = 1e-4, + stamp_version: Version = Version(stamp.__version__), + **hparams: Any, + ) -> None: + super().__init__() + _ = hparams # So we don't get unused parameter warnings + + # Check if version is compatible. + if stamp_version < Version("2.4.0"): + # Update this as we change our model in incompatible ways! + raise ValueError( + f"model has been built with stamp version {stamp_version} " + f"which is incompatible with the current version." + ) + elif stamp_version > Version(stamp.__version__): + # Let's be strict with models "from the future", + # better fail deadly than have broken results. + raise ValueError( + "model has been built with a stamp version newer than the installed one " + f"({stamp_version} > {stamp.__version__}). " + "Please upgrade stamp to a compatible version." + ) + + self.hparams.update({"task": "classification"}) + + self.learning_rate = learning_rate + + target_aurocs = torchmetrics.MetricCollection( + { + sanitize(target_label): SafeMulticlassAUROC(num_classes=len(weight)) + for target_label, weight in weights.items() + } + ) + for step_name in ["train", "validation", "test"]: + setattr( + self, + f"{step_name}_target_aurocs", + target_aurocs.clone(prefix=f"{step_name}_"), + ) + + self.weights = weights + + self.save_hyperparameters() + + def step( + self, + batch: tuple[Bags, CoordinatesBatch, BagSizes, dict[str, torch.Tensor]], + step_name=None, + ): + """Process a batch with structure (feats, coords, bag_sizes, targets). + + Args: + batch: Tuple of (feats, coords, bag_sizes, targets) where: + - feats: bag features [batch, bag_size, feature_dim] + - coords: tile coordinates [batch, bag_size, 2] + - bag_sizes: number of tiles per bag [batch] + - targets: dict mapping target names to one-hot encoded tensors [batch, num_classes] + step_name: Optional step name for logging ('train', 'validation', 'test'). + """ + feats: Bags + coords: CoordinatesBatch + bag_sizes: BagSizes + targets: dict[str, torch.Tensor] + feats, coords, bag_sizes, targets = batch + logits = self(feats, coords) + + # Calculate the cross entropy loss for each target, then sum them + loss = sum( + F.cross_entropy( + (logit := logits[target_label]), + targets[target_label].type_as(logit), + weight=weight.type_as(logit), + ) + for target_label, weight in self.weights.items() + ) + + if step_name: + self.log( + f"{step_name}_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + + # Update target-wise metrics + for target_label in self.weights: + target_auroc = getattr(self, f"{step_name}_target_aurocs")[ + sanitize(target_label) + ] + is_na = (targets[target_label] == 0).all(dim=1) + target_auroc.update( + logits[target_label][~is_na], + targets[target_label][~is_na].argmax(dim=1), + ) + self.log( + f"{step_name}_{target_label}_auroc", + target_auroc, + on_step=False, + on_epoch=True, + sync_dist=True, + ) + + return loss + + def training_step(self, batch, batch_idx): # pyright: ignore[reportIncompatibleMethodOverride] + return self.step(batch, step_name="train") + + def validation_step(self, batch, batch_idx): # pyright: ignore[reportIncompatibleMethodOverride] + return self.step(batch, step_name="validation") + + def test_step(self, batch, batch_idx): # pyright: ignore[reportIncompatibleMethodOverride] + return self.step(batch, step_name="test") + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + if len(batch) == 2: + feats, positions = batch + else: + feats, positions, _, _ = batch + + logits = self(feats, positions) + + softmaxed = { + target_label: torch.softmax(x, 1) for target_label, x in logits.items() + } + return softmaxed + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + return optimizer + + +def sanitize(x: str) -> str: + return re.sub(r"[^A-Za-z0-9_]", "_", x) + + +class SafeMulticlassAUROC(MulticlassAUROC): + """A Multiclass AUROC that doesn't blow up when no targets are given""" + + def compute(self) -> torch.Tensor: + # Add faux entry if there are none so far + if len(self.preds) == 0: + self.update(torch.zeros(1, self.num_classes), torch.zeros(1).long()) + elif len(dim_zero_cat(self.preds)) == 0: + self.update( + torch.zeros(1, self.num_classes).type_as(self.preds[0]), + torch.zeros(1).long().type_as(self.target[0]), + ) + return super().compute() diff --git a/src/stamp/modeling/registry.py b/src/stamp/modeling/registry.py index 2205af22..14011084 100644 --- a/src/stamp/modeling/registry.py +++ b/src/stamp/modeling/registry.py @@ -1,6 +1,7 @@ from enum import StrEnum from stamp.modeling.models import ( + LitEncDecTransformer, LitPatientClassifier, LitPatientRegressor, LitPatientSurvival, @@ -21,6 +22,7 @@ class ModelName(StrEnum): MLP = "mlp" TRANS_MIL = "trans_mil" LINEAR = "linear" + BARSPOON = "barspoon" # Map (feature_type, task) → correct Lightning wrapper class @@ -34,6 +36,7 @@ class ModelName(StrEnum): ("patient", "classification"): LitPatientClassifier, ("patient", "regression"): LitPatientRegressor, ("patient", "survival"): LitPatientSurvival, + # ("tile", "multiclass"): LitEncDecTransformer, } @@ -54,6 +57,13 @@ def load_model_class(task: Task, feature_type: str, model_name: ModelName): case ModelName.MLP: from stamp.modeling.models.mlp import MLP as ModelClass + case ModelName.BARSPOON: + from stamp.modeling.models.barspoon import ( + EncDecTransformer as ModelClass, + ) + + LitModelClass = LitEncDecTransformer + case ModelName.LINEAR: from stamp.modeling.models.mlp import ( Linear as ModelClass, diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index b855e2a0..fb2bdb3b 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -2,7 +2,7 @@ import shutil from collections.abc import Callable, Mapping, Sequence from pathlib import Path -from typing import cast +from typing import Any, cast import lightning import torch @@ -18,13 +18,7 @@ PatientData, PatientFeatureDataset, create_dataloader, - detect_feature_type, - filter_complete_patient_data_, - load_patient_level_data, - log_patient_class_summary, - patient_to_ground_truth_from_clini_table_, - patient_to_survival_from_clini_table_, - slide_to_patient_from_slide_table_, + load_patient_data_, ) from stamp.modeling.registry import ModelName, load_model_class from stamp.modeling.transforms import VaryPrecisionTransform @@ -53,66 +47,25 @@ def train_categorical_model_( advanced: AdvancedConfig, ) -> None: """Trains a model based on the feature type.""" - feature_type = detect_feature_type(config.feature_dir) - _logger.info(f"Detected feature type: {feature_type}") - - if feature_type in ("tile", "slide"): - if config.slide_table is None: - raise ValueError("A slide table is required for modeling") - if config.task == "survival": - if config.time_label is None or config.status_label is None: - raise ValueError( - "Both time_label and status_label is required for survival modeling" - ) - patient_to_ground_truth = patient_to_survival_from_clini_table_( - clini_table_path=config.clini_table, - time_label=config.time_label, - status_label=config.status_label, - patient_label=config.patient_label, - ) - else: - if config.ground_truth_label is None: - raise ValueError( - "Ground truth label is required for tile-level modeling" - ) - patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( - clini_table_path=config.clini_table, - ground_truth_label=config.ground_truth_label, - patient_label=config.patient_label, - ) - slide_to_patient = slide_to_patient_from_slide_table_( - slide_table_path=config.slide_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - filename_label=config.filename_label, - ) - patient_to_data = filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, - slide_to_patient=slide_to_patient, - drop_patients_with_missing_ground_truth=True, - ) - elif feature_type == "patient": - # Patient-level: ignore slide_table - if config.slide_table is not None: - _logger.warning("slide_table is ignored for patient-level features.") - - patient_to_data = load_patient_level_data( - task=config.task, - clini_table=config.clini_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - ground_truth_label=config.ground_truth_label, - time_label=config.time_label, - status_label=config.status_label, - ) - else: - raise RuntimeError(f"Unknown feature type: {feature_type}") - if config.task is None: raise ValueError( "task must be set to 'classification' | 'regression' | 'survival'" ) + patient_to_data, feature_type = load_patient_data_( + feature_dir=config.feature_dir, + clini_table=config.clini_table, + slide_table=config.slide_table, + task=config.task, + ground_truth_label=config.ground_truth_label, + time_label=config.time_label, + status_label=config.status_label, + patient_label=config.patient_label, + filename_label=config.filename_label, + drop_patients_with_missing_ground_truth=True, + ) + _logger.info(f"Detected feature type: {feature_type}") + # Train the model (the rest of the logic is unchanged) model, train_dl, valid_dl = setup_model_for_training( patient_to_data=patient_to_data, @@ -145,14 +98,14 @@ def train_categorical_model_( def setup_model_for_training( *, - patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], + patient_to_data: Mapping[PatientId, PatientData[GroundTruth | None | dict]], task: Task, categories: Sequence[Category] | None, train_transform: Callable[[torch.Tensor], torch.Tensor] | None, feature_type: str, advanced: AdvancedConfig, # Metadata, has no effect on model training - ground_truth_label: PandasLabel | None, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, time_label: PandasLabel | None, status_label: PandasLabel | None, clini_table: Path, @@ -193,10 +146,6 @@ def setup_model_for_training( feature_type=feature_type, train_categories=train_categories, ) - log_patient_class_summary( - patient_to_data={pid: patient_to_data[pid] for pid in patient_to_data}, - categories=categories, - ) # 1. Default to a model if none is specified if advanced.model_name is None: @@ -209,6 +158,7 @@ def setup_model_for_training( LitModelClass, ModelClass = load_model_class( task, feature_type, advanced.model_name ) + print(f"Using Lightning wrapper class: {LitModelClass}") # 3. Validate that the chosen model supports the feature type if feature_type not in LitModelClass.supported_features: @@ -272,7 +222,7 @@ def setup_model_for_training( def setup_dataloaders_for_training( *, - patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], + patient_to_data: Mapping[PatientId, PatientData[GroundTruth | None | dict]], task: Task, categories: Sequence[Category] | None, bag_size: int, @@ -283,7 +233,7 @@ def setup_dataloaders_for_training( ) -> tuple[ DataLoader, DataLoader, - Sequence[Category], + Sequence[Category] | Mapping[str, Sequence[Category]], int, Sequence[PatientId], Sequence[PatientId], @@ -310,10 +260,25 @@ def setup_dataloaders_for_training( ) if task == "classification": - stratify = ground_truths + # Handle both single and multi-target cases + if ground_truths and isinstance(ground_truths[0], dict): + # Multi-target: use first target for stratification + first_key = list(ground_truths[0].keys())[0] + stratify = [cast(dict, gt)[first_key] for gt in ground_truths] + else: + stratify = ground_truths elif task == "survival": - # Extract event indicator (status) - statuses = [int(gt.split()[1]) for gt in ground_truths] + # Extract event indicator (status) - handle both single and multi-target + statuses = [] + for gt in ground_truths: + if isinstance(gt, dict): + # Multi-target survival: extract from first target + first_key = list(gt.keys())[0] + val = cast(dict, gt)[first_key] + if val: + statuses.append(int(val.split()[1])) + else: + statuses.append(int(gt.split()[1])) stratify = statuses elif task == "regression": stratify = None @@ -321,7 +286,10 @@ def setup_dataloaders_for_training( train_patients, valid_patients = cast( tuple[Sequence[PatientId], Sequence[PatientId]], train_test_split( - list(patient_to_data), stratify=stratify, shuffle=True, random_state=0 + list(patient_to_data), + stratify=cast(Any, stratify), + shuffle=True, + random_state=0, ), ) @@ -441,28 +409,50 @@ def _compute_class_weights_and_check_categories( *, train_dl: DataLoader, feature_type: str, - train_categories: Sequence[str], -) -> torch.Tensor: + train_categories: Sequence[str] | Mapping[str, Sequence[str]], +) -> torch.Tensor | dict[str, torch.Tensor]: """ Computes class weights and checks for category issues. Logs warnings if there are too few or underpopulated categories. Returns normalized category weights as a torch.Tensor. """ if feature_type == "tile": - category_counts = cast(BagDataset, train_dl.dataset).ground_truths.sum(dim=0) + dataset = cast(BagDataset, train_dl.dataset) + + if isinstance(dataset.ground_truths, list): + # Multi-target case: compute weights per target head + weights_per_target: dict[str, torch.Tensor] = {} + + target_keys = dataset.ground_truths[0].keys() + + for key in target_keys: + stacked = torch.stack([gt[key] for gt in dataset.ground_truths], dim=0) + counts = stacked.sum(dim=0) + w = counts.sum() / counts + weights_per_target[key] = w / w.sum() + + return weights_per_target + else: + category_counts = dataset.ground_truths.sum(dim=0) else: - category_counts = cast( - PatientFeatureDataset, train_dl.dataset - ).ground_truths.sum(dim=0) + dataset = cast(PatientFeatureDataset, train_dl.dataset) + category_counts = dataset.ground_truths.sum(dim=0) cat_ratio_reciprocal = category_counts.sum() / category_counts category_weights = cat_ratio_reciprocal / cat_ratio_reciprocal.sum() if len(train_categories) <= 1: raise ValueError(f"not enough categories to train on: {train_categories}") - elif any(category_counts < 16): + elif (category_counts < 16).any(): + category_counts_list = ( + category_counts.tolist() + if category_counts.dim() > 0 + else [category_counts.item()] + ) underpopulated_categories = { category: int(count) - for category, count in zip(train_categories, category_counts, strict=True) + for category, count in zip( + train_categories, category_counts_list, strict=True + ) if count < 16 } _logger.warning( diff --git a/src/stamp/preprocessing/__init__.py b/src/stamp/preprocessing/__init__.py index a1844526..84cb48e9 100755 --- a/src/stamp/preprocessing/__init__.py +++ b/src/stamp/preprocessing/__init__.py @@ -16,7 +16,6 @@ from tqdm import tqdm import stamp -from stamp.cache import get_processing_code_hash from stamp.preprocessing.config import ExtractorName from stamp.preprocessing.extractor import Extractor from stamp.preprocessing.tiling import ( @@ -32,6 +31,7 @@ SlidePixels, TilePixels, ) +from stamp.utils.cache import get_processing_code_hash __author__ = "Marko van Treeck" __copyright__ = "Copyright (C) 2022-2024 Marko van Treeck" diff --git a/src/stamp/preprocessing/extractor/chief_ctranspath.py b/src/stamp/preprocessing/extractor/chief_ctranspath.py index 2d2e6b9b..03f5bba6 100644 --- a/src/stamp/preprocessing/extractor/chief_ctranspath.py +++ b/src/stamp/preprocessing/extractor/chief_ctranspath.py @@ -1,6 +1,6 @@ from pathlib import Path -from stamp.cache import STAMP_CACHE_DIR, file_digest +from stamp.utils.cache import STAMP_CACHE_DIR, file_digest try: import gdown diff --git a/src/stamp/preprocessing/extractor/ctranspath.py b/src/stamp/preprocessing/extractor/ctranspath.py index ba9a277c..387e7947 100644 --- a/src/stamp/preprocessing/extractor/ctranspath.py +++ b/src/stamp/preprocessing/extractor/ctranspath.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Optional, TypeVar, cast -from stamp.cache import STAMP_CACHE_DIR, file_digest +from stamp.utils.cache import STAMP_CACHE_DIR, file_digest try: import gdown diff --git a/src/stamp/preprocessing/extractor/dinobloom.py b/src/stamp/preprocessing/extractor/dinobloom.py index fb7713ce..54b79c53 100644 --- a/src/stamp/preprocessing/extractor/dinobloom.py +++ b/src/stamp/preprocessing/extractor/dinobloom.py @@ -8,9 +8,9 @@ from torch import nn from torchvision import transforms -from stamp.cache import STAMP_CACHE_DIR from stamp.preprocessing.config import ExtractorName from stamp.preprocessing.extractor import Extractor +from stamp.utils.cache import STAMP_CACHE_DIR __author__ = "Marko van Treeck" __copyright__ = "Copyright (C) 2022-2025 Marko van Treeck" diff --git a/src/stamp/preprocessing/tiling.py b/src/stamp/preprocessing/tiling.py index 82a3efba..e09a06fa 100644 --- a/src/stamp/preprocessing/tiling.py +++ b/src/stamp/preprocessing/tiling.py @@ -461,7 +461,7 @@ def _extract_mpp_from_metadata(slide: openslide.AbstractSlide) -> SlideMPP | Non return None doc = minidom.parseString(xml_path) collection = doc.documentElement - images = collection.getElementsByTagName("Image") + images = collection.getElementsByTagName("Image") # pyright: ignore[reportOptionalMemberAccess] pixels = images[0].getElementsByTagName("Pixels") mpp = float(pixels[0].getAttribute("PhysicalSizeX")) except Exception: diff --git a/src/stamp/statistics/__init__.py b/src/stamp/statistics/__init__.py index ec09e1e0..bdbef1fa 100644 --- a/src/stamp/statistics/__init__.py +++ b/src/stamp/statistics/__init__.py @@ -1,3 +1,11 @@ +"""Statistics utilities (wrappers) for classification, regression and survival. + +This module provides a small, stable wrapper `compute_stats_` that dispatches +to the task-specific statistic implementations found in the submodules. +""" + +from __future__ import annotations + from collections.abc import Sequence from pathlib import Path from typing import NewType @@ -17,23 +25,25 @@ plot_multiple_decorated_roc_curves, plot_single_decorated_roc_curve, ) -from stamp.statistics.survival import ( - _plot_km, - _survival_stats_for_csv, -) +from stamp.statistics.survival import _plot_km, _survival_stats_for_csv from stamp.types import PandasLabel, Task +__all__ = ["StatsConfig", "compute_stats_"] + + __author__ = "Marko van Treeck, Minh Duc Nguyen" __copyright__ = "Copyright (C) 2022-2024 Marko van Treeck, Minh Duc Nguyen" __license__ = "MIT" def _read_table(file: Path, **kwargs) -> pd.DataFrame: - """Loads a dataframe from a file.""" + """Load a dataframe from CSV or XLSX file path. + + This small helper centralizes file IO formatting and keeps callers simple. + """ if isinstance(file, Path) and file.suffix == ".xlsx": return pd.read_excel(file, **kwargs) - else: - return pd.read_csv(file, **kwargs) + return pd.read_csv(file, **kwargs) class StatsConfig(BaseModel): @@ -60,6 +70,11 @@ def compute_stats_( time_label: str | None = None, status_label: str | None = None, ) -> None: + """Compute and save statistics for the provided task and prediction CSVs. + + This wrapper keeps the external API stable while delegating the detailed + computations and plotting to the submodules under `stamp.statistics.*`. + """ match task: case "classification": if true_class is None or ground_truth_label is None: @@ -105,7 +120,6 @@ def compute_stats_( n_bootstrap_samples=n_bootstrap_samples, threshold_cmap=threshold_cmap, ) - else: plot_multiple_decorated_roc_curves( ax=ax, @@ -116,9 +130,7 @@ def compute_stats_( ) fig.tight_layout() - if not output_dir.exists(): - output_dir.mkdir(parents=True, exist_ok=True) - + output_dir.mkdir(parents=True, exist_ok=True) fig.savefig(output_dir / f"roc-curve_{ground_truth_label}={true_class}.svg") plt.close(fig) @@ -134,7 +146,6 @@ def compute_stats_( title=f"{ground_truth_label} = {true_class}", n_bootstrap_samples=n_bootstrap_samples, ) - else: plot_multiple_decorated_precision_recall_curves( ax=ax, @@ -203,12 +214,7 @@ def compute_stats_( cut_off=cut_off, ) - # ------------------------------------------------------------------ # # Save individual and aggregated CSVs - # ------------------------------------------------------------------ # stats_df = pd.DataFrame(per_fold).transpose() stats_df.index.name = "fold_name" # label the index column stats_df.to_csv(output_dir / "survival-stats_individual.csv", index=True) - - # agg_df = _aggregate_with_ci(stats_df) - # agg_df.to_csv(output_dir / "survival-stats_aggregated.csv", index=True) diff --git a/src/stamp/statistics/survival.py b/src/stamp/statistics/survival.py index 063793cf..7c298a54 100644 --- a/src/stamp/statistics/survival.py +++ b/src/stamp/statistics/survival.py @@ -46,7 +46,7 @@ def _survival_stats_for_csv( if risk_label is None: risk_label = "pred_score" - # --- Clean NaNs and invalid events before computing stats --- + # Clean NaNs and invalid events before computing stats df = df.dropna(subset=[time_label, status_label, risk_label]).copy() df = df[df[status_label].isin([0, 1])] if len(df) == 0: @@ -56,10 +56,10 @@ def _survival_stats_for_csv( event = np.asarray(df[status_label], dtype=int) risk = np.asarray(df[risk_label], dtype=float) - # --- Concordance index --- + # Concordance index c_index, n_pairs = _cindex(time, event, risk) - # --- Log-rank test (median split) --- + # Log-rank test (median split) median_risk = float(cut_off) if cut_off is not None else float(np.nanmedian(risk)) low_mask = risk <= median_risk high_mask = risk > median_risk @@ -101,7 +101,7 @@ def _plot_km( if risk_label is None: risk_label = "pred_score" - # --- Clean NaNs and invalid entries --- + # Clean NaNs and invalid entries df = df.replace(["NaN", "nan", "None", "Inf", "inf"], np.nan) df = df.dropna(subset=[time_label, status_label, risk_label]).copy() df = df[df[status_label].isin([0, 1])] @@ -113,7 +113,7 @@ def _plot_km( event = np.asarray(df[status_label], dtype=int) risk = np.asarray(df[risk_label], dtype=float) - # --- split groups --- + # split groups median_risk = float(cut_off) if cut_off is not None else np.nanmedian(risk) low_mask = risk <= median_risk high_mask = risk > median_risk @@ -138,7 +138,7 @@ def _plot_km( add_at_risk_counts(kmf_low, kmf_high, ax=ax) - # --- log-rank and c-index --- + # log-rank and c-index res = logrank_test( low_df[time_label], high_df[time_label], diff --git a/src/stamp/types.py b/src/stamp/types.py index f1f571cc..c1ff6873 100644 --- a/src/stamp/types.py +++ b/src/stamp/types.py @@ -37,6 +37,7 @@ PatientId: TypeAlias = str GroundTruth: TypeAlias = str +MultiClassGroundTruth: TypeAlias = tuple[str, ...] FeaturePath = NewType("FeaturePath", Path) Category: TypeAlias = str diff --git a/src/stamp/cache.py b/src/stamp/utils/cache.py similarity index 100% rename from src/stamp/cache.py rename to src/stamp/utils/cache.py diff --git a/src/stamp/config.py b/src/stamp/utils/config.py similarity index 100% rename from src/stamp/config.py rename to src/stamp/utils/config.py diff --git a/src/stamp/seed.py b/src/stamp/utils/seed.py similarity index 100% rename from src/stamp/seed.py rename to src/stamp/utils/seed.py diff --git a/src/stamp/utils/target_file.py b/src/stamp/utils/target_file.py new file mode 100644 index 00000000..08c30b6b --- /dev/null +++ b/src/stamp/utils/target_file.py @@ -0,0 +1,351 @@ +"""Automatically generate target information from clini table + +# The `barspoon-targets 2.0` File Format + +A barspoon target file is a [TOML][1] file with the following entries: + + - A `version` key mapping to a version string `"barspoon-targets "`, where + `` is a [PEP-440 version string][2] compatible with `2.0`. + - A `targets` table, the keys of which are target labels (as found in the + clinical table) and the values specify exactly one of the following: + 1. A categorical target label, marked by the presence of a `categories` + key-value pair. + 2. A target label to quantize, marked by the presence of a `thresholds` + key-value pair. + 3. A target format defined in in a later version of barspoon targets. + A target may only ever have one of the fields `categories` or `thresholds`. + A definition of these entries can be found below. + +[1]: https://toml.io "Tom's Obvious Minimal Language" +[2]: https://peps.python.org/pep-0440/ + "PEP 440 - Version Identification and Dependency Specification" + +## Categorical Target Label + +A categorical target is a target table with a key-value pair `categories`. +`categories` contains a list of lists of literal strings. Each list of strings +will be treated as one category, with all literal strings within that list being +treated as one representative for that category. This allows the user to easily +group related classes into one large class (i.e. `"True", "1", "Yes"` could all +be unified into the same category). + +### Category Weights + +It is possible to assign a weight to each category, to e.g. weigh rarer classes +more heavily. The weights are stored in a table `targets.LABEL.class_weights`, +whose keys is the first representative of each category, and the values of which +is the weight of the category as a floating point number. + +## Target Label to Quantize + +If a target has the `thresholds` option key set, it is interpreted as a +continuous target which has to be quantized. `thresholds` has to be a list of +floating point numbers [t_0, t_n], n > 1 containing the thresholds of the bins +to quantize the values into. A categorical target will be quantized into bins + +```asciimath +b_0 = [t_0; t_1], b_1 = (t_1; b_2], ... b_(n-1) = (t_(n-1); t_n] +``` + +The bins will be treated as categories with names +`f"[{t_0:+1.2e};{t_1:+1.2e}]"` for the first bin and +`f"({t_i:+1.2e};{t_(i+1):+1.2e}]"` for all other bins + +To avoid confusion, we recommend to also format the `thresholds` list the same +way. + +The bins can also be weighted. See _Categorical Target Label: Category Weights_ +for details. + + > Experience has shown that many labels contain non-negative values with a + > disproportionate amount (more than n_samples/n_bins) of zeroes. We thus + > decided to make the _right_ side of each bin inclusive, as the bin (-A,0] + > then naturally includes those zero values. +""" + +import logging +from pathlib import Path +from typing import ( + Any, + Dict, + List, + NamedTuple, + Optional, + Sequence, + TextIO, + Tuple, +) + +import numpy as np +import numpy.typing as npt +import pandas as pd +import torch +import torch.nn.functional as F +from packaging.specifiers import Specifier + + +def read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: + if not isinstance(path, Path): + return pd.read_csv(path, **kwargs) + elif path.suffix == ".xlsx": + return pd.read_excel(path, **kwargs) + elif path.suffix == ".csv": + return pd.read_csv(path, **kwargs) + else: + raise ValueError( + "table to load has to either be an excel (`*.xlsx`) or csv (`*.csv`) file." + ) + + +__all__ = ["build_targets", "decode_targets"] + + +class TargetSpec(NamedTuple): + version: str + targets: Dict[str, Dict[str, Any]] + + +class EncodedTarget(NamedTuple): + categories: List[str] + encoded: torch.Tensor + weight: torch.Tensor + + +def encode_category( + *, + clini_df: pd.DataFrame, + target_label: str, + categories: Sequence[List[str]], + class_weights: Optional[Dict[str, float]] = None, + **ignored, +) -> Tuple[List[str], torch.Tensor, torch.Tensor]: + # Map each category to its index + category_map = {member: idx for idx, cat in enumerate(categories) for member in cat} + + # Map each item to it's category's index, mapping nans to num_classes+1 + # This way we can easily discard the NaN column later + indexes = clini_df[target_label].map(lambda c: category_map.get(c, len(categories))) + indexes = torch.tensor(indexes.values) + + # Discard nan column + one_hot = F.one_hot(indexes, num_classes=len(categories) + 1)[:, :-1] + + # Class weights + if class_weights is not None: + weight = torch.tensor([class_weights[c[0]] for c in categories]) + else: + # No class weights given; use normalized inverse frequency + counts = one_hot.sum(dim=0) + weight = (w := (counts.sum() / counts)) / w.sum() + + # Warn user of unused labels + if ignored: + logging.warn(f"ignored labels in target {target_label}: {ignored}") + + return [c[0] for c in categories], one_hot, weight + + +def encode_quantize( + *, + clini_df: pd.DataFrame, + target_label: str, + thresholds: npt.NDArray[np.floating[Any]], + class_weights: Optional[Dict[str, float]] = None, + **ignored, +) -> Tuple[List[str], torch.Tensor, torch.Tensor]: + # Warn user of unused labels + if ignored: + logging.warn(f"ignored labels in target {target_label}: {ignored}") + + n_bins = len(thresholds) - 1 + numeric_vals = torch.tensor(pd.to_numeric(clini_df[target_label]).values).reshape( + -1, 1 + ) + + # Map each value to a class index as follows: + # 1. If the value is NaN or less than the left-most threshold, use class + # index 0 + # 2. If it is between the left-most and the right-most threshold, set it to + # the bin number (starting from 1) + # 3. If it is larger than the right-most threshold, set it to N_bins + 1 + bin_index = ( + (numeric_vals > torch.tensor(thresholds).reshape(1, -1)).count_nonzero(1) + # For the first bucket, we have to include the lower threshold + + (numeric_vals.reshape(-1) == thresholds[0]) + ) + # One hot encode and discard nan columns (first and last col) + one_hot = F.one_hot(bin_index, num_classes=n_bins + 2)[:, 1:-1] + + # Class weights + categories = [ + f"[{thresholds[0]:+1.2e};{thresholds[1]:+1.2e}]", + *( + f"({lower:+1.2e};{upper:+1.2e}]" + for lower, upper in zip(thresholds[1:-1], thresholds[2:], strict=True) + ), + ] + + if class_weights is not None: + weight = torch.tensor([class_weights[c] for c in categories]) + else: + # No class weights given; use normalized inverse frequency + counts = one_hot.sum(0) + weight = (w := (np.divide(counts.sum(), counts, where=counts > 0))) / w.sum() + + return categories, one_hot, weight + + +def decode_targets( + encoded: torch.Tensor, + *, + target_labels: Sequence[str], + targets: Dict[str, Any], + version: str = "barspoon-targets 2.0", + **ignored, +) -> List[np.ndarray]: + name, version = version.split(" ") + spec = Specifier("~=2.0") + + if not (name == "barspoon-targets" and spec.contains(version)): + raise ValueError( + f"incompatible target file: expected barspoon-targets{spec}, found `{name} {version}`" + ) + + # Warn user of unused labels + if ignored: + logging.warn(f"ignored parameters: {ignored}") + + decoded_targets = [] + curr_col = 0 + for target_label in target_labels: + info = targets[target_label] + + if (categories := info.get("categories")) is not None: + # Add another column which is one iff all the other values are zero + encoded_target = encoded[:, curr_col : curr_col + len(categories)] + is_none = ~encoded_target.any(dim=1).view(-1, 1) + encoded_target = torch.cat([encoded_target, is_none], dim=1) + + # Decode to class labels + representatives = np.array([c[0] for c in categories] + [None]) + category_index = encoded_target.argmax(dim=1) + decoded = representatives[category_index] + decoded_targets.append(decoded) + + curr_col += len(categories) + + elif (thresholds := info.get("thresholds")) is not None: + n_bins = len(thresholds) - 1 + encoded_target = encoded[:, curr_col : curr_col + n_bins] + is_none = ~encoded_target.any(dim=1).view(-1, 1) + encoded_target = torch.cat([encoded_target, is_none], dim=1) + + bin_edges = [-np.inf, *thresholds, np.inf] + representatives = np.array( + [ + f"[{lower:+1.2e};{upper:+1.2e})" + for lower, upper in zip(bin_edges[:-1], bin_edges[1:]) + ] + ) + decoded = representatives[encoded_target.argmax(dim=1)] + + decoded_targets.append(decoded) + + curr_col += n_bins + + else: + raise ValueError(f"cannot decode {target_label}: no target info") + + return decoded_targets + + +def build_targets( + *, + clini_tables: Sequence[Path], + categorical_labels: Sequence[str], + category_min_count: int = 32, + quantize: Sequence[tuple[str, int]] = (), +) -> Dict[str, EncodedTarget]: + clini_df = pd.concat([read_table(c) for c in clini_tables]) + encoded_targets: Dict[str, EncodedTarget] = {} + + # categorical targets + for target_label in categorical_labels: + counts = clini_df[target_label].value_counts() + well_supported = counts[counts >= category_min_count] + + if len(well_supported) <= 1: + continue + + categories = [[str(cat)] for cat in well_supported.index] + + weights = well_supported.sum() / well_supported + weights /= weights.sum() + + representatives, encoded, weight = encode_category( + clini_df=clini_df, + target_label=target_label, + categories=categories, + class_weights=weights.to_dict(), + ) + + encoded_targets[target_label] = EncodedTarget( + categories=representatives, + encoded=encoded, + weight=weight, + ) + + # quantized targets + for target_label, bincount in quantize: + vals = pd.to_numeric(clini_df[target_label]).dropna() + + if vals.empty: + continue + + vals_clamped = vals.replace( + { + -np.inf: vals[vals != -np.inf].min(), + np.inf: vals[vals != np.inf].max(), + } + ) + + thresholds = np.array( + [ + -np.inf, + *np.quantile(vals_clamped, q=np.linspace(0, 1, bincount + 1))[1:-1], + np.inf, + ], + dtype=float, + ) + + representatives, encoded, weight = encode_quantize( + clini_df=clini_df, + target_label=target_label, + thresholds=thresholds, + ) + + if encoded.shape[1] <= 1: + continue + + encoded_targets[target_label] = EncodedTarget( + categories=representatives, + encoded=encoded, + weight=weight, + ) + + return encoded_targets + + +if __name__ == "__main__": + encoded = build_targets( + clini_tables=[ + Path( + "/mnt/bulk-neptune/nguyenmin/stamp-dev/experiments/survival_prediction/TCGA-CRC-DX_CLINI.xlsx" + ) + ], + categorical_labels=["BRAF", "KRAS", "NRAS"], + category_min_count=32, + quantize=[], + ) + for name, enc in encoded.items(): + print(name, enc.encoded.shape) diff --git a/tests/random_data.py b/tests/random_data.py index bd95d1bc..c7c36880 100644 --- a/tests/random_data.py +++ b/tests/random_data.py @@ -254,6 +254,92 @@ def create_random_patient_level_dataset( return clini_path, slide_path, feat_dir, categories +def create_random_multi_target_dataset( + *, + dir: Path, + n_patients: int, + max_slides_per_patient: int, + min_tiles_per_slide: int, + max_tiles_per_slide: int, + feat_dim: int, + target_labels: Sequence[str], + categories_per_target: Sequence[Sequence[str]], + extractor_name: str = "random-test-generator", + min_slides_per_patient: int = 1, +) -> tuple[Path, Path, Path, Sequence[Sequence[str]]]: + """ + Create a random multi-target tile-level dataset. + + Args: + dir: Directory to create dataset in + n_patients: Number of patients + max_slides_per_patient: Maximum slides per patient + min_tiles_per_slide: Minimum tiles per slide + max_tiles_per_slide: Maximum tiles per slide + feat_dim: Feature dimension + target_labels: Names of the target columns (e.g., ["subtype", "grade"]) + categories_per_target: Categories for each target (e.g., [["A", "B"], ["1", "2", "3"]]) + extractor_name: Name of the extractor + min_slides_per_patient: Minimum slides per patient + + Returns: + Tuple of (clini_path, slide_path, feat_dir, categories_per_target) + """ + if len(target_labels) != len(categories_per_target): + raise ValueError( + "target_labels and categories_per_target must have same length" + ) + + slide_path_to_patient: Mapping[Path, PatientId] = {} + patient_to_ground_truths: Mapping[PatientId, dict[str, str]] = {} + clini_path = dir / "clini.csv" + slide_path = dir / "slide.csv" + + feat_dir = dir / "feats" + feat_dir.mkdir() + + for _ in range(n_patients): + # Random patient ID + patient_id = random_string(16) + + # Generate ground truths for each target + ground_truths = {} + for target_label, categories in zip(target_labels, categories_per_target): + ground_truths[target_label] = random.choice(categories) + + patient_to_ground_truths[patient_id] = ground_truths + + # Generate some slides + for _ in range(random.randint(min_slides_per_patient, max_slides_per_patient)): + slide_path_to_patient[ + create_random_feature_file( + tmp_path=feat_dir, + min_tiles=min_tiles_per_slide, + max_tiles=max_tiles_per_slide, + feat_dim=feat_dim, + extractor_name=extractor_name, + ).relative_to(feat_dir) + ] = patient_id + + # Create clinical table with multiple target columns + clini_data = [] + for patient_id, ground_truths in patient_to_ground_truths.items(): + row = {"patient": patient_id} + row.update(ground_truths) + clini_data.append(row) + + clini_df = pd.DataFrame(clini_data) + clini_df.to_csv(clini_path, index=False) + + slide_df = pd.DataFrame( + slide_path_to_patient.items(), + columns=["slide_path", "patient"], + ) + slide_df.to_csv(slide_path, index=False) + + return clini_path, slide_path, feat_dir, categories_per_target + + def create_random_feature_file( *, tmp_path: Path, diff --git a/tests/test_cache_tiles.py b/tests/test_cache_tiles.py index a665e92c..d9f1411d 100644 --- a/tests/test_cache_tiles.py +++ b/tests/test_cache_tiles.py @@ -6,10 +6,10 @@ import numpy as np import pytest -from stamp.cache import download_file from stamp.preprocessing import Microns, TilePixels from stamp.preprocessing.tiling import _Tile, tiles_with_cache from stamp.types import ImageExtension, SlidePixels +from stamp.utils.cache import download_file def _get_tiles_and_images( diff --git a/tests/test_config.py b/tests/test_config.py index 15b5dd80..fbb53a6c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,7 +1,6 @@ # %% from pathlib import Path -from stamp.config import StampConfig from stamp.heatmaps.config import HeatmapConfig from stamp.modeling.config import ( AdvancedConfig, @@ -21,6 +20,7 @@ TilePixels, ) from stamp.statistics import StatsConfig +from stamp.utils.config import StampConfig def test_config_parsing() -> None: diff --git a/tests/test_crossval.py b/tests/test_crossval.py index 184a5c23..5e53f50c 100644 --- a/tests/test_crossval.py +++ b/tests/test_crossval.py @@ -13,7 +13,7 @@ VitModelParams, ) from stamp.modeling.crossval import categorical_crossval_ -from stamp.seed import Seed +from stamp.utils.seed import Seed @pytest.mark.slow diff --git a/tests/test_data.py b/tests/test_data.py index 564d6426..3c86a931 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -23,7 +23,6 @@ get_coords, slide_to_patient_from_slide_table_, ) -from stamp.seed import Seed from stamp.types import ( BagSize, FeaturePath, @@ -33,6 +32,7 @@ SlideMPP, TilePixels, ) +from stamp.utils.seed import Seed @pytest.mark.filterwarnings("ignore:some patients have no associated slides") diff --git a/tests/test_deployment.py b/tests/test_deployment.py index de20ea12..4e1570cc 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import cast import numpy as np import pytest @@ -24,8 +25,8 @@ ) from stamp.modeling.models.mlp import MLP from stamp.modeling.models.vision_tranformer import VisionTransformer -from stamp.seed import Seed from stamp.types import GroundTruth, PatientId, Task +from stamp.utils.seed import Seed def test_predict_patient_level( @@ -83,7 +84,8 @@ def test_predict_patient_level( assert len(predictions) == len(patient_to_data) for pid in patient_ids: - assert predictions[pid].shape == torch.Size([3]), "expected one score per class" + pred = cast(torch.Tensor, predictions[pid]) + assert pred.shape == torch.Size([3]), "expected one score per class" # Check if scores are consistent between runs and different for different patients more_patient_ids = [PatientId(f"pat{i}") for i in range(8, 11)] @@ -124,11 +126,13 @@ def test_predict_patient_level( assert len(more_predictions) == len(all_patient_ids) # Different patients should give different results assert not torch.allclose( - more_predictions[more_patient_ids[0]], more_predictions[more_patient_ids[1]] + cast(torch.Tensor, more_predictions[more_patient_ids[0]]), + cast(torch.Tensor, more_predictions[more_patient_ids[1]]), ), "different inputs should give different results" # The same patient should yield the same result assert torch.allclose( - predictions[patient_ids[0]], more_predictions[patient_ids[0]] + cast(torch.Tensor, predictions[patient_ids[0]]), + cast(torch.Tensor, more_predictions[patient_ids[0]]), ), "the same inputs should repeatedly yield the same results" @@ -295,7 +299,7 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: div_factor=25.0, ) - # ---- Build tile-level feature file so batch = (bags, coords, bag_sizes, gt) + # Build tile-level feature file so batch = (bags, coords, bag_sizes, gt) if task == "classification": feature_file = make_old_feature_file( feats=torch.rand(23, dim_feats), coords=torch.rand(23, 2) @@ -319,7 +323,7 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: ) } - # ---- Use tile_bag_dataloader for ALL tasks (so batch has 4 elements) + # Use tile_bag_dataloader for ALL tasks (so batch has 4 elements) test_dl, _ = tile_bag_dataloader( task=task, # "classification" | "regression" | "survival" patient_data=list(patient_to_data.values()), @@ -341,12 +345,17 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: assert len(predictions) == 1 pred = list(predictions.values())[0] if task == "classification": - assert pred.shape == torch.Size([len(categories)]) + pred_tensor = cast(torch.Tensor, pred) + assert pred_tensor.shape == torch.Size([len(categories)]) elif task == "regression": - assert pred.shape == torch.Size([1]) + pred_tensor = cast(torch.Tensor, pred) + assert pred_tensor.shape == torch.Size([1]) else: # survival # Cox model → scalar log-risk, KM → vector or matrix - assert pred.ndim in (0, 1, 2), f"unexpected survival output shape: {pred.shape}" + pred_tensor = cast(torch.Tensor, pred) + assert pred_tensor.ndim in (0, 1, 2), ( + f"unexpected survival output shape: {pred_tensor.shape}" + ) # Repeatability predictions2 = _predict( @@ -356,4 +365,6 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: accelerator="cpu", ) for pid in predictions: - assert torch.allclose(predictions[pid], predictions2[pid]) + assert torch.allclose( + cast(torch.Tensor, predictions[pid]), cast(torch.Tensor, predictions2[pid]) + ) diff --git a/tests/test_encoders.py b/tests/test_encoders.py index 3edef575..28d7c2c1 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -10,13 +10,13 @@ from huggingface_hub.errors import GatedRepoError from random_data import create_random_dataset, create_random_feature_file, random_string -from stamp.cache import download_file from stamp.encoding import ( EncoderName, init_patient_encoder_, init_slide_encoder_, ) from stamp.preprocessing.config import ExtractorName +from stamp.utils.cache import download_file # Contains an accepted input patch-level feature encoder # TODO: Make a class for each extractor instead of a function. This class diff --git a/tests/test_feature_extractors.py b/tests/test_feature_extractors.py index 699c10f6..6323ee8d 100644 --- a/tests/test_feature_extractors.py +++ b/tests/test_feature_extractors.py @@ -7,8 +7,8 @@ import torch from huggingface_hub.errors import GatedRepoError -from stamp.cache import download_file from stamp.preprocessing import ExtractorName, Microns, TilePixels, extract_ +from stamp.utils.cache import download_file @pytest.mark.slow diff --git a/tests/test_model.py b/tests/test_model.py index 1aa6d80a..f74e0a99 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,5 +1,6 @@ import torch +from stamp.modeling.models.barspoon import EncDecTransformer from stamp.modeling.models.mlp import MLP from stamp.modeling.models.trans_mil import TransMIL from stamp.modeling.models.vision_tranformer import VisionTransformer @@ -162,3 +163,114 @@ def test_trans_mil_inference_reproducibility( ) assert logits1.allclose(logits2) + + +def test_enc_dec_transformer_dims( + batch_size: int = 6, + n_tiles: int = 75, + input_dim: int = 456, + d_model: int = 128, +) -> None: + target_n_outs = {"subtype": 3, "grade": 4} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=256, + positional_encoding=True, + ) + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + logits = model.forward(bags, coords) + + assert set(logits.keys()) == set(target_n_outs.keys()) + for target_label, n_out in target_n_outs.items(): + assert logits[target_label].shape == (batch_size, n_out) + + +def test_enc_dec_transformer_single_target( + batch_size: int = 4, + n_tiles: int = 50, + input_dim: int = 256, + d_model: int = 64, +) -> None: + target_n_outs = {"label": 5} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=1, + num_decoder_layers=1, + dim_feedforward=128, + ) + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + logits = model.forward(bags, coords) + + assert list(logits.keys()) == ["label"] + assert logits["label"].shape == (batch_size, 5) + + +def test_enc_dec_transformer_no_positional_encoding( + batch_size: int = 4, + n_tiles: int = 30, + input_dim: int = 128, + d_model: int = 64, +) -> None: + target_n_outs = {"a": 2, "b": 3} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=1, + num_decoder_layers=1, + dim_feedforward=128, + positional_encoding=False, + ) + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + logits = model.forward(bags, coords) + + for target_label, n_out in target_n_outs.items(): + assert logits[target_label].shape == (batch_size, n_out) + + +def test_enc_dec_transformer_inference_reproducibility( + batch_size: int = 5, + n_tiles: int = 40, + input_dim: int = 200, + d_model: int = 64, +) -> None: + target_n_outs = {"subtype": 3, "grade": 4} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=128, + ) + model = model.eval() + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + + with torch.inference_mode(): + logits1 = model.forward(bags, coords) + logits2 = model.forward(bags, coords) + + for target_label in target_n_outs: + assert logits1[target_label].allclose(logits2[target_label]) diff --git a/tests/test_statistics.py b/tests/test_statistics.py index 790b98ab..b600d7e7 100644 --- a/tests/test_statistics.py +++ b/tests/test_statistics.py @@ -2,6 +2,7 @@ from pathlib import Path import numpy as np +import pandas as pd import torch from random_data import random_patient_preds, random_string @@ -47,3 +48,153 @@ def test_statistics_integration( def test_statistics_integration_for_multiple_patient_preds(tmp_path: Path) -> None: return test_statistics_integration(tmp_path=tmp_path, n_patient_preds=5) + + +def test_statistics_survival_integration( + *, + tmp_path: Path, + n_folds: int = 1, + n_patients: int = 200, +) -> None: + """Check that survival statistics run without crashing.""" + random.seed(0) + np.random.seed(0) + + for fold_i in range(n_folds): + times = np.random.uniform(30, 2000, size=n_patients) + statuses = np.random.choice([0, 1], size=n_patients, p=[0.3, 0.7]) + risks = np.random.randn(n_patients) + df = pd.DataFrame( + { + "patient": [random_string(8) for _ in range(n_patients)], + "day": times, + "status": statuses, + "pred_score": risks, + } + ) + df.to_csv(tmp_path / f"survival-preds-{fold_i}.csv", index=False) + + compute_stats_( + task="survival", + output_dir=tmp_path / "output", + pred_csvs=[tmp_path / f"survival-preds-{i}.csv" for i in range(n_folds)], + time_label="day", + status_label="status", + ) + + assert (tmp_path / "output" / "survival-stats_individual.csv").is_file() + + +def test_statistics_survival_integration_multiple_folds(tmp_path: Path) -> None: + return test_statistics_survival_integration(tmp_path=tmp_path, n_folds=5) + + +def test_statistics_regression_integration( + *, + tmp_path: Path, + n_folds: int = 1, + n_patients: int = 200, +) -> None: + """Check that regression statistics run without crashing.""" + random.seed(0) + np.random.seed(0) + + for fold_i in range(n_folds): + y_true = np.random.uniform(0, 100, size=n_patients) + y_pred = y_true + np.random.randn(n_patients) * 10 # noisy predictions + df = pd.DataFrame( + { + "patient": [random_string(8) for _ in range(n_patients)], + "target": y_true, + "pred": y_pred, + } + ) + df.to_csv(tmp_path / f"regression-preds-{fold_i}.csv", index=False) + + compute_stats_( + task="regression", + output_dir=tmp_path / "output", + pred_csvs=[tmp_path / f"regression-preds-{i}.csv" for i in range(n_folds)], + ground_truth_label="target", + ) + + assert (tmp_path / "output" / "target_regression-stats_individual.csv").is_file() + assert (tmp_path / "output" / "target_regression-stats_aggregated.csv").is_file() + + +def test_statistics_regression_integration_multiple_folds(tmp_path: Path) -> None: + return test_statistics_regression_integration(tmp_path=tmp_path, n_folds=5) + + +def test_statistics_multi_target_classification_integration( + *, + tmp_path: Path, + n_patient_preds: int = 1, +) -> None: + """Check that multi-target classification statistics run without crashing. + + Multi-target predictions produce separate ground-truth columns per target. + We run compute_stats_ once per target, as the statistics pipeline handles + one target at a time. + """ + random.seed(0) + np.random.seed(0) + torch.random.manual_seed(0) + + categories_per_target = {"subtype": ["A", "B"], "grade": ["1", "2", "3"]} + + for pred_i in range(n_patient_preds): + n_patients = random.randint(100, 500) + data: dict[str, list] = { + "patient": [random_string(8) for _ in range(n_patients)], + } + + for target_label, cats in categories_per_target.items(): + data[target_label] = [random.choice(cats) for _ in range(n_patients)] + probs = torch.softmax(torch.rand(len(cats), n_patients), dim=0) + for j, cat in enumerate(cats): + data[f"{target_label}_{cat}"] = probs[j].tolist() + + pd.DataFrame(data).to_csv( + tmp_path / f"multi-target-preds-{pred_i}.csv", index=False + ) + + # Run statistics per target (as the pipeline would do) + for target_label, cats in categories_per_target.items(): + true_class = cats[0] + compute_stats_( + task="classification", + output_dir=tmp_path / "output" / target_label, + pred_csvs=[ + tmp_path / f"multi-target-preds-{i}.csv" for i in range(n_patient_preds) + ], + ground_truth_label=target_label, + true_class=true_class, + ) + + assert ( + tmp_path + / "output" + / target_label + / f"{target_label}_categorical-stats_aggregated.csv" + ).is_file() + assert ( + tmp_path + / "output" + / target_label + / f"roc-curve_{target_label}={true_class}.svg" + ).is_file() + assert ( + tmp_path + / "output" + / target_label + / f"pr-curve_{target_label}={true_class}.svg" + ).is_file() + + +def test_statistics_multi_target_classification_multiple_preds( + tmp_path: Path, +) -> None: + return test_statistics_multi_target_classification_integration( + tmp_path=tmp_path, n_patient_preds=3 + ) diff --git a/tests/test_train_deploy.py b/tests/test_train_deploy.py index 0180d171..ea5547f1 100644 --- a/tests/test_train_deploy.py +++ b/tests/test_train_deploy.py @@ -8,6 +8,7 @@ import torch from random_data import ( create_random_dataset, + create_random_multi_target_dataset, create_random_patient_level_dataset, create_random_patient_level_survival_dataset, create_random_regression_dataset, @@ -22,8 +23,9 @@ VitModelParams, ) from stamp.modeling.deploy import deploy_categorical_model_ +from stamp.modeling.registry import ModelName from stamp.modeling.train import train_categorical_model_ -from stamp.seed import Seed +from stamp.utils.seed import Seed @pytest.mark.slow @@ -123,6 +125,9 @@ def test_train_deploy_integration( pytest.param(False, True, id="use vary_precision_transform"), ], ) +@pytest.mark.filterwarnings( + "ignore:.*violates type hint.*not instance of tuple:UserWarning" +) def test_train_deploy_patient_level_integration( *, tmp_path: Path, @@ -356,6 +361,9 @@ def test_train_deploy_survival_integration( @pytest.mark.slow +@pytest.mark.filterwarnings( + "ignore:.*violates type hint.*not instance of tuple:UserWarning" +) def test_train_deploy_patient_level_regression_integration( *, tmp_path: Path, @@ -465,6 +473,9 @@ def test_train_deploy_patient_level_regression_integration( @pytest.mark.slow +@pytest.mark.filterwarnings( + "ignore:.*violates type hint.*not instance of tuple:UserWarning" +) def test_train_deploy_patient_level_survival_integration( *, tmp_path: Path, @@ -531,3 +542,89 @@ def test_train_deploy_patient_level_survival_integration( accelerator="gpu" if torch.cuda.is_available() else "cpu", num_workers=min(os.cpu_count() or 1, 16), ) + + +@pytest.mark.slow +@pytest.mark.filterwarnings("ignore:No positive samples in targets") +def test_train_deploy_multi_target_integration( + *, + tmp_path: Path, + feat_dim: int = 25, +) -> None: + """Integration test: train + deploy a multi-target tile-level classification model.""" + Seed.set(42) + + (tmp_path / "train").mkdir() + (tmp_path / "deploy").mkdir() + + # Define multi-target setup: subtype (2 categories) and grade (3 categories) + target_labels = ["subtype", "grade"] + categories_per_target = [["A", "B"], ["1", "2", "3"]] + + # Create random multi-target tile-level dataset + train_clini_path, train_slide_path, train_feature_dir, _ = ( + create_random_multi_target_dataset( + dir=tmp_path / "train", + n_patients=400, + max_slides_per_patient=3, + min_tiles_per_slide=20, + max_tiles_per_slide=600, + feat_dim=feat_dim, + target_labels=target_labels, + categories_per_target=categories_per_target, + ) + ) + deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( + create_random_multi_target_dataset( + dir=tmp_path / "deploy", + n_patients=50, + max_slides_per_patient=3, + min_tiles_per_slide=20, + max_tiles_per_slide=600, + feat_dim=feat_dim, + target_labels=target_labels, + categories_per_target=categories_per_target, + ) + ) + + # Build config objects + config = TrainConfig( + task="classification", + clini_table=train_clini_path, + slide_table=train_slide_path, + feature_dir=train_feature_dir, + output_dir=tmp_path / "train_output", + patient_label="patient", + ground_truth_label=target_labels, + filename_label="slide_path", + categories=[cat for cats in categories_per_target for cat in cats], + ) + + advanced = AdvancedConfig( + bag_size=500, + num_workers=min(os.cpu_count() or 1, 16), + batch_size=8, + max_epochs=2, + patience=1, + accelerator="gpu" if torch.cuda.is_available() else "cpu", + model_params=ModelParams(), + model_name=ModelName.BARSPOON, + ) + + # Train + deploy multi-target model + train_categorical_model_(config=config, advanced=advanced) + + deploy_categorical_model_( + output_dir=tmp_path / "deploy_output", + checkpoint_paths=[tmp_path / "train_output" / "model.ckpt"], + clini_table=deploy_clini_path, + slide_table=deploy_slide_path, + feature_dir=deploy_feature_dir, + patient_label="patient", + ground_truth_label=target_labels, + time_label=None, + status_label=None, + filename_label="slide_path", + accelerator="gpu" if torch.cuda.is_available() else "cpu", + num_workers=min(os.cpu_count() or 1, 16), + ) diff --git a/uv.lock b/uv.lock index c4015d9f..96b4b73a 100644 --- a/uv.lock +++ b/uv.lock @@ -3699,13 +3699,14 @@ wheels = [ [[package]] name = "stamp" -version = "2.3.0" +version = "2.4.0" source = { editable = "." } dependencies = [ { name = "beartype" }, { name = "einops" }, { name = "h5py" }, { name = "jaxtyping" }, + { name = "lifelines" }, { name = "lightning" }, { name = "matplotlib" }, { name = "numpy" }, @@ -3807,7 +3808,6 @@ gigapath = [ { name = "fvcore" }, { name = "gigapath" }, { name = "iopath" }, - { name = "lifelines" }, { name = "monai" }, { name = "scikit-image" }, { name = "scikit-survival" }, @@ -3828,7 +3828,6 @@ gpu = [ { name = "huggingface-hub" }, { name = "iopath" }, { name = "jinja2" }, - { name = "lifelines" }, { name = "madeleine" }, { name = "mamba-ssm" }, { name = "monai" }, @@ -3920,7 +3919,7 @@ requires-dist = [ { name = "iopath", marker = "extra == 'gigapath'" }, { name = "jaxtyping", specifier = ">=0.3.2" }, { name = "jinja2", marker = "extra == 'cobra'", specifier = ">=3.1.4" }, - { name = "lifelines", marker = "extra == 'gigapath'" }, + { name = "lifelines", specifier = ">=0.28.0" }, { name = "lightning", specifier = ">=2.5.2" }, { name = "madeleine", marker = "extra == 'madeleine'", git = "https://github.com/mahmoodlab/MADELEINE.git?rev=de7c85acc2bdad352e6df8eee5694f8b6f288012" }, { name = "mamba-ssm", marker = "extra == 'cobra'", specifier = ">=2.2.6.post3" }, @@ -4747,4 +4746,4 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/24/fd/725b8e73ac2a50e78a4534ac43c6addf5c1c2d65380dd48a9169cc6739a9/yarl-1.20.1-cp313-cp313t-win32.whl", hash = "sha256:b121ff6a7cbd4abc28985b6028235491941b9fe8fe226e6fdc539c977ea1739d", size = 86591, upload-time = "2025-06-10T00:45:25.793Z" }, { url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" }, { url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" }, -] \ No newline at end of file +] From d8cc268200732f851c36ab59e4cbb1bae6f5c782 Mon Sep 17 00:00:00 2001 From: mducducd Date: Fri, 13 Feb 2026 14:16:16 +0000 Subject: [PATCH 2/9] add multi-target support --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3069dddf..9a0ed68b 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ STAMP is an **end‑to‑end, weakly‑supervised deep‑learning pipeline** tha * 🎓 **Beginner‑friendly & expert‑ready**: Zero‑code CLI and YAML config for routine use; optional code‑level customization for advanced research. * 🧩 **Model‑rich**: Out‑of‑the‑box support for **+20 foundation models** at [tile level](getting-started.md#feature-extraction) (e.g., *Virchow‑v2*, *UNI‑v2*) and [slide level](getting-started.md#slide-level-encoding) (e.g., *TITAN*, *COBRA*). * 🔬 **Weakly‑supervised**: End‑to‑end MIL with Transformer aggregation for training, cross‑validation and external deployment; no pixel‑level labels required. -* 🧮 **Multi-task learning**: Unified framework for **classification**, **regression**, and **cox-based survival analysis**. +* 🧮 **Multi-task learning**: Unified framework for **classification**, **multi-target classification**, **regression**, and **cox-based survival analysis**. * 📊 **Stats & results**: Built‑in metrics and patient‑level predictions, ready for analysis and reporting. * 🖼️ **Explainable**: Generates heatmaps and top‑tile exports out‑of‑the‑box for transparent model auditing and publication‑ready figures. * 🤝 **Collaborative by design**: Clinicians drive hypothesis & interpretation while engineers handle compute; STAMP’s modular CLI mirrors real‑world workflows and tracks every step for full reproducibility. From f7fa2c599a2e34ddd8cf67ae764e887995ec80f7 Mon Sep 17 00:00:00 2001 From: mducducd Date: Mon, 16 Feb 2026 14:12:38 +0000 Subject: [PATCH 3/9] add multi-target statistics --- src/stamp/statistics/categorical.py | 82 +++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/src/stamp/statistics/categorical.py b/src/stamp/statistics/categorical.py index 2b5c859e..30a03a86 100755 --- a/src/stamp/statistics/categorical.py +++ b/src/stamp/statistics/categorical.py @@ -21,6 +21,30 @@ ] +def _detect_targets_from_columns(columns: Sequence[str]) -> list[str]: + """Detect target columns from CSV column names. + + Assumes multi-target format where each target has: + - A ground truth column (target name) + - A prediction column (pred_{target}) + - Probability columns ({target}_{class1}, {target}_{class2}, ...) + + Returns: + List of target names detected. + """ + # Convert to list to handle pandas Index + columns = list(columns) + targets = [] + for col in columns: + # Look for columns that start with "pred_" + if col.startswith("pred_"): + target_name = col[5:] # Remove "pred_" prefix + # Verify the target column exists + if target_name in columns: + targets.append(target_name) + return sorted(targets) + + def _categorical(preds_df: pd.DataFrame, target_label: str) -> pd.DataFrame: """Calculates some stats for categorical prediction tables. @@ -110,3 +134,61 @@ def categorical_aggregated_( preds_df.to_csv(outpath / f"{ground_truth_label}_categorical-stats_individual.csv") stats_df = _aggregate_categorical_stats(preds_df.reset_index()) stats_df.to_csv(outpath / f"{ground_truth_label}_categorical-stats_aggregated.csv") + + +def categorical_aggregated_multitarget_( + *, + preds_csvs: Sequence[Path], + outpath: Path, + target_labels: Sequence[str], +) -> None: + """Calculate statistics for multi-target categorical deployments. + + Args: + preds_csvs: CSV files containing predictions. + outpath: Path to save the results to. + target_labels: List of target labels to compute statistics for. + + This will apply `_categorical` to each target in the multi-target setup, + calculate statistics per target, and save both individual and aggregated results. + """ + outpath.mkdir(parents=True, exist_ok=True) + + all_target_stats = {} + + for target_label in target_labels: + # Process each target separately + preds_dfs = {} + for p in preds_csvs: + df = pd.read_csv(p, dtype=str) + # Drop rows where this target's ground truth is missing + df_clean = df.dropna(subset=[target_label]) + if len(df_clean) > 0: + preds_dfs[Path(p).parent.name] = _categorical(df_clean, target_label) + + if not preds_dfs: + continue + + # Concatenate and save individual stats for this target + preds_df = pd.concat(preds_dfs).sort_index() + preds_df.to_csv(outpath / f"{target_label}_categorical-stats_individual.csv") + + # Aggregate stats for this target + stats_df = _aggregate_categorical_stats(preds_df.reset_index()) + stats_df.to_csv(outpath / f"{target_label}_categorical-stats_aggregated.csv") + + # Store for summary + all_target_stats[target_label] = stats_df + + # Create a combined summary across all targets + if all_target_stats: + summary_dfs = [] + for target_name, stats_df in all_target_stats.items(): + stats_copy = stats_df.copy() + stats_copy.index = pd.MultiIndex.from_product( + [[target_name], stats_copy.index], names=["target", "class"] + ) + summary_dfs.append(stats_copy) + + combined_summary = pd.concat(summary_dfs) + combined_summary.to_csv(outpath / "multitarget_categorical-stats_summary.csv") From 747143f65b57edd50df8725b63f78b089f6a5a3f Mon Sep 17 00:00:00 2001 From: mducducd Date: Mon, 16 Feb 2026 15:28:08 +0000 Subject: [PATCH 4/9] refactor --- src/stamp/encoding/__init__.py | 2 +- src/stamp/encoding/encoder/__init__.py | 4 +- src/stamp/heatmaps/__init__.py | 100 ++--- src/stamp/modeling/crossval.py | 2 +- src/stamp/modeling/data.py | 53 ++- src/stamp/modeling/deploy.py | 2 +- src/stamp/modeling/models/__init__.py | 125 ++++--- src/stamp/modeling/models/mlp.py | 2 +- .../preprocessing/extractor/ctranspath.py | 2 +- src/stamp/preprocessing/tiling.py | 5 +- src/stamp/statistics/__init__.py | 271 +++++++++++--- src/stamp/statistics/categorical.py | 14 +- src/stamp/utils/target_file.py | 351 ------------------ tests/random_data.py | 6 +- tests/test_data.py | 27 ++ tests/test_deployment.py | 10 +- 16 files changed, 432 insertions(+), 544 deletions(-) delete mode 100644 src/stamp/utils/target_file.py diff --git a/src/stamp/encoding/__init__.py b/src/stamp/encoding/__init__.py index 9cb873bb..02d46594 100644 --- a/src/stamp/encoding/__init__.py +++ b/src/stamp/encoding/__init__.py @@ -73,7 +73,7 @@ def init_slide_encoder_( selected_encoder = encoder case _ as unreachable: - assert_never(unreachable) # type: ignore + assert_never(unreachable) selected_encoder.encode_slides_( output_dir=output_dir, diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index 86daa54a..4720ef9b 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from pathlib import Path from tempfile import NamedTemporaryFile +from typing import cast import h5py import numpy as np @@ -183,7 +184,8 @@ def _read_h5( elif not h5_path.endswith(".h5"): raise ValueError(f"File is not of type .h5: {os.path.basename(h5_path)}") with h5py.File(h5_path, "r") as f: - feats: Tensor = torch.tensor(f["feats"][:], dtype=self.precision) # type: ignore + feats_ds = cast(h5py.Dataset, f["feats"]) + feats: Tensor = torch.tensor(feats_ds[:], dtype=self.precision) coords: CoordsInfo = get_coords(f) extractor: str = f.attrs.get("extractor", "") if extractor == "": diff --git a/src/stamp/heatmaps/__init__.py b/src/stamp/heatmaps/__init__.py index 22fb5250..fb704fe6 100644 --- a/src/stamp/heatmaps/__init__.py +++ b/src/stamp/heatmaps/__init__.py @@ -5,7 +5,7 @@ import logging from collections.abc import Collection, Iterable from pathlib import Path -from typing import cast, no_type_check +from typing import cast import h5py import matplotlib.pyplot as plt @@ -19,7 +19,7 @@ from packaging.version import Version from PIL import Image from torch import Tensor -from torch.func import jacrev # pyright: ignore[reportPrivateImportUsage] +from torch.func import jacrev from stamp.modeling.data import get_coords, get_stride from stamp.modeling.deploy import load_model_from_ckpt @@ -29,6 +29,8 @@ _logger = logging.getLogger("stamp") +_SlideLike = openslide.OpenSlide | openslide.ImageSlide + def _gradcam_per_category( model: torch.nn.Module, @@ -37,23 +39,19 @@ def _gradcam_per_category( ) -> Float[Tensor, "tile category"]: feat_dim = -1 - cam = ( - ( - feats - * jacrev( - lambda bags: model.forward( - bags.unsqueeze(0), - coords=coords.unsqueeze(0), - mask=None, - ).squeeze(0) - )(feats) - ) - .mean(feat_dim) # type: ignore - .abs() + jac = cast( + Tensor, + jacrev( + lambda bags: model.forward( + bags.unsqueeze(0), + coords=coords.unsqueeze(0), + mask=None, + ).squeeze(0) + )(feats), ) + cam = (feats * jac).mean(feat_dim).abs() cam = torch.softmax(cam, dim=-1) - return cam.permute(-1, -2) @@ -70,21 +68,28 @@ def _attention_rollout_single( device = feats.device - # 1. Forward pass to fill attn_weights in each SelfAttention layer + # --- 1. Forward pass to fill attn_weights in each SelfAttention layer --- _ = model( bags=feats.unsqueeze(0), coords=coords.unsqueeze(0), mask=torch.zeros(1, len(feats), dtype=torch.bool, device=device), ) - # 2. Rollout computation + # --- 2. Rollout computation --- attn_rollout: torch.Tensor | None = None - for layer in model.transformer.layers: # type: ignore - attn = getattr(layer[0], "attn_weights", None) # SelfAttention.attn_weights + transformer = getattr(model, "transformer", None) + if transformer is None: + raise RuntimeError("Model does not have a transformer attribute") + for layer in transformer.layers: + attn = getattr(layer, "attn_weights", None) + if attn is None: + first_child = next(iter(layer.children()), None) + if first_child is not None: + attn = getattr(first_child, "attn_weights", None) if attn is None: raise RuntimeError( "SelfAttention.attn_weights not found. " - "Make sure SelfAttention stores them." + "Make sure SelfAttention stores them on the layer or its first child." ) # attn: [heads, seq, seq] @@ -96,10 +101,10 @@ def _attention_rollout_single( if attn_rollout is None: raise RuntimeError("No attention maps collected from transformer layers.") - # 3. Extract CLS → tiles attention + # --- 3. Extract CLS → tiles attention --- cls_attn = attn_rollout[0, 1:] # [tile] - # 4. Normalize for visualization consistency + # --- 4. Normalize for visualization consistency --- cls_attn = cls_attn - cls_attn.min() cls_attn = cls_attn / (cls_attn.max().clamp(min=1e-8)) @@ -117,15 +122,18 @@ def _gradcam_single( """ feat_dim = -1 - jac = jacrev( - lambda bags: model.forward( - bags.unsqueeze(0), - coords=coords.unsqueeze(0), - mask=None, - ).squeeze() - )(feats) + jac = cast( + Tensor, + jacrev( + lambda bags: model.forward( + bags.unsqueeze(0), + coords=coords.unsqueeze(0), + mask=None, + ).squeeze() + )(feats), + ) - cam = (feats * jac).mean(feat_dim).abs() # type: ignore # [tile] + cam = (feats * jac).mean(feat_dim).abs() # [tile] return cam @@ -148,17 +156,21 @@ def _vals_to_im( def _show_thumb( - slide, thumb_ax: Axes, attention: Tensor, default_slide_mpp: SlideMPP | None + slide: _SlideLike, + thumb_ax: Axes, + attention: Tensor, + default_slide_mpp: SlideMPP | None, ) -> np.ndarray: mpp = get_slide_mpp_(slide, default_mpp=default_slide_mpp) dims_um = np.array(slide.dimensions) * mpp - thumb = slide.get_thumbnail(np.round(dims_um * 8 / 256).astype(int)) + thumb_size = tuple(np.round(dims_um * 8 / 256).astype(int).tolist()) + thumb = slide.get_thumbnail(thumb_size) thumb_ax.imshow(np.array(thumb)[: attention.shape[0] * 8, : attention.shape[1] * 8]) return np.array(thumb)[: attention.shape[0] * 8, : attention.shape[1] * 8] def _get_thumb_array( - slide, + slide: _SlideLike, attention: torch.Tensor, default_slide_mpp: SlideMPP | None, ) -> np.ndarray: @@ -168,12 +180,12 @@ def _get_thumb_array( """ mpp = get_slide_mpp_(slide, default_mpp=default_slide_mpp) dims_um = np.array(slide.dimensions) * mpp - thumb = np.array(slide.get_thumbnail(np.round(dims_um * 8 / 256).astype(int))) + thumb_size = tuple(np.round(dims_um * 8 / 256).astype(int).tolist()) + thumb = np.array(slide.get_thumbnail(thumb_size)) thumb_crop = thumb[: attention.shape[0] * 8, : attention.shape[1] * 8] return thumb_crop -@no_type_check # beartype<=0.19.0 breaks here for some reason def _show_class_map( class_ax: Axes, top_score_indices: Integer[Tensor, "width height"], @@ -298,13 +310,8 @@ def heatmaps_( raise ValueError( f"Feature file {h5_path} is a slide or patient level feature. Heatmaps are currently supported for tile-level features only." ) - feats = ( - torch.tensor( - h5["feats"][:] # pyright: ignore[reportIndexIssue] - ) - .float() - .to(device) - ) + feats_np = np.asarray(h5["feats"]) + feats = torch.from_numpy(feats_np).float().to(device) coords_info = get_coords(h5) coords_um = torch.from_numpy(coords_info.coords_um).float() stride_um = Microns(get_stride(coords_um)) @@ -322,9 +329,10 @@ def heatmaps_( model = load_model_from_ckpt(checkpoint_path).eval() # TODO: Update version when a newer model logic breaks heatmaps. - if Version(model.stamp_version) < Version("2.4.0"): + stamp_version = str(getattr(model, "stamp_version", "")) + if Version(stamp_version) < Version("2.4.0"): raise ValueError( - f"model has been built with stamp version {model.stamp_version} " + f"model has been built with stamp version {stamp_version} " f"which is incompatible with the current version." ) @@ -356,7 +364,7 @@ def heatmaps_( with torch.no_grad(): scores = torch.softmax( - model.model.forward( + model.model( feats.unsqueeze(-2), coords=coords_um.unsqueeze(-2), mask=torch.zeros( diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index 2caccecb..8ddfb03d 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -121,7 +121,7 @@ def categorical_crossval_( categories_for_export: ( dict[str, list] | list ) = [] # declare upfront to avoid unbound variable warnings - categories: Sequence[GroundTruth] | list | None = [] # type: ignore # declare upfront to avoid unbound variable warnings + categories: Sequence[GroundTruth] | list | None = [] if config.task == "classification": # Determine categories for training (single-target) and for export (supports multi-target) diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 8b30ab11..eadb42f8 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -3,13 +3,13 @@ import logging from collections.abc import Callable, Iterable, Mapping, Sequence from dataclasses import KW_ONLY, dataclass +from io import BytesIO # accept in _BinaryIOLike at runtime from itertools import groupby from pathlib import Path from typing import ( IO, Any, BinaryIO, - Dict, Final, Generic, List, @@ -23,6 +23,9 @@ import numpy as np import pandas as pd import torch + +# Use beartype's typing for PEP-585 deprecation-safe hints +from beartype.typing import Dict from packaging.version import Version from torch import Tensor from torch.utils.data import DataLoader, Dataset @@ -58,7 +61,9 @@ _EncodedTarget: TypeAlias = ( Tensor | dict[str, Tensor] ) # Union of encoded targets or multi-target dict -_BinaryIOLike: TypeAlias = Union[BinaryIO, IO[bytes]] +_BinaryIOLike: TypeAlias = Union[ + BinaryIO, IO[bytes], BytesIO +] # includes io.BytesIO for runtime checks """The ground truth, encoded numerically - classification: one-hot float [C] - regression: float [1] @@ -73,7 +78,7 @@ class PatientData(Generic[GroundTruthType]): _ = KW_ONLY ground_truth: GroundTruthType - feature_files: Iterable[FeaturePath | BinaryIO] + feature_files: Iterable[FeaturePath | _BinaryIOLike] def tile_bag_dataloader( @@ -533,9 +538,19 @@ def __getitem__( for bag_file in self.bags[index]: with h5py.File(bag_file, "r") as h5: if "feats" in h5: - arr = h5["feats"][:] # pyright: ignore[reportIndexIssue] # original STAMP files + feats_obj = h5["feats"] + if not isinstance(feats_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" + ) + arr = feats_obj[:] # original STAMP files else: - arr = h5["patch_embeddings"][:] # type: ignore # your Kronos files + embeddings_obj = h5["patch_embeddings"] + if not isinstance(embeddings_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'patch_embeddings' to be an HDF5 dataset but got {type(embeddings_obj)}" + ) + arr = embeddings_obj[:] # your Kronos files feats.append(torch.from_numpy(arr)) coords_um.append(torch.from_numpy(get_coords(h5).coords_um)) @@ -569,7 +584,7 @@ class PatientFeatureDataset(Dataset): def __init__( self, - feature_files: Sequence[FeaturePath | BinaryIO], + feature_files: Sequence[FeaturePath | _BinaryIOLike], ground_truths: Tensor, # shape: [num_samples, num_classes] transform: Callable[[Tensor], Tensor] | None, ): @@ -585,7 +600,12 @@ def __len__(self): def __getitem__(self, idx: int): feature_file = self.feature_files[idx] with h5py.File(feature_file, "r") as h5: - feats = torch.from_numpy(h5["feats"][:]) # pyright: ignore[reportIndexIssue] + feats_obj = h5["feats"] + if not isinstance(feats_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" + ) + feats = torch.from_numpy(feats_obj[:]) # Accept [V] or [1, V] if feats.ndim == 2 and feats.shape[0] == 1: feats = feats[0] @@ -634,10 +654,15 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: tile_size_px = TilePixels(0) return CoordsInfo(coords_um, tile_size_um, tile_size_px) - coords: np.ndarray = feature_h5["coords"][:] # type: ignore - coords_um: np.ndarray | None = None + coords_obj = feature_h5["coords"] + if not isinstance(coords_obj, h5py.Dataset): + raise RuntimeError( + f"{feature_h5.filename}: expected 'coords' to be an HDF5 dataset but got {type(coords_obj)}" + ) + coords: np.ndarray = coords_obj[:] tile_size_um: Microns | None = None tile_size_px: TilePixels | None = None + coords_um: np.ndarray | None = None if (tile_size := feature_h5.attrs.get("tile_size", None)) and feature_h5.attrs.get( "unit", None ) == "um": @@ -672,7 +697,15 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: ) if not tile_size_px and "tile_size_px" in feature_h5.attrs: - tile_size_px = TilePixels(int(feature_h5.attrs["tile_size_px"])) # pyright: ignore[reportArgumentType] + tile_size_px_attr = feature_h5.attrs.get("tile_size_px") + if tile_size_px_attr is not None and isinstance( + tile_size_px_attr, (int, float) + ): + tile_size_px = TilePixels(int(tile_size_px_attr)) + else: + raise RuntimeError( + "Invalid or missing 'tile_size_px' attribute in the feature file." + ) if not tile_size_um or coords_um is None: raise RuntimeError( diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index 10272328..d3b29ebd 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -276,7 +276,7 @@ def deploy_categorical_model_( for model_i, model in enumerate(models): predictions = _predict( model=model, - test_dl=test_dl, # pyright: ignore[reportPossiblyUnboundVariable] + test_dl=test_dl, patient_ids=patient_ids, accelerator=accelerator, ) diff --git a/src/stamp/modeling/models/__init__.py b/src/stamp/modeling/models/__init__.py index 86003c52..0b6a3885 100644 --- a/src/stamp/modeling/models/__init__.py +++ b/src/stamp/modeling/models/__init__.py @@ -3,11 +3,13 @@ import inspect from abc import ABC from collections.abc import Iterable, Sequence -from typing import Any, Mapping, TypeAlias +from typing import Any, TypeAlias import lightning -import numpy as np import torch + +# Use beartype.typing.Mapping to avoid PEP-585 deprecation warnings in beartype +from beartype.typing import Mapping from jaxtyping import Bool, Float from packaging.version import Version from torch import Tensor, nn, optim @@ -148,6 +150,19 @@ def on_train_batch_end(self, outputs, batch, batch_idx): ) +class _TileLevelMixin: + """Mixin for tile-level models providing shared MIL masking logic.""" + + @staticmethod + def _mask_from_bags(bags: Bags, bag_sizes: BagSizes) -> Bool[Tensor, "batch tile"]: + """Create attention mask for padded tiles in variable-length bags.""" + max_possible_bag_size = bags.size(1) + mask = torch.arange(max_possible_bag_size).type_as(bag_sizes).unsqueeze( + 0 + ).repeat(len(bags), 1) >= bag_sizes.unsqueeze(1) + return mask + + class LitBaseClassifier(Base): """ PyTorch Lightning wrapper for tile level and patient level clasification. @@ -199,15 +214,16 @@ def __init__( self.class_weights = category_weights self.valid_auroc = MulticlassAUROC(len(categories)) # Number classes - self.categories = np.array(categories) + self.categories = list(categories) self.hparams.update({"task": "classification"}) -class LitTileClassifier(LitBaseClassifier): +class LitTileClassifier(_TileLevelMixin, LitBaseClassifier): """ - PyTorch Lightning wrapper for the model used in weakly supervised - learning settings, such as Multiple Instance Learning (MIL) for whole-slide images or patch-based data. + PyTorch Lightning wrapper for tile-level MIL classification. + + Used in weakly supervised settings for whole-slide images or patch-based data. """ supported_features = ["tile"] @@ -249,7 +265,6 @@ def _step( ) if step_name == "validation": - # TODO this is a bit ugly, we'd like to have `_step` without special cases self.valid_auroc.update(logits, targets.long().argmax(dim=-1)) self.log( f"{step_name}_auroc", @@ -291,31 +306,19 @@ def predict_step( # adding a mask here will *drastically* and *unbearably* increase memory usage return self.model(bags, coords=coords, mask=None) - def _mask_from_bags( - *, - bags: Bags, - bag_sizes: BagSizes, - ) -> Bool[Tensor, "batch tile"]: - max_possible_bag_size = bags.size(1) - mask = torch.arange(max_possible_bag_size).type_as(bag_sizes).unsqueeze( - 0 - ).repeat(len(bags), 1) >= bag_sizes.unsqueeze(1) - - return mask - class LitSlideClassifier(LitBaseClassifier): - """ - PyTorch Lightning wrapper for MLPClassifier. - """ + """PyTorch Lightning wrapper for slide/patient-level classification.""" supported_features = ["slide"] def forward(self, x: Tensor) -> Tensor: return self.model(x) - def _step(self, batch, step_name: str): - feats, targets = batch + def _step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], step_name: str + ) -> Loss: + feats, targets = list(batch) # Works for both tuple and list logits = self.model(feats.float()) loss = nn.functional.cross_entropy( logits, @@ -341,17 +344,25 @@ def _step(self, batch, step_name: str): ) return loss - def training_step(self, batch, batch_idx): + def training_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch, "training") - def validation_step(self, batch, batch_idx): + def validation_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch, "validation") - def test_step(self, batch, batch_idx): + def test_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch, "test") - def predict_step(self, batch, batch_idx): - feats, _ = batch + def predict_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Tensor: + feats, _ = batch if isinstance(batch, tuple) else batch return self.model(feats) @@ -402,9 +413,10 @@ def _compute_loss(y_true: Tensor, y_pred: Tensor) -> Loss: return nn.functional.l1_loss(y_true, y_pred) -class LitTileRegressor(LitBaseRegressor): +class LitTileRegressor(_TileLevelMixin, LitBaseRegressor): """ - PyTorch Lightning wrapper for weakly supervised / MIL regression at tile/patient level. + PyTorch Lightning wrapper for tile-level MIL regression. + Produces a single continuous output per bag (dim_output = 1). """ @@ -491,38 +503,22 @@ def predict_step( # keep memory usage low as in classifier return self.model(bags, coords=coords, mask=None) - def _mask_from_bags( - *, - bags: Bags, - bag_sizes: BagSizes, - ) -> Bool[Tensor, "batch tile"]: - max_possible_bag_size = bags.size(1) - mask = torch.arange(max_possible_bag_size).type_as(bag_sizes).unsqueeze( - 0 - ).repeat(len(bags), 1) >= bag_sizes.unsqueeze(1) - - return mask - class LitSlideRegressor(LitBaseRegressor): - """ - PyTorch Lightning wrapper for slide-level or patient-level regression. - Produces a single continuous output per slide (dim_output = 1). - """ + """PyTorch Lightning wrapper for slide/patient-level regression.""" supported_features = ["slide"] def forward(self, feats: Tensor) -> Tensor: - """Forward pass for slide-level features.""" return self.model(feats.float()) def _step( self, *, - batch: tuple[Tensor, Tensor], + batch: tuple[Tensor, Tensor] | list[Tensor], step_name: str, ) -> Loss: - feats, targets = batch + feats, targets = list(batch) # Works for both tuple and list preds = self.model(feats.float(), mask=None) # (B, 1) y = targets.to(preds).float() @@ -539,7 +535,6 @@ def _step( ) if step_name == "validation": - # same metrics as LitTileRegressor p = preds.squeeze(-1) t = y.squeeze(-1) self.log( @@ -552,17 +547,25 @@ def _step( return loss - def training_step(self, batch, batch_idx): + def training_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch=batch, step_name="training") - def validation_step(self, batch, batch_idx): + def validation_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch=batch, step_name="validation") - def test_step(self, batch, batch_idx): + def test_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch=batch, step_name="test") - def predict_step(self, batch, batch_idx): - feats, _ = batch + def predict_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Tensor: + feats, _ = batch if isinstance(batch, tuple) else batch return self.model(feats.float()) @@ -707,12 +710,12 @@ def on_train_epoch_end(self): self.hparams.update({"train_pred_median": self.train_pred_median}) -class LitTileSurvival(LitSurvivalBase): +class LitTileSurvival(_TileLevelMixin, LitSurvivalBase): """ - Tile-level or patch-level survival analysis. - Expects dataloader batches like: - (bags, coords, bag_sizes, targets) - where targets is shape (B,2): [:,0]=time, [:,1]=event (1=event, 0=censored). + Tile-level survival analysis with Cox proportional hazards loss. + + Expects batches: (bags, coords, bag_sizes, targets) + where targets.shape = (B, 2): [:,0]=time, [:,1]=event (0=censored, 1=event). """ supported_features = ["tile"] diff --git a/src/stamp/modeling/models/mlp.py b/src/stamp/modeling/models/mlp.py index e4f8881f..e88a77ca 100644 --- a/src/stamp/modeling/models/mlp.py +++ b/src/stamp/modeling/models/mlp.py @@ -29,7 +29,7 @@ def __init__( layers.append(nn.Dropout(dropout)) in_dim = dim_hidden layers.append(nn.Linear(in_dim, dim_output)) - self.mlp = nn.Sequential(*layers) # type: ignore + self.mlp = nn.Sequential(*layers) @jaxtyped(typechecker=beartype) def forward( diff --git a/src/stamp/preprocessing/extractor/ctranspath.py b/src/stamp/preprocessing/extractor/ctranspath.py index 387e7947..d189e279 100644 --- a/src/stamp/preprocessing/extractor/ctranspath.py +++ b/src/stamp/preprocessing/extractor/ctranspath.py @@ -518,7 +518,7 @@ def forward(self, x, mask: Optional[torch.Tensor] = None): attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.view(-1) + self.relative_position_index.view(-1) # pyright: ignore[reportCallIssue] ].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], diff --git a/src/stamp/preprocessing/tiling.py b/src/stamp/preprocessing/tiling.py index e09a06fa..ce684ba4 100644 --- a/src/stamp/preprocessing/tiling.py +++ b/src/stamp/preprocessing/tiling.py @@ -461,7 +461,10 @@ def _extract_mpp_from_metadata(slide: openslide.AbstractSlide) -> SlideMPP | Non return None doc = minidom.parseString(xml_path) collection = doc.documentElement - images = collection.getElementsByTagName("Image") # pyright: ignore[reportOptionalMemberAccess] + if collection is None: + _logger.error("Document element is None, unable to extract MPP.") + return None + images = collection.getElementsByTagName("Image") pixels = images[0].getElementsByTagName("Pixels") mpp = float(pixels[0].getAttribute("PhysicalSizeX")) except Exception: diff --git a/src/stamp/statistics/__init__.py b/src/stamp/statistics/__init__.py index bdbef1fa..b3243ecc 100644 --- a/src/stamp/statistics/__init__.py +++ b/src/stamp/statistics/__init__.py @@ -15,7 +15,10 @@ from matplotlib import pyplot as plt from pydantic import BaseModel, ConfigDict, Field -from stamp.statistics.categorical import categorical_aggregated_ +from stamp.statistics.categorical import ( + categorical_aggregated_, + categorical_aggregated_multitarget_, +) from stamp.statistics.prc import ( plot_multiple_decorated_precision_recall_curves, plot_single_decorated_precision_recall_curve, @@ -51,7 +54,7 @@ class StatsConfig(BaseModel): task: Task = Field(default="classification") output_dir: Path pred_csvs: list[Path] - ground_truth_label: PandasLabel | None = None + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None true_class: str | None = None time_label: str | None = None status_label: str | None = None @@ -60,52 +63,65 @@ class StatsConfig(BaseModel): _Inches = NewType("_Inches", float) -def compute_stats_( +def _compute_multitarget_classification_stats( *, - task: Task, output_dir: Path, pred_csvs: Sequence[Path], - ground_truth_label: PandasLabel | None = None, - true_class: str | None = None, - time_label: str | None = None, - status_label: str | None = None, + target_labels: Sequence[str], ) -> None: - """Compute and save statistics for the provided task and prediction CSVs. + """Compute statistics and plots for multi-target classification. - This wrapper keeps the external API stable while delegating the detailed - computations and plotting to the submodules under `stamp.statistics.*`. + For each target, creates ROC and PRC curves for each class, + similar to single-target classification. """ - match task: - case "classification": - if true_class is None or ground_truth_label is None: - raise ValueError( - "both true_class and ground_truth_label are required in statistic configuration" - ) - - preds_dfs = [ - _read_table( - p, - usecols=[ground_truth_label, f"{ground_truth_label}_{true_class}"], - dtype={ - ground_truth_label: str, - f"{ground_truth_label}_{true_class}": float, - }, - ) - for p in pred_csvs - ] - - y_trues = [ - np.array(df[ground_truth_label] == true_class) for df in preds_dfs - ] - y_preds = [ - np.array(df[f"{ground_truth_label}_{true_class}"].values) - for df in preds_dfs - ] - n_bootstrap_samples = 1000 - figure_width = _Inches(3.8) - threshold_cmap = None - - roc_curve_figure_aspect_ratio = 1.08 + output_dir.mkdir(parents=True, exist_ok=True) + n_bootstrap_samples = 1000 + figure_width = _Inches(3.8) + roc_curve_figure_aspect_ratio = 1.08 + + # Validate all target labels exist in CSV + first_df = _read_table(pred_csvs[0], nrows=0) + missing_targets = [t for t in target_labels if t not in first_df.columns] + if missing_targets: + raise ValueError( + f"Target labels not found in CSV: {missing_targets}. Available columns: {list(first_df.columns)}" + ) + + # Process each target + for target_label in target_labels: + # Load data for this target + preds_dfs = [] + for p in pred_csvs: + df = _read_table(p, dtype=str) + # Only keep rows where this target has ground truth + df_clean = df.dropna(subset=[target_label]) + if len(df_clean) > 0: + preds_dfs.append(df_clean) + + if not preds_dfs: + continue + + # Get unique classes for this target + classes = sorted(preds_dfs[0][target_label].unique()) + + # Create plots for each class in this target + for true_class in classes: + # Extract ground truth and predictions for this class + y_trues = [] + y_preds = [] + + for df in preds_dfs: + prob_col = f"{target_label}_{true_class}" + if prob_col not in df.columns: + continue + + y_trues.append(np.array(df[target_label] == true_class)) + y_preds.append(np.array(df[prob_col].astype(float).values)) + + if not y_trues: + continue + + # Plot ROC curve fig, ax = plt.subplots( figsize=(figure_width, figure_width * roc_curve_figure_aspect_ratio), dpi=300, @@ -116,34 +132,35 @@ def compute_stats_( ax=ax, y_true=y_trues[0], y_score=y_preds[0], - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", n_bootstrap_samples=n_bootstrap_samples, - threshold_cmap=threshold_cmap, + threshold_cmap=None, ) else: plot_multiple_decorated_roc_curves( ax=ax, y_trues=y_trues, y_scores=y_preds, - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", n_bootstrap_samples=None, ) fig.tight_layout() - output_dir.mkdir(parents=True, exist_ok=True) - fig.savefig(output_dir / f"roc-curve_{ground_truth_label}={true_class}.svg") + fig.savefig(output_dir / f"roc-curve_{target_label}={true_class}.svg") plt.close(fig) + # Plot PRC curve fig, ax = plt.subplots( figsize=(figure_width, figure_width * roc_curve_figure_aspect_ratio), dpi=300, ) + if len(preds_dfs) == 1: plot_single_decorated_precision_recall_curve( ax=ax, y_true=y_trues[0], y_score=y_preds[0], - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", n_bootstrap_samples=n_bootstrap_samples, ) else: @@ -151,24 +168,170 @@ def compute_stats_( ax=ax, y_trues=y_trues, y_scores=y_preds, - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", ) fig.tight_layout() - fig.savefig(output_dir / f"pr-curve_{ground_truth_label}={true_class}.svg") + fig.savefig(output_dir / f"pr-curve_{target_label}={true_class}.svg") plt.close(fig) - categorical_aggregated_( - preds_csvs=pred_csvs, - ground_truth_label=ground_truth_label, - outpath=output_dir, + # Compute aggregated statistics for all targets + categorical_aggregated_multitarget_( + preds_csvs=pred_csvs, + outpath=output_dir, + target_labels=target_labels, + ) + + +def compute_stats_( + *, + task: Task, + output_dir: Path, + pred_csvs: Sequence[Path], + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None, + true_class: str | None = None, + time_label: str | None = None, + status_label: str | None = None, +) -> None: + """Compute and save statistics for the provided task and prediction CSVs. + + This wrapper keeps the external API stable while delegating the detailed + computations and plotting to the submodules under `stamp.statistics.*`. + """ + match task: + case "classification": + # Check if multi-target based on ground_truth_label type + is_multitarget = ( + isinstance(ground_truth_label, (list, tuple)) + and len(ground_truth_label) > 1 ) + if is_multitarget: + # Multi-target classification + if not isinstance(ground_truth_label, (list, tuple)): + raise ValueError( + "ground_truth_label must be a list or tuple for multi-target classification" + ) + _compute_multitarget_classification_stats( + output_dir=output_dir, + pred_csvs=pred_csvs, + target_labels=list(ground_truth_label), + ) + else: + # Single-target classification (original behavior) + if true_class is None or ground_truth_label is None: + raise ValueError( + "both true_class and ground_truth_label are required in statistic configuration" + ) + if not isinstance(ground_truth_label, str): + raise ValueError( + "ground_truth_label must be a string for single-target classification" + ) + + preds_dfs = [ + _read_table( + p, + usecols=[ + ground_truth_label, + f"{ground_truth_label}_{true_class}", + ], + dtype={ + ground_truth_label: str, + f"{ground_truth_label}_{true_class}": float, + }, + ) + for p in pred_csvs + ] + + y_trues = [ + np.array(df[ground_truth_label] == true_class) for df in preds_dfs + ] + y_preds = [ + np.array(df[f"{ground_truth_label}_{true_class}"].values) + for df in preds_dfs + ] + n_bootstrap_samples = 1000 + figure_width = _Inches(3.8) + threshold_cmap = None + + roc_curve_figure_aspect_ratio = 1.08 + fig, ax = plt.subplots( + figsize=( + figure_width, + figure_width * roc_curve_figure_aspect_ratio, + ), + dpi=300, + ) + + if len(preds_dfs) == 1: + plot_single_decorated_roc_curve( + ax=ax, + y_true=y_trues[0], + y_score=y_preds[0], + title=f"{ground_truth_label} = {true_class}", + n_bootstrap_samples=n_bootstrap_samples, + threshold_cmap=threshold_cmap, + ) + else: + plot_multiple_decorated_roc_curves( + ax=ax, + y_trues=y_trues, + y_scores=y_preds, + title=f"{ground_truth_label} = {true_class}", + n_bootstrap_samples=None, + ) + + fig.tight_layout() + output_dir.mkdir(parents=True, exist_ok=True) + fig.savefig( + output_dir / f"roc-curve_{ground_truth_label}={true_class}.svg" + ) + plt.close(fig) + + fig, ax = plt.subplots( + figsize=( + figure_width, + figure_width * roc_curve_figure_aspect_ratio, + ), + dpi=300, + ) + if len(preds_dfs) == 1: + plot_single_decorated_precision_recall_curve( + ax=ax, + y_true=y_trues[0], + y_score=y_preds[0], + title=f"{ground_truth_label} = {true_class}", + n_bootstrap_samples=n_bootstrap_samples, + ) + else: + plot_multiple_decorated_precision_recall_curves( + ax=ax, + y_trues=y_trues, + y_scores=y_preds, + title=f"{ground_truth_label} = {true_class}", + ) + + fig.tight_layout() + fig.savefig( + output_dir / f"pr-curve_{ground_truth_label}={true_class}.svg" + ) + plt.close(fig) + + categorical_aggregated_( + preds_csvs=pred_csvs, + ground_truth_label=ground_truth_label, + outpath=output_dir, + ) + case "regression": if ground_truth_label is None: raise ValueError( "no ground_truth_label configuration supplied in statistic" ) + if not isinstance(ground_truth_label, str): + raise ValueError( + "ground_truth_label must be a string for regression (multi-target regression not yet supported)" + ) regression_aggregated_( preds_csvs=pred_csvs, ground_truth_label=ground_truth_label, diff --git a/src/stamp/statistics/categorical.py b/src/stamp/statistics/categorical.py index 30a03a86..9d6c4c12 100755 --- a/src/stamp/statistics/categorical.py +++ b/src/stamp/statistics/categorical.py @@ -62,29 +62,29 @@ def _categorical(preds_df: pd.DataFrame, target_label: str) -> pd.DataFrame: # roc_auc stats_df["roc_auc_score"] = [ - metrics.roc_auc_score(y_true == cat, y_pred[:, i]) # pyright: ignore[reportCallIssue,reportArgumentType] + metrics.roc_auc_score(y_true == cat, y_pred[:, i]) for i, cat in enumerate(categories) ] # average_precision stats_df["average_precision_score"] = [ - metrics.average_precision_score(y_true == cat, y_pred[:, i]) # pyright: ignore[reportCallIssue,reportArgumentType] + metrics.average_precision_score(y_true == cat, y_pred[:, i]) for i, cat in enumerate(categories) ] # f1 score y_pred_labels = categories[y_pred.argmax(axis=1)] stats_df["f1_score"] = [ - metrics.f1_score(y_true == cat, y_pred_labels == cat) # pyright: ignore[reportCallIssue,reportArgumentType] - for cat in categories + metrics.f1_score(y_true == cat, y_pred_labels == cat) for cat in categories ] # p values p_values = [] for i, cat in enumerate(categories): - pos_scores = y_pred[:, i][y_true == cat] # pyright: ignore[reportCallIssue,reportArgumentType] - neg_scores = y_pred[:, i][y_true != cat] # pyright: ignore[reportCallIssue,reportArgumentType] - p_values.append(st.ttest_ind(pos_scores, neg_scores).pvalue) # pyright: ignore[reportGeneralTypeIssues, reportAttributeAccessIssue] + pos_scores = y_pred[:, i][y_true == cat] + neg_scores = y_pred[:, i][y_true != cat] + _, p_value = st.ttest_ind(pos_scores, neg_scores) + p_values.append(p_value) stats_df["p_value"] = p_values assert set(_score_labels) & set(stats_df.columns) == set(_score_labels) diff --git a/src/stamp/utils/target_file.py b/src/stamp/utils/target_file.py deleted file mode 100644 index 08c30b6b..00000000 --- a/src/stamp/utils/target_file.py +++ /dev/null @@ -1,351 +0,0 @@ -"""Automatically generate target information from clini table - -# The `barspoon-targets 2.0` File Format - -A barspoon target file is a [TOML][1] file with the following entries: - - - A `version` key mapping to a version string `"barspoon-targets "`, where - `` is a [PEP-440 version string][2] compatible with `2.0`. - - A `targets` table, the keys of which are target labels (as found in the - clinical table) and the values specify exactly one of the following: - 1. A categorical target label, marked by the presence of a `categories` - key-value pair. - 2. A target label to quantize, marked by the presence of a `thresholds` - key-value pair. - 3. A target format defined in in a later version of barspoon targets. - A target may only ever have one of the fields `categories` or `thresholds`. - A definition of these entries can be found below. - -[1]: https://toml.io "Tom's Obvious Minimal Language" -[2]: https://peps.python.org/pep-0440/ - "PEP 440 - Version Identification and Dependency Specification" - -## Categorical Target Label - -A categorical target is a target table with a key-value pair `categories`. -`categories` contains a list of lists of literal strings. Each list of strings -will be treated as one category, with all literal strings within that list being -treated as one representative for that category. This allows the user to easily -group related classes into one large class (i.e. `"True", "1", "Yes"` could all -be unified into the same category). - -### Category Weights - -It is possible to assign a weight to each category, to e.g. weigh rarer classes -more heavily. The weights are stored in a table `targets.LABEL.class_weights`, -whose keys is the first representative of each category, and the values of which -is the weight of the category as a floating point number. - -## Target Label to Quantize - -If a target has the `thresholds` option key set, it is interpreted as a -continuous target which has to be quantized. `thresholds` has to be a list of -floating point numbers [t_0, t_n], n > 1 containing the thresholds of the bins -to quantize the values into. A categorical target will be quantized into bins - -```asciimath -b_0 = [t_0; t_1], b_1 = (t_1; b_2], ... b_(n-1) = (t_(n-1); t_n] -``` - -The bins will be treated as categories with names -`f"[{t_0:+1.2e};{t_1:+1.2e}]"` for the first bin and -`f"({t_i:+1.2e};{t_(i+1):+1.2e}]"` for all other bins - -To avoid confusion, we recommend to also format the `thresholds` list the same -way. - -The bins can also be weighted. See _Categorical Target Label: Category Weights_ -for details. - - > Experience has shown that many labels contain non-negative values with a - > disproportionate amount (more than n_samples/n_bins) of zeroes. We thus - > decided to make the _right_ side of each bin inclusive, as the bin (-A,0] - > then naturally includes those zero values. -""" - -import logging -from pathlib import Path -from typing import ( - Any, - Dict, - List, - NamedTuple, - Optional, - Sequence, - TextIO, - Tuple, -) - -import numpy as np -import numpy.typing as npt -import pandas as pd -import torch -import torch.nn.functional as F -from packaging.specifiers import Specifier - - -def read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: - if not isinstance(path, Path): - return pd.read_csv(path, **kwargs) - elif path.suffix == ".xlsx": - return pd.read_excel(path, **kwargs) - elif path.suffix == ".csv": - return pd.read_csv(path, **kwargs) - else: - raise ValueError( - "table to load has to either be an excel (`*.xlsx`) or csv (`*.csv`) file." - ) - - -__all__ = ["build_targets", "decode_targets"] - - -class TargetSpec(NamedTuple): - version: str - targets: Dict[str, Dict[str, Any]] - - -class EncodedTarget(NamedTuple): - categories: List[str] - encoded: torch.Tensor - weight: torch.Tensor - - -def encode_category( - *, - clini_df: pd.DataFrame, - target_label: str, - categories: Sequence[List[str]], - class_weights: Optional[Dict[str, float]] = None, - **ignored, -) -> Tuple[List[str], torch.Tensor, torch.Tensor]: - # Map each category to its index - category_map = {member: idx for idx, cat in enumerate(categories) for member in cat} - - # Map each item to it's category's index, mapping nans to num_classes+1 - # This way we can easily discard the NaN column later - indexes = clini_df[target_label].map(lambda c: category_map.get(c, len(categories))) - indexes = torch.tensor(indexes.values) - - # Discard nan column - one_hot = F.one_hot(indexes, num_classes=len(categories) + 1)[:, :-1] - - # Class weights - if class_weights is not None: - weight = torch.tensor([class_weights[c[0]] for c in categories]) - else: - # No class weights given; use normalized inverse frequency - counts = one_hot.sum(dim=0) - weight = (w := (counts.sum() / counts)) / w.sum() - - # Warn user of unused labels - if ignored: - logging.warn(f"ignored labels in target {target_label}: {ignored}") - - return [c[0] for c in categories], one_hot, weight - - -def encode_quantize( - *, - clini_df: pd.DataFrame, - target_label: str, - thresholds: npt.NDArray[np.floating[Any]], - class_weights: Optional[Dict[str, float]] = None, - **ignored, -) -> Tuple[List[str], torch.Tensor, torch.Tensor]: - # Warn user of unused labels - if ignored: - logging.warn(f"ignored labels in target {target_label}: {ignored}") - - n_bins = len(thresholds) - 1 - numeric_vals = torch.tensor(pd.to_numeric(clini_df[target_label]).values).reshape( - -1, 1 - ) - - # Map each value to a class index as follows: - # 1. If the value is NaN or less than the left-most threshold, use class - # index 0 - # 2. If it is between the left-most and the right-most threshold, set it to - # the bin number (starting from 1) - # 3. If it is larger than the right-most threshold, set it to N_bins + 1 - bin_index = ( - (numeric_vals > torch.tensor(thresholds).reshape(1, -1)).count_nonzero(1) - # For the first bucket, we have to include the lower threshold - + (numeric_vals.reshape(-1) == thresholds[0]) - ) - # One hot encode and discard nan columns (first and last col) - one_hot = F.one_hot(bin_index, num_classes=n_bins + 2)[:, 1:-1] - - # Class weights - categories = [ - f"[{thresholds[0]:+1.2e};{thresholds[1]:+1.2e}]", - *( - f"({lower:+1.2e};{upper:+1.2e}]" - for lower, upper in zip(thresholds[1:-1], thresholds[2:], strict=True) - ), - ] - - if class_weights is not None: - weight = torch.tensor([class_weights[c] for c in categories]) - else: - # No class weights given; use normalized inverse frequency - counts = one_hot.sum(0) - weight = (w := (np.divide(counts.sum(), counts, where=counts > 0))) / w.sum() - - return categories, one_hot, weight - - -def decode_targets( - encoded: torch.Tensor, - *, - target_labels: Sequence[str], - targets: Dict[str, Any], - version: str = "barspoon-targets 2.0", - **ignored, -) -> List[np.ndarray]: - name, version = version.split(" ") - spec = Specifier("~=2.0") - - if not (name == "barspoon-targets" and spec.contains(version)): - raise ValueError( - f"incompatible target file: expected barspoon-targets{spec}, found `{name} {version}`" - ) - - # Warn user of unused labels - if ignored: - logging.warn(f"ignored parameters: {ignored}") - - decoded_targets = [] - curr_col = 0 - for target_label in target_labels: - info = targets[target_label] - - if (categories := info.get("categories")) is not None: - # Add another column which is one iff all the other values are zero - encoded_target = encoded[:, curr_col : curr_col + len(categories)] - is_none = ~encoded_target.any(dim=1).view(-1, 1) - encoded_target = torch.cat([encoded_target, is_none], dim=1) - - # Decode to class labels - representatives = np.array([c[0] for c in categories] + [None]) - category_index = encoded_target.argmax(dim=1) - decoded = representatives[category_index] - decoded_targets.append(decoded) - - curr_col += len(categories) - - elif (thresholds := info.get("thresholds")) is not None: - n_bins = len(thresholds) - 1 - encoded_target = encoded[:, curr_col : curr_col + n_bins] - is_none = ~encoded_target.any(dim=1).view(-1, 1) - encoded_target = torch.cat([encoded_target, is_none], dim=1) - - bin_edges = [-np.inf, *thresholds, np.inf] - representatives = np.array( - [ - f"[{lower:+1.2e};{upper:+1.2e})" - for lower, upper in zip(bin_edges[:-1], bin_edges[1:]) - ] - ) - decoded = representatives[encoded_target.argmax(dim=1)] - - decoded_targets.append(decoded) - - curr_col += n_bins - - else: - raise ValueError(f"cannot decode {target_label}: no target info") - - return decoded_targets - - -def build_targets( - *, - clini_tables: Sequence[Path], - categorical_labels: Sequence[str], - category_min_count: int = 32, - quantize: Sequence[tuple[str, int]] = (), -) -> Dict[str, EncodedTarget]: - clini_df = pd.concat([read_table(c) for c in clini_tables]) - encoded_targets: Dict[str, EncodedTarget] = {} - - # categorical targets - for target_label in categorical_labels: - counts = clini_df[target_label].value_counts() - well_supported = counts[counts >= category_min_count] - - if len(well_supported) <= 1: - continue - - categories = [[str(cat)] for cat in well_supported.index] - - weights = well_supported.sum() / well_supported - weights /= weights.sum() - - representatives, encoded, weight = encode_category( - clini_df=clini_df, - target_label=target_label, - categories=categories, - class_weights=weights.to_dict(), - ) - - encoded_targets[target_label] = EncodedTarget( - categories=representatives, - encoded=encoded, - weight=weight, - ) - - # quantized targets - for target_label, bincount in quantize: - vals = pd.to_numeric(clini_df[target_label]).dropna() - - if vals.empty: - continue - - vals_clamped = vals.replace( - { - -np.inf: vals[vals != -np.inf].min(), - np.inf: vals[vals != np.inf].max(), - } - ) - - thresholds = np.array( - [ - -np.inf, - *np.quantile(vals_clamped, q=np.linspace(0, 1, bincount + 1))[1:-1], - np.inf, - ], - dtype=float, - ) - - representatives, encoded, weight = encode_quantize( - clini_df=clini_df, - target_label=target_label, - thresholds=thresholds, - ) - - if encoded.shape[1] <= 1: - continue - - encoded_targets[target_label] = EncodedTarget( - categories=representatives, - encoded=encoded, - weight=weight, - ) - - return encoded_targets - - -if __name__ == "__main__": - encoded = build_targets( - clini_tables=[ - Path( - "/mnt/bulk-neptune/nguyenmin/stamp-dev/experiments/survival_prediction/TCGA-CRC-DX_CLINI.xlsx" - ) - ], - categorical_labels=["BRAF", "KRAS", "NRAS"], - category_min_count=32, - quantize=[], - ) - for name, enc in encoded.items(): - print(name, enc.encoded.shape) diff --git a/tests/random_data.py b/tests/random_data.py index c7c36880..b79c9f42 100644 --- a/tests/random_data.py +++ b/tests/random_data.py @@ -74,13 +74,13 @@ def create_random_dataset( clini_df = pd.DataFrame( patient_to_ground_truth.items(), - columns=["patient", "ground-truth"], # pyright: ignore[reportArgumentType] + columns=["patient", "ground-truth"], ) clini_df.to_csv(clini_path, index=False) slide_df = pd.DataFrame( slide_path_to_patient.items(), - columns=["slide_path", "patient"], # pyright: ignore[reportArgumentType] + columns=["slide_path", "patient"], ) slide_df.to_csv(slide_path, index=False) @@ -130,7 +130,7 @@ def create_random_regression_dataset( # --- Write clini + slide tables --- clini_df = pd.DataFrame(patient_to_target, columns=["patient", "target"]) - clini_df["target"] = clini_df["target"].astype(float) # ✅ ensure numeric dtype + clini_df["target"] = clini_df["target"].astype(float) # ensure numeric dtype clini_df.to_csv(clini_path, index=False) slide_df = pd.DataFrame( diff --git a/tests/test_data.py b/tests/test_data.py index 3c86a931..a60d830e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -3,6 +3,7 @@ from pathlib import Path import h5py +import pandas as pd import pytest import torch from random_data import ( @@ -21,6 +22,7 @@ PatientFeatureDataset, filter_complete_patient_data_, get_coords, + patient_to_ground_truth_from_clini_table_, slide_to_patient_from_slide_table_, ) from stamp.types import ( @@ -85,6 +87,31 @@ def test_get_cohort_df(tmp_path: Path) -> None: } +def test_patient_to_ground_truth_multi_target(tmp_path: Path) -> None: + """Verify multi-target clini parsing returns dicts and drops rows missing all targets.""" + df = pd.DataFrame( + { + "patient": ["p1", "p2", "p3", "p4"], + "subtype": ["A", None, "B", None], + "grade": ["1", "2", None, None], + } + ) + df.to_csv(tmp_path / "clini.csv", index=False) + + result = patient_to_ground_truth_from_clini_table_( + clini_table_path=tmp_path / "clini.csv", + patient_label="patient", + ground_truth_label=["subtype", "grade"], + ) + + # p4 has both targets missing → dropped + assert "p4" not in result + + assert result["p1"] == {"subtype": "A", "grade": "1"} + assert result["p2"] == {"subtype": None, "grade": "2"} + assert result["p3"] == {"subtype": "B", "grade": None} + + @pytest.mark.parametrize( "feature_file_creator", [make_feature_file, make_old_feature_file], diff --git a/tests/test_deployment.py b/tests/test_deployment.py index 4e1570cc..7d1d6589 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -167,7 +167,7 @@ def test_to_prediction_df(task: str) -> None: ) if task == "classification": preds_df = _to_prediction_df( - categories=list(model.categories), # type: ignore + categories=list(cast(list, model.categories)), patient_to_ground_truth={ PatientId("pat5"): GroundTruth("foo"), PatientId("pat6"): None, @@ -196,13 +196,13 @@ def test_to_prediction_df(task: str) -> None: # Check if no loss / target is given for targets with missing ground truths no_ground_truth = preds_df[preds_df["patient"].isin(["pat6"])] - assert no_ground_truth["target"].isna().all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] - assert no_ground_truth["loss"].isna().all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] + assert no_ground_truth["target"].isna().all() + assert no_ground_truth["loss"].isna().all() # Check if loss / target is given for targets with ground truths with_ground_truth = preds_df[preds_df["patient"].isin(["pat5", "pat7"])] - assert (~with_ground_truth["target"].isna()).all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] - assert (~with_ground_truth["loss"].isna()).all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] + assert (~with_ground_truth["target"].isna()).all() + assert (~with_ground_truth["loss"].isna()).all() elif task == "regression": patient_to_ground_truth = {} From f08e33ad705e3d028df75c643d2b1cee00d8ebf6 Mon Sep 17 00:00:00 2001 From: mducducd Date: Tue, 17 Feb 2026 16:16:41 +0000 Subject: [PATCH 5/9] refactor survival training/validation --- src/stamp/modeling/crossval.py | 218 ++++++++++++++++++-------- src/stamp/modeling/data.py | 140 +++++++++++------ src/stamp/modeling/deploy.py | 14 +- src/stamp/modeling/models/__init__.py | 43 +++-- src/stamp/modeling/train.py | 173 ++++++++++++++++++-- src/stamp/types.py | 2 + 6 files changed, 453 insertions(+), 137 deletions(-) diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index 8ddfb03d..26196065 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -22,7 +22,7 @@ _to_survival_prediction_df, load_model_from_ckpt, ) -from stamp.modeling.train import setup_model_for_training, train_model_ +from stamp.modeling.train import setup_model_from_dataloaders, train_model_ from stamp.modeling.transforms import VaryPrecisionTransform from stamp.types import ( GroundTruth, @@ -80,19 +80,28 @@ def categorical_crossval_( # Generate the splits, or load them from the splits file if they already exist if not splits_file.exists(): - splits = ( - _get_splits( - patient_to_data=patient_to_data, - n_splits=config.n_splits, - spliter=KFold, - ) - if config.task == "regression" - else _get_splits( - patient_to_data=patient_to_data, - n_splits=config.n_splits, - spliter=StratifiedKFold, - ) + # Detect multi-target classification (ground_truth is a dict) + is_multitarget = any( + isinstance(pd.ground_truth, dict) for pd in patient_to_data.values() ) + + # Use KFold for regression or multi-target classification; otherwise StratifiedKFold. + # For survival we want StratifiedKFold so folds are balanced by event status. + spliter = ( + KFold + if (config.task == "regression" or is_multitarget) + else StratifiedKFold + ) + + _logger.info(f"Using {spliter.__name__} for cross-validation splits") + + splits = _get_splits( + patient_to_data=patient_to_data, + n_splits=config.n_splits, + spliter=spliter, + task=config.task, + ) + with open(splits_file, "w") as fp: fp.write(splits.model_dump_json(indent=4)) else: @@ -183,48 +192,93 @@ def categorical_crossval_( # Train the model if not (split_dir / "model.ckpt").exists(): - model, train_dl, valid_dl = setup_model_for_training( - clini_table=config.clini_table, - slide_table=config.slide_table, - feature_dir=config.feature_dir, - ground_truth_label=config.ground_truth_label, - time_label=config.time_label, - status_label=config.status_label, - advanced=advanced, - task=config.task, - patient_to_data={ - patient_id: patient_data - for patient_id, patient_data in patient_to_data.items() - if patient_id in split.train_patients - }, - categories=( - categories - if categories is not None - else ( - sorted( - { - patient_data.ground_truth - for patient_data in patient_to_data.values() - if patient_data.ground_truth is not None - and not isinstance(patient_data.ground_truth, dict) - } - ) - if not isinstance(config.ground_truth_label, Sequence) - else None + # Build train and test dataloaders directly (pure 2-way k-fold split) + train_patient_ids = [ + pid for pid in split.train_patients if pid in patient_to_data + ] + test_patient_ids = [ + pid for pid in split.test_patients if pid in patient_to_data + ] + train_patient_data = [patient_to_data[pid] for pid in train_patient_ids] + test_patient_data = [patient_to_data[pid] for pid in test_patient_ids] + + fold_categories = ( + categories + if categories is not None + else ( + sorted( + { + patient_data.ground_truth + for patient_data in patient_to_data.values() + if patient_data.ground_truth is not None + and not isinstance(patient_data.ground_truth, dict) + } ) - ), - train_transform=( - VaryPrecisionTransform(min_fraction_bits=1) - if config.use_vary_precision_transform + if not isinstance(config.ground_truth_label, Sequence) else None - ), + ) + ) + + train_transform = ( + VaryPrecisionTransform(min_fraction_bits=1) + if config.use_vary_precision_transform + else None + ) + + train_dl, train_categories = create_dataloader( + feature_type=feature_type, + task=config.task, + patient_data=train_patient_data, + bag_size=advanced.bag_size, + batch_size=advanced.batch_size, + shuffle=True, + num_workers=advanced.num_workers, + transform=train_transform, + categories=fold_categories, + ) + test_dl, _ = create_dataloader( + feature_type=feature_type, + task=config.task, + patient_data=test_patient_data, + bag_size=None, + batch_size=1, + shuffle=False, + num_workers=advanced.num_workers, + transform=None, + categories=train_categories, + ) + + # Infer feature dimension + batch = next(iter(train_dl)) + if feature_type == "tile": + bags, _, _, _ = batch + dim_feats = bags.shape[-1] + else: + feats, _ = batch + dim_feats = feats.shape[-1] + + model = setup_model_from_dataloaders( + train_dl=train_dl, + valid_dl=test_dl, + task=config.task, + train_categories=train_categories, + dim_feats=dim_feats, + train_patients=train_patient_ids, + valid_patients=test_patient_ids, feature_type=feature_type, + advanced=advanced, + ground_truth_label=config.ground_truth_label, + time_label=config.time_label, + status_label=config.status_label, + clini_table=config.clini_table, + slide_table=config.slide_table, + feature_dir=config.feature_dir, ) model = train_model_( output_dir=split_dir, model=model, train_dl=train_dl, - valid_dl=valid_dl, + valid_dl=test_dl, max_epochs=advanced.max_epochs, patience=advanced.patience, accelerator=advanced.accelerator, @@ -235,9 +289,9 @@ def categorical_crossval_( else: model = load_model_from_ckpt(split_dir / "model.ckpt") - # Deploy on test set + # Deploy on test fold (used as validation/prediction set) if not (split_dir / "patient-preds.csv").exists(): - # Prepare test dataloader + # Prepare validation dataloader for predictions test_patients = [ pid for pid in split.test_patients if pid in patient_to_data ] @@ -262,19 +316,24 @@ def categorical_crossval_( ) if config.task == "survival": - if isinstance(config.ground_truth_label, str): + # Export only when patients have single-target survival labels ("time status"). + # Don't rely on `ground_truth_label` for survival — check the loaded patient data. + if any(isinstance(gt, dict) for gt in patient_to_ground_truth.values()): + _logger.warning( + "Multi-target survival prediction export not yet supported; skipping CSV save" + ) + else: _to_survival_prediction_df( patient_to_ground_truth=cast( - Mapping[PatientId, str | None], patient_to_ground_truth + Mapping[ + PatientId, str | tuple[float | None, int | None] | None + ], + patient_to_ground_truth, ), predictions=cast(Mapping[PatientId, torch.Tensor], predictions), patient_label=config.patient_label, cut_off=getattr(model.hparams, "train_pred_median", None), ).to_csv(split_dir / "patient-preds.csv", index=False) - else: - _logger.warning( - "Multi-target survival prediction export not yet supported; skipping CSV save" - ) elif config.task == "regression": if config.ground_truth_label is None: raise RuntimeError("Grounf truth label is required for regression") @@ -310,27 +369,56 @@ def categorical_crossval_( def _get_splits( - *, patient_to_data: Mapping[PatientId, PatientData[Any]], n_splits: int, spliter + *, + patient_to_data: Mapping[PatientId, PatientData[Any]], + n_splits: int, + spliter, + task: str | None = None, ) -> _Splits: patients = np.array(list(patient_to_data.keys())) - # Extract ground truth for stratification. - # For multi-target (dict), use the first target's value - y_strat = np.array( - [ - next(iter(gt.values())) if isinstance(gt, dict) else gt - for gt in [patient.ground_truth for patient in patient_to_data.values()] - ] - ) + # Build stratification labels depending on the task + gts = [patient.ground_truth for patient in patient_to_data.values()] + + if task == "survival": + # use event status (0/1) for stratification + # Ground-truths are expected to be pre-parsed into (time, status) tuples + # by the data loading pipeline; extract the status element directly. + statuses: list[int] = [] + for gt in gts: + # support multi-target fallback: use first target + val = next(iter(gt.values())) if isinstance(gt, dict) else gt + # If structured (time, status) use second element, otherwise assume val is status + if isinstance(val, (tuple, list)) and len(val) == 2: + status_val = val[1] + else: + status_val = val + statuses.append(int(cast(int, status_val)) if status_val is not None else 0) + + y_strat = np.array(statuses) + elif task == "classification": + # For multi-target (dict), use the first target's value + y_strat = np.array( + [next(iter(gt.values())) if isinstance(gt, dict) else gt for gt in gts] + ) + else: + # regression or unknown: do not stratify (KFold will ignore y) + y_strat = None skf = spliter(n_splits=n_splits, shuffle=True, random_state=0) + + if y_strat is None: + splits_iter = skf.split(patients) + else: + splits_iter = skf.split(patients, y_strat) + splits = _Splits( splits=[ _Split( train_patients=set(patients[train_indices]), test_patients=set(patients[test_indices]), ) - for train_indices, test_indices in skf.split(patients, y_strat) + for train_indices, test_indices in splits_iter ] ) return splits diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index eadb42f8..6ddc00fd 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -121,6 +121,7 @@ def tile_bag_dataloader( bag_size=bag_size, ground_truths=targets, transform=transform, + deterministic=(not shuffle), ) dl = DataLoader( ds, @@ -229,9 +230,20 @@ def _parse_targets( events.append(np.nan) continue - time_str, status_str = gt.split(" ", 1) - times.append(np.nan if time_str.lower() == "nan" else float(time_str)) - events.append(_parse_survival_status(status_str)) + # Accept either structured tuple/list (time, event) or the legacy + # string form "time status". + if isinstance(gt, (tuple, list)) and len(gt) == 2: + t_val, e_val = gt + times.append( + np.nan + if t_val is None or str(t_val).lower() == "nan" + else float(t_val) + ) + events.append(float(e_val) if e_val is not None else np.nan) + else: + time_str, status_str = str(gt).split(" ", 1) + times.append(np.nan if time_str.lower() == "nan" else float(time_str)) + events.append(_parse_survival_status(status_str)) y = torch.tensor(np.column_stack([times, events]), dtype=torch.float32) return y, [] @@ -356,22 +368,20 @@ def create_dataloader( if gt is None: continue if isinstance(gt, dict): - # Use first value for multi-target regression - first_val = next(iter(gt.values())) - values.append(float(first_val)) - else: - values.append(float(gt)) + raise ValueError( + "Multi-target regression is not supported; provide a single numeric target per patient" + ) + values.append(float(gt)) labels = torch.tensor(values, dtype=torch.float32).reshape(-1, 1) elif task == "survival": times, events = [], [] for p in patient_data: if isinstance(p.ground_truth, dict): - # Multi-target survival: use first target - val = list(p.ground_truth.values())[0] - t, e = (val or "nan nan").split(" ", 1) - else: - t, e = (p.ground_truth or "nan nan").split(" ", 1) + raise ValueError( + "Multi-target survival is not supported; provide a single survival time/status per patient" + ) + t, e = (p.ground_truth or "nan nan").split(" ", 1) times.append(float(t) if t.lower() != "nan" else np.nan) events.append(_parse_survival_status(e)) @@ -449,6 +459,15 @@ def load_patient_level_data( """ # Load ground truth mapping + # Multi-target ground truths are only supported for classification. + if task is not None and task != "classification": + if isinstance(ground_truth_label, Sequence) and not isinstance( + ground_truth_label, str + ): + raise ValueError( + "Multi-target ground_truth_label is only supported for classification tasks" + ) + if task == "survival" and time_label is not None and status_label is not None: # Survival: use the existing helper patient_to_ground_truth = patient_to_survival_from_clini_table_( @@ -519,6 +538,7 @@ class BagDataset(Dataset[tuple[_Bag, _Coordinates, BagSize, _EncodedTarget]]): # """The ground truth for each bag, one-hot encoded.""" transform: Callable[[Tensor], Tensor] | None + deterministic: bool = False def __post_init__(self) -> None: if len(self.bags) != len(self.ground_truths): @@ -564,7 +584,12 @@ def __getitem__( # Sample a subset, if required if self.bag_size is not None: return ( - *_to_fixed_size_bag(feats, coords=coords_um, bag_size=self.bag_size), + *_to_fixed_size_bag( + feats, + coords=coords_um, + bag_size=self.bag_size, + deterministic=self.deterministic, + ), self.ground_truths[index], ) else: @@ -697,15 +722,7 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: ) if not tile_size_px and "tile_size_px" in feature_h5.attrs: - tile_size_px_attr = feature_h5.attrs.get("tile_size_px") - if tile_size_px_attr is not None and isinstance( - tile_size_px_attr, (int, float) - ): - tile_size_px = TilePixels(int(tile_size_px_attr)) - else: - raise RuntimeError( - "Invalid or missing 'tile_size_px' attribute in the feature file." - ) + tile_size_px = TilePixels(int(feature_h5.attrs["tile_size_px"])) # pyright: ignore[reportArgumentType] if not tile_size_um or coords_um is None: raise RuntimeError( @@ -716,7 +733,7 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: def _to_fixed_size_bag( - bag: _Bag, coords: _Coordinates, bag_size: BagSize + bag: _Bag, coords: _Coordinates, bag_size: BagSize, deterministic: bool = False ) -> tuple[_Bag, _Coordinates, BagSize]: """Samples a fixed-size bag of tiles from an arbitrary one. @@ -725,23 +742,47 @@ def _to_fixed_size_bag( """ # get up to bag_size elements n_tiles, _dim_feats = bag.shape - bag_idxs = torch.randperm(n_tiles)[:bag_size] + if n_tiles <= bag_size: + # take all and pad later + bag_idxs = torch.arange(n_tiles, device=bag.device) + else: + if deterministic: + # equidistant indices across the bag + idxs = torch.linspace(0, n_tiles - 1, steps=bag_size, device=bag.device) + bag_idxs = idxs.round().long() + else: + bag_idxs = torch.randperm(n_tiles, device=bag.device)[:bag_size] + bag_samples = bag[bag_idxs] coord_samples = coords[bag_idxs] # zero-pad if we don't have enough samples - zero_padded_bag = torch.cat( - ( - bag_samples, - torch.zeros(bag_size - bag_samples.shape[0], bag_samples.shape[1]), + if bag_samples.shape[0] < bag_size: + zero_padded_bag = torch.cat( + ( + bag_samples, + torch.zeros( + bag_size - bag_samples.shape[0], + bag_samples.shape[1], + device=bag.device, + dtype=bag.dtype, + ), + ) ) - ) - zero_padded_coord = torch.cat( - ( - coord_samples, - torch.zeros(bag_size - coord_samples.shape[0], coord_samples.shape[1]), + zero_padded_coord = torch.cat( + ( + coord_samples, + torch.zeros( + bag_size - coord_samples.shape[0], + coord_samples.shape[1], + device=coords.device, + dtype=coords.dtype, + ), + ) ) - ) + else: + zero_padded_bag = bag_samples + zero_padded_coord = coord_samples return zero_padded_bag, zero_padded_coord, min(bag_size, len(bag)) @@ -822,13 +863,14 @@ def patient_to_survival_from_clini_table_( patient_label: PandasLabel, time_label: PandasLabel, status_label: PandasLabel, -) -> dict[PatientId, GroundTruth]: +) -> dict[PatientId, tuple[float | None, int | None]]: """ Loads patients and their survival ground truths (time + event) from a clini table. Returns dict[PatientId, GroundTruth] - Mapping patient_id -> "time status" (e.g. "302 dead", "476 alive"). + Mapping patient_id -> tuple (time, event) where `time` is a numeric follow-up + value and `event` is an integer indicator (1=event/death, 0=censored). """ clini_df = read_table( clini_table_path, @@ -864,7 +906,7 @@ def patient_to_survival_from_clini_table_( # Only drop rows where BOTH time and status are missing clini_df = clini_df.dropna(subset=[time_label, status_label], how="all") - patient_to_ground_truth: dict[PatientId, GroundTruth] = {} + patient_to_ground_truth: dict[PatientId, tuple[float | None, int | None]] = {} for _, row in clini_df.iterrows(): pid = row[patient_label] time_str = row[time_label] @@ -877,10 +919,9 @@ def patient_to_survival_from_clini_table_( # Encode status: keep both dead (event=1) and alive (event=0) status = _parse_survival_status(status_str) - # Encode back to "alive"/"dead" like before - # status = "dead" if status_val == 1 else "alive" - - patient_to_ground_truth[pid] = f"{time_str} {status}" + # Store structured ground-truth as (time, event) to avoid repeated string parsing + time_val = None if pd.isna(time_str) else float(time_str) + patient_to_ground_truth[pid] = (time_val, status) return patient_to_ground_truth @@ -940,7 +981,7 @@ def read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: def filter_complete_patient_data_( *, patient_to_ground_truth: Mapping[ - PatientId, GroundTruth | dict[str, GroundTruth] | None + PatientId, GroundTruth | dict[str, GroundTruth] | tuple[float | None, int | None] | None ], slide_to_patient: Mapping[FeaturePath, PatientId], drop_patients_with_missing_ground_truth: bool, @@ -1021,8 +1062,10 @@ def _log_patient_slide_feature_inconsistencies( if slides_without_features := { slide for slide in slide_to_patient.keys() if not slide.exists() }: + slides_list = sorted(str(s) for s in slides_without_features) _logger.warning( - f"some feature files could not be found: {slides_without_features}" + "some feature files could not be found: %s", + ", ".join(slides_list), ) @@ -1131,6 +1174,15 @@ def load_patient_data_( raise ValueError( "Ground truth label is required for classification or regression modeling" ) + # Disallow multi-target ground truth for non-classification tasks + if ( + task != "classification" + and isinstance(ground_truth_label, Sequence) + and not isinstance(ground_truth_label, str) + ): + raise ValueError( + "Multi-target ground_truth_label is only supported for classification tasks" + ) patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( clini_table_path=clini_table, ground_truth_label=ground_truth_label, diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index d3b29ebd..c61a2512 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -20,7 +20,13 @@ slide_to_patient_from_slide_table_, ) from stamp.modeling.registry import ModelName, load_model_class -from stamp.types import Category, GroundTruth, PandasLabel, PatientId +from stamp.types import ( + Category, + GroundTruth, + PandasLabel, + PatientId, + SurvivalGroundTruth, +) __all__ = ["deploy_categorical_model_"] @@ -218,7 +224,7 @@ def deploy_categorical_model_( } patient_to_data = filter_complete_patient_data_( patient_to_ground_truth=cast( - Mapping[PatientId, GroundTruth | None], + Mapping[PatientId, GroundTruth | dict[str, GroundTruth] | None], patient_to_ground_truth, ), slide_to_patient=slide_to_patient, @@ -622,7 +628,9 @@ def _to_regression_prediction_df( def _to_survival_prediction_df( *, - patient_to_ground_truth: Mapping[PatientId, GroundTruth | None], + patient_to_ground_truth: Mapping[ + PatientId, GroundTruth | SurvivalGroundTruth | None + ], predictions: Mapping[PatientId, torch.Tensor], patient_label: PandasLabel, cut_off: float | None = None, diff --git a/src/stamp/modeling/models/__init__.py b/src/stamp/modeling/models/__init__.py index 0b6a3885..b5a59b5f 100644 --- a/src/stamp/modeling/models/__init__.py +++ b/src/stamp/modeling/models/__init__.py @@ -6,7 +6,9 @@ from typing import Any, TypeAlias import lightning +import numpy as np import torch +from lifelines.utils import concordance_index as lifelines_cindex # Use beartype.typing.Mapping to avoid PEP-585 deprecation warnings in beartype from beartype.typing import Mapping @@ -653,27 +655,36 @@ def cox_loss( def c_index( scores: torch.Tensor, times: torch.Tensor, events: torch.Tensor ) -> torch.Tensor: - # """ - # Concordance index: proportion of correctly ordered comparable pairs. - # """ - N = len(times) - if N <= 1: + """ + Concordance index using lifelines implementation for consistency with statistics. + + Uses the same convention as stamp.statistics.survival: + - Higher risk scores should correspond to shorter survival (worse outcome). + - We negate scores so lifelines interprets higher values as longer survival. + """ + # Convert to numpy for lifelines + scores_np = scores.detach().cpu().numpy().flatten() + times_np = times.detach().cpu().numpy().flatten() + events_np = events.detach().cpu().numpy().flatten() + + # Filter out NaN values + valid_mask = ~(np.isnan(times_np) | np.isnan(events_np) | np.isnan(scores_np)) + if valid_mask.sum() <= 1: return torch.tensor(float("nan"), device=scores.device) - t_i = times.view(-1, 1).expand(N, N) - t_j = times.view(1, -1).expand(N, N) - e_i = events.view(-1, 1).expand(N, N) + times_np = times_np[valid_mask] + events_np = events_np[valid_mask] + scores_np = scores_np[valid_mask] - mask = (t_i < t_j) & e_i.bool() - if mask.sum() == 0: + # Use lifelines concordance_index with negated risk (same as statistics module) + # lifelines expects: higher predicted value = longer survival + # Cox outputs: higher risk = shorter survival, so we negate + try: + ci = lifelines_cindex(times_np, -scores_np, events_np) + except Exception: return torch.tensor(float("nan"), device=scores.device) - s_i = scores.view(-1, 1).expand(N, N)[mask] - s_j = scores.view(1, -1).expand(N, N)[mask] - - conc = (s_i > s_j).float() - ties = (s_i == s_j).float() * 0.5 - return (conc + ties).sum() / mask.sum() + return torch.tensor(ci, device=scores.device, dtype=scores.dtype) def on_validation_epoch_end(self): if ( diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index fb2bdb3b..7031eed6 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -17,6 +17,7 @@ BagDataset, PatientData, PatientFeatureDataset, + _parse_survival_status, create_dataloader, load_patient_data_, ) @@ -154,11 +155,21 @@ def setup_model_for_training( f"No model specified, defaulting to '{advanced.model_name.value}' for feature type '{feature_type}'" ) + # Prevent selecting `barspoon` for single-target classification + if ( + task == "classification" + and isinstance(ground_truth_label, str) + and advanced.model_name == ModelName.BARSPOON + ): + raise ValueError( + "Model 'barspoon' requires multi-target classification. " + "For single-target classification set model_name to 'vit', 'trans_mil', or 'mlp'." + ) + # 2. Instantiate the lightning wrapper (based on provided task, feature type) and model backbone dynamically LitModelClass, ModelClass = load_model_class( task, feature_type, advanced.model_name ) - print(f"Using Lightning wrapper class: {LitModelClass}") # 3. Validate that the chosen model supports the feature type if feature_type not in LitModelClass.supported_features: @@ -220,6 +231,124 @@ def setup_model_for_training( return model, train_dl, valid_dl +def setup_model_from_dataloaders( + *, + train_dl: DataLoader, + valid_dl: DataLoader, + task: Task, + train_categories: Sequence[Category] | Mapping[str, Sequence[Category]], + dim_feats: int, + train_patients: Sequence[PatientId], + valid_patients: Sequence[PatientId], + feature_type: str, + advanced: AdvancedConfig, + # Metadata, has no effect on model training + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, + time_label: PandasLabel | None, + status_label: PandasLabel | None, + clini_table: Path, + slide_table: Path | None, + feature_dir: Path, +) -> lightning.LightningModule: + """Creates a model from pre-built dataloaders (no internal split).""" + + _logger.info( + "Training dataloaders: task=%s, feature_type=%s", + task, + feature_type, + ) + + category_weights: torch.Tensor | dict[str, torch.Tensor] | list = [] + if task == "classification": + category_weights = _compute_class_weights_and_check_categories( + train_dl=train_dl, + feature_type=feature_type, + train_categories=train_categories, + ) + + # 1. Default to a model if none is specified + if advanced.model_name is None: + advanced.model_name = ModelName.VIT if feature_type == "tile" else ModelName.MLP + _logger.info( + f"No model specified, defaulting to '{advanced.model_name.value}' for feature type '{feature_type}'" + ) + + # Prevent selecting `barspoon` for single-target classification + if ( + task == "classification" + and isinstance(ground_truth_label, str) + and advanced.model_name == ModelName.BARSPOON + ): + raise ValueError( + "Model 'barspoon' requires multi-target classification. " + "For single-target classification set model_name to 'vit', 'trans_mil', or 'mlp'." + ) + + # 2. Instantiate the lightning wrapper (based on provided task, feature type) and model backbone dynamically + LitModelClass, ModelClass = load_model_class( + task, feature_type, advanced.model_name + ) + + # 3. Validate that the chosen model supports the feature type + if feature_type not in LitModelClass.supported_features: + raise ValueError( + f"Model '{advanced.model_name.value}' does not support feature type '{feature_type}'. " + f"Supported types are: {LitModelClass.supported_features}" + ) + elif ( + feature_type in ("slide", "patient") + and advanced.model_name.value.lower() != "mlp" + ): + raise ValueError( + f"Feature type '{feature_type}' only supports MLP backbones. " + f"Got '{advanced.model_name.value}'. Please set model_name='mlp'." + ) + + # 4. Get model-specific hyperparameters + model_specific_params = ( + advanced.model_params.model_dump().get(advanced.model_name.value) or {} + ) + + # 5. Calculate total steps for scheduler + steps_per_epoch = len(train_dl) + total_steps = steps_per_epoch * advanced.max_epochs + + # 6. Prepare common parameters + common_params = { + "categories": train_categories, + "category_weights": category_weights, + "dim_input": dim_feats, + "total_steps": total_steps, + "max_lr": advanced.max_lr, + "div_factor": advanced.div_factor, + # Metadata, has no effect on model training + "model_name": advanced.model_name.value, + "ground_truth_label": ground_truth_label, + "time_label": time_label, + "status_label": status_label, + "train_patients": train_patients, + "valid_patients": valid_patients, + "clini_table": clini_table, + "slide_table": slide_table, + "feature_dir": feature_dir, + } + + all_params = {**common_params, **model_specific_params} + + _logger.info( + f"Instantiating model '{advanced.model_name.value}' with parameters: {model_specific_params}" + ) + _logger.info( + "Other params: max_epochs=%s, patience=%s", + advanced.max_epochs, + advanced.patience, + ) + + model = LitModelClass(model_class=ModelClass, **all_params) + + return model + + def setup_dataloaders_for_training( *, patient_to_data: Mapping[PatientId, PatientData[GroundTruth | None | dict]], @@ -259,6 +388,12 @@ def setup_dataloaders_for_training( "patient_to_data must have a ground truth defined for all targets!" ) + # Multi-target ground truths are only supported for classification + if task != "classification" and any(isinstance(gt, dict) for gt in ground_truths): + raise ValueError( + "Multi-target ground truths are only supported for classification tasks" + ) + if task == "classification": # Handle both single and multi-target cases if ground_truths and isinstance(ground_truths[0], dict): @@ -268,17 +403,37 @@ def setup_dataloaders_for_training( else: stratify = ground_truths elif task == "survival": - # Extract event indicator (status) - handle both single and multi-target - statuses = [] + # Extract event indicator (status). Accept either structured (time,event) + # or legacy string "time status" formats. + statuses: list[int] = [] for gt in ground_truths: if isinstance(gt, dict): - # Multi-target survival: extract from first target - first_key = list(gt.keys())[0] - val = cast(dict, gt)[first_key] - if val: - statuses.append(int(val.split()[1])) + raise ValueError( + "Multi-target survival is not supported; provide a single survival time/status per patient" + ) + if isinstance(gt, (tuple, list)) and len(gt) == 2: + status_val = gt[1] + if status_val is None: + raise ValueError( + "Missing survival status for a patient; cannot stratify" + ) + statuses.append(int(status_val)) else: - statuses.append(int(gt.split()[1])) + parts = str(gt).split() + if len(parts) >= 2: + status_token = parts[1] + elif len(parts) == 1: + status_token = parts[0] + else: + raise ValueError( + "Unrecognized survival ground-truth format for stratification" + ) + parsed_status = _parse_survival_status(status_token) + if parsed_status is None: + raise ValueError( + f"Unrecognized survival status token for stratification: {status_token!r}" + ) + statuses.append(int(parsed_status)) stratify = statuses elif task == "regression": stratify = None diff --git a/src/stamp/types.py b/src/stamp/types.py index c1ff6873..fddb9d4c 100644 --- a/src/stamp/types.py +++ b/src/stamp/types.py @@ -37,6 +37,8 @@ PatientId: TypeAlias = str GroundTruth: TypeAlias = str +# Survival ground-truth is represented as (time, event) +SurvivalGroundTruth: TypeAlias = tuple[float | None, int | None] MultiClassGroundTruth: TypeAlias = tuple[str, ...] FeaturePath = NewType("FeaturePath", Path) From 2daf92193c0fd70cc0f0e6285f09bb11497ff5e1 Mon Sep 17 00:00:00 2001 From: mducducd Date: Tue, 17 Feb 2026 16:17:02 +0000 Subject: [PATCH 6/9] refactor survival training/validation --- src/stamp/modeling/data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 6ddc00fd..bb295bcf 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -981,7 +981,8 @@ def read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: def filter_complete_patient_data_( *, patient_to_ground_truth: Mapping[ - PatientId, GroundTruth | dict[str, GroundTruth] | tuple[float | None, int | None] | None + PatientId, + GroundTruth | dict[str, GroundTruth] | tuple[float | None, int | None] | None, ], slide_to_patient: Mapping[FeaturePath, PatientId], drop_patients_with_missing_ground_truth: bool, From 95be87c48ec5fc30989802be854ade4b3e95267c Mon Sep 17 00:00:00 2001 From: mducducd Date: Thu, 12 Mar 2026 11:55:28 +0000 Subject: [PATCH 7/9] core: squash updates from 5785e37 to latest --- getting-started.md | 6 + pyproject.toml | 2 +- src/stamp/__main__.py | 42 +- src/stamp/config.yaml | 5 +- src/stamp/encoding/encoder/__init__.py | 6 +- src/stamp/encoding/encoder/eagle.py | 4 +- src/stamp/encoding/encoder/gigapath.py | 72 +- src/stamp/encoding/encoder/titan.py | 10 +- src/stamp/heatmaps/__init__.py | 2 +- src/stamp/modeling/crossval.py | 4 +- src/stamp/modeling/data.py | 183 +++-- src/stamp/modeling/deploy.py | 48 +- src/stamp/modeling/models/__init__.py | 60 +- src/stamp/modeling/models/barspoon.py | 8 +- .../modeling/models/vision_tranformer.py | 4 +- src/stamp/modeling/train.py | 12 +- src/stamp/preprocessing/__init__.py | 51 +- src/stamp/preprocessing/config.py | 10 +- src/stamp/preprocessing/extractor/keep.py | 49 ++ src/stamp/preprocessing/extractor/reddino.py | 64 ++ src/stamp/preprocessing/extractor/ticon.py | 739 ++++++++++++++++++ src/stamp/preprocessing/tiling.py | 4 +- src/stamp/statistics/categorical.py | 14 +- src/stamp/statistics/prc.py | 13 +- src/stamp/statistics/roc.py | 10 +- src/stamp/statistics/survival.py | 17 +- src/stamp/utils/cache.py | 48 +- tests/conftest.py | 10 + tests/test_deployment.py | 12 +- tests/test_train_deploy.py | 84 +- uv.lock | 2 +- 31 files changed, 1291 insertions(+), 304 deletions(-) create mode 100644 src/stamp/preprocessing/extractor/keep.py create mode 100644 src/stamp/preprocessing/extractor/reddino.py create mode 100644 src/stamp/preprocessing/extractor/ticon.py diff --git a/getting-started.md b/getting-started.md index 93f1a0e7..9e1069cc 100644 --- a/getting-started.md +++ b/getting-started.md @@ -58,6 +58,9 @@ Stamp currently supports the following feature extractors: - [mSTAR][mstar] - [MUSK][musk] - [PLIP][plip] + - [KEEP][keep] + - [TICON][ticon] + - [RedDino][reddino] As some of the above require you to request access to the model on huggingface, @@ -153,11 +156,14 @@ meaning ignored that it was ignored during feature extraction. [mstar]: https://huggingface.co/Wangyh/mSTAR [musk]: https://huggingface.co/xiangjx/musk [plip]: https://github.com/PathologyFoundation/plip +[keep]: https://loiesun.github.io/keep/ "A Knowledge-enhanced Pathology Vision-language Foundation Model for Cancer Diagnosis" [TITAN]: https://huggingface.co/MahmoodLab/TITAN [COBRA2]: https://huggingface.co/KatherLab/COBRA [EAGLE]: https://github.com/KatherLab/EAGLE [MADELEINE]: https://huggingface.co/MahmoodLab/madeleine [PRISM]: https://huggingface.co/paige-ai/Prism +[TICON]: https://cvlab-stonybrook.github.io/TICON/ "TICON: A Slide-Level Tile Contextualizer for Histopathology Representation Learning" +[reddino]: https://github.com/Snarci/RedDino "RedDino: A Foundation Model for Red Blood Cell Analysis" diff --git a/pyproject.toml b/pyproject.toml index fc79a26b..ce840af8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "stamp" -version = "2.4.0" +version = "2.5.0" authors = [ { name = "Omar El Nahhas", email = "omar.el_nahhas@tu-dresden.de" }, { name = "Marko van Treeck", email = "markovantreeck@gmail.com" }, diff --git a/src/stamp/__main__.py b/src/stamp/__main__.py index ffa98bae..0b252d6f 100755 --- a/src/stamp/__main__.py +++ b/src/stamp/__main__.py @@ -6,15 +6,6 @@ import yaml -from stamp.modeling.config import ( - AdvancedConfig, - MlpModelParams, - ModelParams, - VitModelParams, -) -from stamp.utils.config import StampConfig -from stamp.utils.seed import Seed - STAMP_FACTORY_SETTINGS = Path(__file__).with_name("config.yaml") # Set up the logger @@ -41,23 +32,38 @@ def _create_config_file(config_file: Path) -> None: def _run_cli(args: argparse.Namespace) -> None: - # Handle init command + # Handle init command before any stamp-internal imports so that + # `stamp init` and `stamp --help` don't pay the full torch/pydantic + # import cost. if args.command == "init": _create_config_file(args.config_file_path) return + # Deferred imports: only reached for real commands, not --help / init. + from stamp.modeling.config import ( + AdvancedConfig, + MlpModelParams, + ModelParams, + VitModelParams, + ) + from stamp.utils.config import StampConfig + from stamp.utils.seed import Seed + # Load YAML configuration with open(args.config_file_path, "r") as config_yaml: config = StampConfig.model_validate(yaml.safe_load(config_yaml)) - # use default advanced config in case none is provided - if config.advanced_config is None: - config.advanced_config = AdvancedConfig( - model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams()), - ) + # Only build a default AdvancedConfig (with model-params) for commands + # that actually use it. Preprocess / encode / statistics / heatmaps + # never touch config.advanced_config, so don't pay the construction cost. + if args.command in {"train", "crossval"}: + if config.advanced_config is None: + config.advanced_config = AdvancedConfig( + model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams()), + ) - # Set global random seed - if config.advanced_config.seed is not None: + # Apply the global seed for any command that has one configured. + if config.advanced_config is not None and config.advanced_config.seed is not None: Seed.set(config.advanced_config.seed) match args.command: @@ -153,6 +159,7 @@ def _run_cli(args: argparse.Namespace) -> None: if config.training.task is None: raise ValueError("task must be set in training configuration") + assert config.advanced_config is not None # guaranteed above for "train" train_categorical_model_( config=config.training, advanced=config.advanced_config ) @@ -198,6 +205,7 @@ def _run_cli(args: argparse.Namespace) -> None: f"{yaml.dump(config.crossval.model_dump(mode='json', exclude_none=True))}" ) + assert config.advanced_config is not None # guaranteed above for "crossval" categorical_crossval_( config=config.crossval, advanced=config.advanced_config, diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index 4f16dcb2..80d105bf 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -4,7 +4,7 @@ preprocessing: # Extractor to use for feature extractor. Possible options are "ctranspath", # "uni", "conch", "chief-ctranspath", "conch1_5", "uni2", "dino-bloom", # "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow", - # "virchow-full", "musk", "mstar", "plip", "ticon" + # "virchow-full", "musk", "mstar", "plip", "ticon", "red-dino", "keep" # Some of them require requesting access to the respective authors beforehand. extractor: "chief-ctranspath" @@ -26,6 +26,9 @@ preprocessing: #tile_size_um: 256.0 #tile_size_px: 224 + # Magnification level to extract tiles from + #default_slide_mpp: 1.0 + # How many workers to use for tile extraction. Should be less or equal to # the number of cores of your system. #max_workers: 8 diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index 4720ef9b..3b4c3ac4 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -62,7 +62,8 @@ def encode_slides_( if self.precision == torch.float16: self.model.half() - for tile_feats_filename in (progress := tqdm(os.listdir(feat_dir))): + h5_files = sorted(f for f in os.listdir(feat_dir) if f.endswith(".h5")) + for tile_feats_filename in (progress := tqdm(h5_files)): h5_path = os.path.join(feat_dir, tile_feats_filename) slide_name: str = Path(tile_feats_filename).stem progress.set_description(slide_name) @@ -185,7 +186,8 @@ def _read_h5( raise ValueError(f"File is not of type .h5: {os.path.basename(h5_path)}") with h5py.File(h5_path, "r") as f: feats_ds = cast(h5py.Dataset, f["feats"]) - feats: Tensor = torch.tensor(feats_ds[:], dtype=self.precision) + # torch.from_numpy avoids a redundant data copy vs torch.tensor(array) + feats: Tensor = torch.from_numpy(feats_ds[()]).to(dtype=self.precision) coords: CoordsInfo = get_coords(f) extractor: str = f.attrs.get("extractor", "") if extractor == "": diff --git a/src/stamp/encoding/encoder/eagle.py b/src/stamp/encoding/encoder/eagle.py index 9266f315..45092f4f 100644 --- a/src/stamp/encoding/encoder/eagle.py +++ b/src/stamp/encoding/encoder/eagle.py @@ -271,10 +271,12 @@ def _align_vir2_to_ctp_by_coords( decimals: int = 5, ) -> tuple[torch.Tensor, np.ndarray]: """Align vir2 features to ctp features based on coordinates.""" + # round coordinates to avoid floating-point precision mismatches ref = np.round(np.asarray(ref_coords_um, dtype=np.float64), decimals) oth = np.round(np.asarray(other_coords_um, dtype=np.float64), decimals) - # coord -> queue(indices) + # build mapping: coordinate -> queue of indices + # using deque ensures stable matching when duplicates exist buckets = defaultdict(deque) for j, key in enumerate(map(tuple, oth)): buckets[key].append(j) diff --git a/src/stamp/encoding/encoder/gigapath.py b/src/stamp/encoding/encoder/gigapath.py index 4c0a2f6b..09688ad0 100644 --- a/src/stamp/encoding/encoder/gigapath.py +++ b/src/stamp/encoding/encoder/gigapath.py @@ -31,7 +31,10 @@ class Gigapath(Encoder): def __init__(self) -> None: try: model = slide_encoder.create_model( - "hf_hub:prov-gigapath/prov-gigapath", "gigapath_slide_enc12l768d", 1536 + "hf_hub:prov-gigapath/prov-gigapath", + "gigapath_slide_enc12l768d", + 1536, + global_pool=True, ) except AssertionError: raise ModuleNotFoundError( @@ -51,20 +54,9 @@ def _generate_slide_embedding( if not coords: raise ValueError("Tile coords are required for encoding") - # Calculate slide dimensions - slide_width = max(coords.coords_um[:, 0]) + coords.tile_size_um - slide_height = max(coords.coords_um[:, 1]) + coords.tile_size_um - - # Normalize coordinates to a [0, 1000] grid - n_grid = 1000 - norm_coords = self._convert_coords( - coords.coords_um, slide_width, slide_height, n_grid, current_x_offset=0 - ) + coords_px = coords.coords_um / coords.mpp norm_coords = ( - torch.tensor(norm_coords, dtype=torch.float32) - .unsqueeze(0) - .to(device) - .half() + torch.tensor(coords_px, dtype=torch.float32).unsqueeze(0).to(device).half() ) feats = feats.unsqueeze(0).half().to(device) @@ -119,8 +111,6 @@ def encode_patients_( all_feats_list = [] all_coords_list = [] - total_wsi_width = 0 - max_wsi_height = 0 slides_mpp = SlideMPP(-1) slide_info = [] @@ -151,31 +141,20 @@ def encode_patients_( ) wsi_width = max(coords.coords_um[:, 0]) + coords.tile_size_um - wsi_height = max(coords.coords_um[:, 1]) + coords.tile_size_um - - total_wsi_width += wsi_width # Sum the widths of all slides - max_wsi_height = max(max_wsi_height, wsi_height) # Track the max height - - slide_info.append((wsi_width, wsi_height, feats, coords)) + slide_info.append((wsi_width, feats, coords)) current_x_offset = 0 - for wsi_width, wsi_height, feats, coords in slide_info: - norm_coords = self._convert_coords( - coords=coords.coords_um, - total_wsi_width=total_wsi_width, - max_wsi_height=max_wsi_height, - n_grid=1000, - current_x_offset=current_x_offset, - ) + for wsi_width, feats, coords in slide_info: + offset_coords_um = coords.coords_um.copy() + offset_coords_um[:, 0] += current_x_offset - # Update x-coordinates by shifting them based on the current_x_offset - current_x_offset += ( - wsi_width # Move the x_offset forward for the next slide - ) + current_x_offset += wsi_width + + coords_px = offset_coords_um / coords.mpp norm_coords = ( - torch.tensor(norm_coords, dtype=torch.float32) + torch.tensor(coords_px, dtype=torch.float32) .unsqueeze(0) .to(device) .half() @@ -211,26 +190,3 @@ def _generate_patient_embedding( patient_embedding = torch.cat(patient_embedding, dim=0) return patient_embedding.detach().squeeze().cpu().numpy() - - def _convert_coords( - self, - coords, - total_wsi_width, - max_wsi_height, - n_grid, - current_x_offset, - ) -> np.ndarray: - """ - Normalize the x and y coordinates relative to the total WSI width and max height, using the same grid [0, 1000]. - Thanks Peter! - """ - # Normalize x-coordinates based on total WSI width (taking into account the current x offset) - normalized_x = (coords[:, 0] + current_x_offset) / total_wsi_width * n_grid - - # Normalize y-coordinates based on the maximum WSI height - normalized_y = coords[:, 1] / max_wsi_height * n_grid - - # Stack normalized x and y coordinates - converted_coords = np.stack([normalized_x, normalized_y], axis=-1) - - return np.array(converted_coords, dtype=np.float32) diff --git a/src/stamp/encoding/encoder/titan.py b/src/stamp/encoding/encoder/titan.py index 568254ca..920b3db8 100644 --- a/src/stamp/encoding/encoder/titan.py +++ b/src/stamp/encoding/encoder/titan.py @@ -49,7 +49,15 @@ def _generate_slide_embedding( coords_tensor = torch.tensor(coords.coords_um, dtype=self.precision) # Convert coordinates from microns to pixels - patch_size_lvl0 = math.floor(256 / coords.mpp) # Inferred from TITAN docs + xs = torch.unique(coords_tensor[:, 0]) + ys = torch.unique(coords_tensor[:, 1]) + patch_size_lvl0 = int( + min( + (xs[1:] - xs[:-1])[(xs[1:] - xs[:-1]) > 0].min(), + (ys[1:] - ys[:-1])[(ys[1:] - ys[:-1]) > 0].min(), + ) + ) + coords_px = coords_tensor / coords.mpp # Convert to pixels coords_px = coords_px.to(torch.int64).to(device) # Convert to integer diff --git a/src/stamp/heatmaps/__init__.py b/src/stamp/heatmaps/__init__.py index fb704fe6..b903ef89 100644 --- a/src/stamp/heatmaps/__init__.py +++ b/src/stamp/heatmaps/__init__.py @@ -330,7 +330,7 @@ def heatmaps_( # TODO: Update version when a newer model logic breaks heatmaps. stamp_version = str(getattr(model, "stamp_version", "")) - if Version(stamp_version) < Version("2.4.0"): + if Version(stamp_version) < Version("2.5.0"): raise ValueError( f"model has been built with stamp version {stamp_version} " f"which is incompatible with the current version." diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index 26196065..4ee71563 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -336,7 +336,7 @@ def categorical_crossval_( ).to_csv(split_dir / "patient-preds.csv", index=False) elif config.task == "regression": if config.ground_truth_label is None: - raise RuntimeError("Grounf truth label is required for regression") + raise RuntimeError("Ground truth label is required for regression") if isinstance(config.ground_truth_label, str): _to_regression_prediction_df( patient_to_ground_truth=cast( @@ -353,7 +353,7 @@ def categorical_crossval_( else: if config.ground_truth_label is None: raise RuntimeError( - "Grounf truth label is required for classification" + "Ground truth label is required for classification" ) _to_prediction_df( categories=categories_for_export, diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index bb295bcf..c3696c0f 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -1,6 +1,7 @@ """Helper classes to manage pytorch data.""" import logging +from collections import OrderedDict from collections.abc import Callable, Iterable, Mapping, Sequence from dataclasses import KW_ONLY, dataclass from io import BytesIO # accept in _BinaryIOLike at runtime @@ -130,6 +131,8 @@ def tile_bag_dataloader( num_workers=num_workers, collate_fn=collate_fn, worker_init_fn=Seed.get_loader_worker_init() if Seed._is_set() else None, + persistent_workers=(num_workers > 0), + pin_memory=torch.cuda.is_available(), ) return ( @@ -229,9 +232,7 @@ def _parse_targets( times.append(np.nan) events.append(np.nan) continue - - # Accept either structured tuple/list (time, event) or the legacy - # string form "time status". + # Expect a structured tuple/list (time, event). if isinstance(gt, (tuple, list)) and len(gt) == 2: t_val, e_val = gt times.append( @@ -241,9 +242,9 @@ def _parse_targets( ) events.append(float(e_val) if e_val is not None else np.nan) else: - time_str, status_str = str(gt).split(" ", 1) - times.append(np.nan if time_str.lower() == "nan" else float(time_str)) - events.append(_parse_survival_status(status_str)) + raise ValueError( + "survival ground truth must be a (time, event) tuple/list" + ) y = torch.tensor(np.column_stack([times, events]), dtype=torch.float32) return y, [] @@ -358,10 +359,8 @@ def create_dataloader( if task == "classification": raw = np.array([p.ground_truth for p in patient_data]) - categories_out = categories or list(np.unique(raw)) - labels = torch.tensor( - raw.reshape(-1, 1) == categories_out, dtype=torch.float32 - ) + categories = categories or list(np.unique(raw)) + labels = torch.tensor(raw.reshape(-1, 1) == categories, dtype=torch.float32) elif task == "regression": values: list[float] = [] for gt in (p.ground_truth for p in patient_data): @@ -381,8 +380,28 @@ def create_dataloader( raise ValueError( "Multi-target survival is not supported; provide a single survival time/status per patient" ) - t, e = (p.ground_truth or "nan nan").split(" ", 1) - times.append(float(t) if t.lower() != "nan" else np.nan) + gt = p.ground_truth + # Prefer structured (time, event) tuples. Do NOT call .split + # on the ground-truth value. If not a tuple/list, treat time + # as the whole value and event as unknown. + if isinstance(gt, (tuple, list)) and len(gt) == 2: + t, e = gt + elif gt is None: + t, e = None, None + else: + t, e = str(gt), "nan" + + # Parse time defensively + if t is None: + times.append(np.nan) + elif isinstance(t, str): + try: + times.append(np.nan if t.lower() == "nan" else float(t)) + except Exception: + times.append(np.nan) + else: + times.append(float(t)) + events.append(_parse_survival_status(e)) labels = torch.tensor(np.column_stack([times, events]), dtype=torch.float32) @@ -396,6 +415,8 @@ def create_dataloader( shuffle=shuffle, num_workers=num_workers, worker_init_fn=Seed.get_loader_worker_init() if Seed._is_set() else None, + persistent_workers=(num_workers > 0), + pin_memory=torch.cuda.is_available(), ) return dl, categories or [] else: @@ -545,6 +566,19 @@ def __post_init__(self) -> None: raise ValueError( "the number of ground truths has to match the number of bags" ) + # Initialise per-worker HDF5 handle cache here so __getitem__ avoids + # a hasattr() call on every tile read. + self._h5_handle_cache: OrderedDict[FeaturePath | _BinaryIOLike, h5py.File] = ( + OrderedDict() + ) + + def __getstate__(self) -> dict: + # h5py file handles cannot be pickled (required when DataLoader uses + # spawn-based multiprocessing). Drop the cache; each worker reopens + # files lazily on the first __getitem__ access. + state = self.__dict__.copy() + state["_h5_handle_cache"] = OrderedDict() + return state def __len__(self) -> int: return len(self.bags) @@ -556,24 +590,46 @@ def __getitem__( feats = [] coords_um = [] for bag_file in self.bags[index]: - with h5py.File(bag_file, "r") as h5: - if "feats" in h5: - feats_obj = h5["feats"] - if not isinstance(feats_obj, h5py.Dataset): - raise RuntimeError( - f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" - ) - arr = feats_obj[:] # original STAMP files - else: - embeddings_obj = h5["patch_embeddings"] - if not isinstance(embeddings_obj, h5py.Dataset): - raise RuntimeError( - f"expected 'patch_embeddings' to be an HDF5 dataset but got {type(embeddings_obj)}" - ) - arr = embeddings_obj[:] # your Kronos files + if bag_file not in self._h5_handle_cache: + # Limit open handles to avoid reaching OS ulimits + if len(self._h5_handle_cache) >= 128: + _, h = self._h5_handle_cache.popitem(last=False) + h.close() + + try: + # libver='latest' and swmr=True can provide better performance + # on some network/HPC filesystems + self._h5_handle_cache[bag_file] = h5py.File( + bag_file, "r", swmr=True, libver="latest" + ) + except Exception: + # Fallback for older HDF5 files or unconventional storage + self._h5_handle_cache[bag_file] = h5py.File(bag_file, "r") + else: + # Move recently accessed file to end (mark as recently used) + self._h5_handle_cache.move_to_end(bag_file) + + h5 = self._h5_handle_cache[bag_file] - feats.append(torch.from_numpy(arr)) - coords_um.append(torch.from_numpy(get_coords(h5).coords_um)) + if "feats" in h5: + feats_obj = h5["feats"] + if not isinstance(feats_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" + ) + arr = feats_obj[ + () + ] # uses [()] instead of [:] for clarity, both read entire dataset + else: + embeddings_obj = h5["patch_embeddings"] + if not isinstance(embeddings_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'patch_embeddings' to be an HDF5 dataset but got {type(embeddings_obj)}" + ) + arr = embeddings_obj[()] # your Kronos files + + feats.append(torch.from_numpy(arr)) + coords_um.append(torch.from_numpy(get_coords(h5).coords_um)) feats = torch.concat(feats).float() coords_um = torch.concat(coords_um).float() @@ -618,31 +674,53 @@ def __init__( self.feature_files = feature_files self.ground_truths = ground_truths self.transform = transform + # Initialise per-worker HDF5 handle cache eagerly so __getitem__ avoids + # a hasattr() call on every sample read. + self._h5_handle_cache: dict[FeaturePath | _BinaryIOLike, h5py.File] = {} + + def __getstate__(self) -> dict: + # h5py file handles cannot be pickled (required when DataLoader uses + # spawn-based multiprocessing). Drop the cache; each worker reopens + # files lazily on the first __getitem__ access. + state = self.__dict__.copy() + state["_h5_handle_cache"] = {} + return state def __len__(self): return len(self.feature_files) def __getitem__(self, idx: int): feature_file = self.feature_files[idx] - with h5py.File(feature_file, "r") as h5: - feats_obj = h5["feats"] - if not isinstance(feats_obj, h5py.Dataset): - raise RuntimeError( - f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" + if feature_file not in self._h5_handle_cache: + if len(self._h5_handle_cache) >= 128: + _, h = self._h5_handle_cache.popitem() + h.close() + try: + self._h5_handle_cache[feature_file] = h5py.File( + feature_file, "r", swmr=True, libver="latest" ) - feats = torch.from_numpy(feats_obj[:]) - # Accept [V] or [1, V] - if feats.ndim == 2 and feats.shape[0] == 1: - feats = feats[0] - elif feats.ndim == 1: - pass - else: - raise RuntimeError( - f"Expected single feature vector (shape [F] or [1, F]), got {feats.shape} in {feature_file}." - "Check that the features are patient-level." - ) - if self.transform is not None: - feats = self.transform(feats) + except Exception: + self._h5_handle_cache[feature_file] = h5py.File(feature_file, "r") + + h5 = self._h5_handle_cache[feature_file] + feats_obj = h5["feats"] + if not isinstance(feats_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" + ) + feats = torch.from_numpy(feats_obj[()]) + # Accept [V] or [1, V] + if feats.ndim == 2 and feats.shape[0] == 1: + feats = feats[0] + elif feats.ndim == 1: + pass + else: + raise RuntimeError( + f"Expected single feature vector (shape [F] or [1, F]), got {feats.shape} in {feature_file}." + "Check that the features are patient-level." + ) + if self.transform is not None: + feats = self.transform(feats) label = self.ground_truths[idx] return feats, label @@ -1063,7 +1141,8 @@ def _log_patient_slide_feature_inconsistencies( if slides_without_features := { slide for slide in slide_to_patient.keys() if not slide.exists() }: - slides_list = sorted(str(s) for s in slides_without_features) + # Log only the filenames (not full paths) to keep warnings concise. + slides_list = sorted(s.name for s in slides_without_features) _logger.warning( "some feature files could not be found: %s", ", ".join(slides_list), @@ -1102,15 +1181,7 @@ def _parse_survival_status(value) -> int | None: None, NaN, '' -> None """ - # Handle missing inputs gracefully - # if value is None: - # return 0 # treat empty/missing as censored - # if isinstance(value, float) and math.isnan(value): - # return 0 # treat empty/missing as censored - s = str(value).strip().lower() - # if s in {"", "nan", "none"}: - # return 0 # treat empty/missing as censored # Known mappings positives = {"1", "event", "dead", "deceased", "yes", "y", "True", "true"} diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index c61a2512..e0444f15 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -280,6 +280,22 @@ def deploy_categorical_model_( Sequence[Category] | Mapping[str, Sequence[Category]] | None ) = cast(Sequence[Category] | Mapping[str, Sequence[Category]] | None, None) for model_i, model in enumerate(models): + # Check for data leakage: if the deployment patient set overlaps with + # the patients used during model training/validation, log a critical + # message. This check is intentionally performed at the deploy level + # (not inside `_predict`) so prediction helpers can be reused without + # side-effects in other contexts (e.g., cross-validation). + patients_used_for_training: set[PatientId] = set( + getattr(model, "train_patients", []) + ) | set(getattr(model, "valid_patients", [])) + if overlap := patients_used_for_training & set(patient_ids): + _logger.critical( + "DATA LEAKAGE DETECTED: %d patient(s) in deployment set were used " + "during training/validation. Overlapping IDs: %s", + len(overlap), + sorted(overlap), + ) + predictions = _predict( model=model, test_dl=test_dl, @@ -374,17 +390,7 @@ def _predict( model = model.eval() torch.set_float32_matmul_precision("medium") - # Check for data leakage - patients_used_for_training: set[PatientId] = set( - getattr(model, "train_patients", []) - ) | set(getattr(model, "valid_patients", [])) - if overlap := patients_used_for_training & set(patient_ids): - _logger.critical( - "DATA LEAKAGE DETECTED: %d patient(s) in deployment set were used " - "during training/validation. Overlapping IDs: %s", - len(overlap), - sorted(overlap), - ) + # Note: data-leakage check intentionally performed at deploy level. trainer = lightning.Trainer( accelerator=accelerator, @@ -659,21 +665,11 @@ def _to_survival_prediction_df( else: row["pred_score"] = pred.cpu().tolist() - # Ground truth: time + event - if gt is not None: - if isinstance(gt, str) and " " in gt: - time_str, status_str = gt.split(" ", 1) - row["time"] = float(time_str) if time_str.lower() != "nan" else None - if status_str.lower() in {"dead", "event", "1"}: - row["event"] = 1 - elif status_str.lower() in {"alive", "censored", "0"}: - row["event"] = 0 - else: - row["event"] = None - elif isinstance(gt, (tuple, list)) and len(gt) == 2: - row["time"], row["event"] = gt - else: - row["time"], row["event"] = None, None + # Ground truth: prefer structured tuple/list (time, event). Do not + # call .split on ground-truth values — assume structured input. If + # the value is not a 2-tuple/list, treat both fields as unknown. + if isinstance(gt, (tuple, list)) and len(gt) == 2: + row["time"], row["event"] = gt else: row["time"], row["event"] = None, None diff --git a/src/stamp/modeling/models/__init__.py b/src/stamp/modeling/models/__init__.py index b5a59b5f..217ef8b2 100644 --- a/src/stamp/modeling/models/__init__.py +++ b/src/stamp/modeling/models/__init__.py @@ -8,11 +8,11 @@ import lightning import numpy as np import torch -from lifelines.utils import concordance_index as lifelines_cindex # Use beartype.typing.Mapping to avoid PEP-585 deprecation warnings in beartype from beartype.typing import Mapping from jaxtyping import Bool, Float +from lifelines.utils import concordance_index as lifelines_cindex from packaging.version import Version from torch import Tensor, nn, optim from torchmetrics.classification import MulticlassAUROC @@ -89,7 +89,7 @@ def __init__( # This should only happen when the model is loaded, # otherwise the default value will make these checks pass. # TODO: Change this on version change - if stamp_version < Version("2.4.0"): + if stamp_version < Version("2.5.0"): # Update this as we change our model in incompatible ways! raise ValueError( f"model has been built with stamp version {stamp_version} " @@ -239,7 +239,7 @@ def forward( def _step( self, *, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], step_name: str, use_mask: bool, ) -> Loss: @@ -280,32 +280,36 @@ def _step( def training_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="training", use_mask=False) def validation_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="validation", use_mask=False) def test_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="test", use_mask=False) def predict_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Float[Tensor, "batch logit"]: bags, coords, bag_sizes, _ = batch # adding a mask here will *drastically* and *unbearably* increase memory usage + # Ensure input dtype matches model weights to avoid dtype-mismatch errors + param_dtype = next(self.model.parameters()).dtype + bags = bags.to(dtype=param_dtype) + coords = coords.to(dtype=param_dtype) return self.model(bags, coords=coords, mask=None) @@ -365,6 +369,9 @@ def predict_step( self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int ) -> Tensor: feats, _ = batch if isinstance(batch, tuple) else batch + # Cast inputs to model parameter dtype to avoid Half/Float mismatches + param_dtype = next(self.model.parameters()).dtype + feats = feats.to(dtype=param_dtype) return self.model(feats) @@ -437,7 +444,7 @@ def forward( def _step( self, *, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], step_name: str, use_mask: bool, ) -> Loss: @@ -477,28 +484,28 @@ def _step( def training_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="training", use_mask=False) def validation_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="validation", use_mask=False) def test_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="test", use_mask=False) def predict_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Float[Tensor, "batch 1"]: bags, coords, bag_sizes, _ = batch @@ -741,7 +748,11 @@ def forward( # (most ViT backbones accept coords/mask even if unused) return self.model(bags, coords=coords, mask=mask) - def training_step(self, batch, batch_idx): + def training_step( + self, + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], + batch_idx: int, + ) -> Loss: bags, coords, bag_sizes, targets = batch preds = self.model(bags, coords=coords, mask=None) y = targets.to(preds.device, dtype=torch.float32) @@ -766,7 +777,7 @@ def training_step(self, batch, batch_idx): def validation_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Any: bags, coords, bag_sizes, targets = batch @@ -780,9 +791,13 @@ def validation_step( self._val_times.append(times.detach().cpu()) self._val_events.append(events.detach().cpu()) - def predict_step(self, batch, batch_idx): - feats, coords, n_tiles, survival_target = batch - return self.model(feats.float(), coords=coords, mask=None) + def predict_step( + self, + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], + batch_idx: int, + ) -> Float[Tensor, "batch 1"]: + bags, coords, bag_sizes, survival_target = batch + return self.model(bags, coords=coords, mask=None) class LitSlideSurvival(LitSurvivalBase): @@ -858,6 +873,10 @@ def __init__( positional_encoding: bool = True, # Other hparams learning_rate: float = 1e-4, + # Deployment metadata (optional) — keep parity with `Base` + train_patients: Iterable[PatientId] = (), + valid_patients: Iterable[PatientId] = (), + stamp_version: Version = Version(stamp.__version__), **hparams: Any, ) -> None: weights_dict: dict[TargetLabel, torch.Tensor] = dict(category_weights) @@ -905,6 +924,13 @@ def __init__( self.ground_truth_label = ground_truth_label self.categories = normalized_categories + # Deployment metadata — mirror `Base` behavior so checkpoints include + # train/valid patient lists and stamp version for leak-detection and + # compatibility checks. + self.train_patients = train_patients + self.valid_patients = valid_patients + self.stamp_version = str(stamp_version) + self.save_hyperparameters() def forward(self, *args): diff --git a/src/stamp/modeling/models/barspoon.py b/src/stamp/modeling/models/barspoon.py index f841bb3d..92c1a15d 100644 --- a/src/stamp/modeling/models/barspoon.py +++ b/src/stamp/modeling/models/barspoon.py @@ -78,12 +78,12 @@ class tokens, one per output label. Finally, we forward each of the decoded 2. Adding absolute positions to the feature vector, scaled down so the maximum value in the training dataset is 1. - Since neither reduced performance and the author percieves the first one to + Since neither reduced performance and the author perceives the first one to be more elegant (as the magnitude of the positional encodings is bounded), we opted to keep the positional encoding regardless in the hopes of it improving performance on future tasks. - The architecture _differs_ from the one descibed in [Attention Is All You + The architecture _differs_ from the one described in [Attention Is All You Need][1] as follows: 1. There is an initial projection stage to reduce the dimension of the @@ -223,7 +223,7 @@ def __init__( _ = hparams # So we don't get unused parameter warnings # Check if version is compatible. - if stamp_version < Version("2.4.0"): + if stamp_version < Version("2.5.0"): # Update this as we change our model in incompatible ways! raise ValueError( f"model has been built with stamp version {stamp_version} " @@ -261,7 +261,7 @@ def __init__( def step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, dict[str, torch.Tensor]], + batch: tuple[Bags, CoordinatesBatch, BagSizes, dict[str, torch.Tensor]] | list, step_name=None, ): """Process a batch with structure (feats, coords, bag_sizes, targets). diff --git a/src/stamp/modeling/models/vision_tranformer.py b/src/stamp/modeling/models/vision_tranformer.py index b936c5c9..fcd60c12 100644 --- a/src/stamp/modeling/models/vision_tranformer.py +++ b/src/stamp/modeling/models/vision_tranformer.py @@ -56,9 +56,7 @@ def forward( Which query-key pairs to mask from ALiBi (i.e. don't apply ALiBi to). """ weight_logits = torch.einsum("bqf,bkf->bqk", q, k) * (k.size(-1) ** -0.5) - distances = torch.linalg.norm( - coords_q.unsqueeze(2) - coords_k.unsqueeze(1), dim=-1 - ) + distances = torch.cdist(coords_q, coords_k) scaled_distances = self.scale_distance(distances) * self.bias_scale if alibi_mask is not None: diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index 7031eed6..a55ec1e3 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -177,13 +177,13 @@ def setup_model_for_training( f"Model '{advanced.model_name.value}' does not support feature type '{feature_type}'. " f"Supported types are: {LitModelClass.supported_features}" ) - elif ( - feature_type in ("slide", "patient") - and advanced.model_name.value.lower() != "mlp" - ): + elif feature_type in ( + "slide", + "patient", + ) and advanced.model_name.value.lower() not in {"mlp", "linear"}: raise ValueError( - f"Feature type '{feature_type}' only supports MLP backbones. " - f"Got '{advanced.model_name.value}'. Please set model_name='mlp'." + f"Feature type '{feature_type}' only supports MLP or Linear. " + f"Got '{advanced.model_name.value}'. Please set model_name='mlp' or 'linear'." ) # 4. Get model-specific hyperparameters diff --git a/src/stamp/preprocessing/__init__.py b/src/stamp/preprocessing/__init__.py index 84cb48e9..b2daa386 100755 --- a/src/stamp/preprocessing/__init__.py +++ b/src/stamp/preprocessing/__init__.py @@ -85,16 +85,9 @@ def __init__( self.canny_cutoff = canny_cutoff self.default_slide_mpp = default_slide_mpp - # Already check if we can extract the MPP here. - # We don't want to kill our dataloader later, - # because that leads to _a lot_ of error messages which are difficult to read - if ( - get_slide_mpp_( - openslide.open_slide(slide_path), default_mpp=default_slide_mpp - ) - is None - ): - raise MPPExtractionError() + # MPP is validated by the caller (extract_()) before constructing this dataset, + # so we no longer open the slide here for a redundant MPP check. + # This removes one openslide.open_slide() call per WSI. def __iter__(self) -> Iterator[tuple[Tensor, Microns, Microns]]: return ( @@ -177,6 +170,11 @@ def extract_( extractor = dino_bloom() + case ExtractorName.RED_DINO: + from stamp.preprocessing.extractor.reddino import red_dino + + extractor = red_dino() + case ExtractorName.VIRCHOW: from stamp.preprocessing.extractor.virchow import virchow @@ -222,6 +220,15 @@ def extract_( extractor = plip() + case ExtractorName.KEEP: + from stamp.preprocessing.extractor.keep import keep + + extractor = keep() + case ExtractorName.TICON: + from stamp.preprocessing.extractor.ticon import ticon + + extractor = ticon() + case ExtractorName.EMPTY: from stamp.preprocessing.extractor.empty import empty @@ -281,6 +288,15 @@ def extract_( feature_output_path.parent.mkdir(parents=True, exist_ok=True) try: + # Validate MPP here once (avoids a second openslide.open_slide inside _TileDataset.__init__). + if ( + get_slide_mpp_( + openslide.open_slide(slide_path), default_mpp=default_slide_mpp + ) + is None + ): + raise MPPExtractionError() + ds = _TileDataset( slide_path=slide_path, cache_dir=cache_dir, @@ -295,7 +311,15 @@ def extract_( default_slide_mpp=default_slide_mpp, ) # Parallelism is implemented in the dataset iterator already, so one worker is enough! - dl = DataLoader(ds, batch_size=64, num_workers=1, drop_last=False) + # pin_memory speeds up CPU→GPU DMA for tile batches. + # num_workers=1 is intentional: WSI read parallelism is inside _supertiles. + dl = DataLoader( + ds, + batch_size=64, + num_workers=1, + drop_last=False, + pin_memory=torch.cuda.is_available(), + ) feats, xs_um, ys_um = [], [], [] for tiles, xs, ys in tqdm(dl, leave=False): @@ -379,8 +403,9 @@ def _get_rejection_thumb( dtype=bool, ) - for y, x in np.floor(coords_um / tile_size_um).astype(np.uint32): - inclusion_map[y, x] = True + # Vectorized: set all tile positions at once instead of a Python loop. + tile_indices = np.floor(coords_um / tile_size_um).astype(np.uint32) + inclusion_map[tile_indices[:, 0], tile_indices[:, 1]] = True thumb = slide.get_thumbnail(size).convert("RGBA") discarded_im = Image.fromarray( diff --git a/src/stamp/preprocessing/config.py b/src/stamp/preprocessing/config.py index 244d70dd..e017daf8 100644 --- a/src/stamp/preprocessing/config.py +++ b/src/stamp/preprocessing/config.py @@ -1,7 +1,6 @@ from enum import StrEnum from pathlib import Path -import torch from pydantic import BaseModel, ConfigDict, Field from stamp.types import ImageExtension, Microns, SlideMPP, TilePixels @@ -28,7 +27,10 @@ class ExtractorName(StrEnum): MUSK = "musk" MSTAR = "mstar" PLIP = "plip" + KEEP = "keep" + TICON = "ticon" EMPTY = "empty" + RED_DINO = "red-dino" class PreprocessingConfig(BaseModel, arbitrary_types_allowed=True): @@ -45,7 +47,11 @@ class PreprocessingConfig(BaseModel, arbitrary_types_allowed=True): tile_size_px: TilePixels = TilePixels(224) extractor: ExtractorName max_workers: int = 8 - device: str = "cuda" if torch.cuda.is_available() else "cpu" + device: str = Field( + default_factory=lambda: ( + "cuda" if __import__("torch").cuda.is_available() else "cpu" + ) + ) generate_hash: bool = True default_slide_mpp: SlideMPP | None = None diff --git a/src/stamp/preprocessing/extractor/keep.py b/src/stamp/preprocessing/extractor/keep.py new file mode 100644 index 00000000..4adc964e --- /dev/null +++ b/src/stamp/preprocessing/extractor/keep.py @@ -0,0 +1,49 @@ +""" +Adopted from https://github.com/MAGIC-AI4Med/KEEP +KEEP (KnowledgE-Enhanced Pathology) +""" + +try: + import torch + from torchvision import transforms + from transformers import AutoModel +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "keep dependencies not installed." + " Please reinstall stamp using `pip install 'stamp[keep]'`" + ) from e + +from stamp.preprocessing.config import ExtractorName +from stamp.preprocessing.extractor import Extractor + + +class KEEPWrapper(torch.nn.Module): + def __init__(self, model) -> None: + super().__init__() + self.model = model + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + return self.model.encode_image(batch) + + +def keep() -> Extractor[KEEPWrapper]: + """Extracts features from slide tiles using the KEEP tile encoder.""" + model = AutoModel.from_pretrained("Astaxanthin/KEEP", trust_remote_code=True) + model.eval() + + transform = transforms.Compose( + [ + transforms.Resize( + size=224, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.CenterCrop(size=(224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ] + ) + + return Extractor( + model=KEEPWrapper(model), + transform=transform, + identifier=ExtractorName.KEEP, + ) diff --git a/src/stamp/preprocessing/extractor/reddino.py b/src/stamp/preprocessing/extractor/reddino.py new file mode 100644 index 00000000..b370ea2d --- /dev/null +++ b/src/stamp/preprocessing/extractor/reddino.py @@ -0,0 +1,64 @@ +""" +Port from https://github.com/Snarci/RedDino +RedDino: A Foundation Model for Red Blood Cell Analysis +""" + +from typing import Callable, cast + +try: + import timm + import torch + from PIL import Image + from torchvision import transforms +except ModuleNotFoundError as e: + raise ModuleNotFoundError("red-dino dependencies not installed.") from e + +from stamp.preprocessing.config import ExtractorName +from stamp.preprocessing.extractor import Extractor + +__license__ = "MIT" + + +class RedDinoClsOnly(torch.nn.Module): + def __init__(self, model) -> None: + super().__init__() + self.model = model + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + out = self.model(batch) + if isinstance(out, tuple): + out = out[0] + # if model returns tokens, return class token + if getattr(out, "ndim", 0) >= 2 and out.shape[1] > 1: + return out[:, 0] + return out + + +def red_dino() -> Extractor[RedDinoClsOnly]: + """Extracts features from single image using RedDino encoder.""" + + model = timm.create_model( + "hf-hub:Snarcy/RedDino-large", + pretrained=True, + num_classes=0, + pretrained_strict=False, + ) + + transform = cast( + Callable[[Image.Image], torch.Tensor], + transforms.Compose( + [ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ), + ) + + return Extractor( + model=RedDinoClsOnly(model), + transform=transform, + identifier=ExtractorName.RED_DINO, + ) diff --git a/src/stamp/preprocessing/extractor/ticon.py b/src/stamp/preprocessing/extractor/ticon.py new file mode 100644 index 00000000..02aac13e --- /dev/null +++ b/src/stamp/preprocessing/extractor/ticon.py @@ -0,0 +1,739 @@ +""" +This file contains code adapted from: +TICON: A Slide-Level Tile Contextualizer for Histopathology Representation Learning +https://github.com/cvlab-stonybrook/TICON +""" + +import math +from collections.abc import Callable, Mapping +from functools import partial +from typing import Any + +import timm +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from jaxtyping import Float +from torch import Tensor +from torchvision import transforms + +from stamp.preprocessing.extractor import Extractor + +try: + import timm + from torchvision import transforms +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "h_optimus_1 dependencies not installed." + " Please reinstall stamp using `pip install 'stamp[h_optimus_1]'`" + ) from e + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.init_values = init_values + self.gamma = nn.Parameter(torch.empty(dim)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self.gamma, self.init_values) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + mlp_ratio: int | float | None = (16 / 3), + bias: bool = True, + ) -> None: + super().__init__() + if hidden_features is None: + assert mlp_ratio is not None + hidden_features = int(in_features * mlp_ratio) + else: + assert mlp_ratio is None + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = nn.SiLU() + self.fc2 = nn.Linear(hidden_features // 2, in_features, bias=bias) + + def forward(self, x: Float[Tensor, "*b d"]) -> Float[Tensor, "*b d"]: + x = self.fc1(x) + x1, x2 = x.chunk(2, dim=-1) + x = self.act(x1) * x2 + x = self.fc2(x) + return x + + +class ProjectionMlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int, + out_features: int, + bias: bool = True, + ) -> None: + super().__init__() + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = nn.SiLU() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.norm = nn.LayerNorm(out_features) + + def forward(self, x: Float[Tensor, "*b d"]) -> Float[Tensor, "*b d"]: + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.norm(x) + return x + + +def get_slopes(n): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2( + n + ) # In the paper, we only train models that have 2^a heads for some a. This function has + else: # some good properties that only occur when the input is a power of 2. To maintain that even + closest_power_of_2 = 2 ** math.floor( + math.log2(n) + ) # when the number of heads is not a power of 2, we use this workaround. + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + +def scaled_dot_product_attention_custom( + query, + key, + value, + attn_bias=None, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, +) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + # attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) # pyright: ignore[reportOptionalMemberAccess] + attn_bias.to(query.dtype) # pyright: ignore[reportOptionalMemberAccess] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) # pyright: ignore[reportOptionalMemberAccess] + else: + attn_bias = attn_mask + attn_bias + + if enable_gqa: + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + context_dim: int | None = None, + # rope_kwargs: Mapping = {}, + ) -> None: + super().__init__() + self.num_heads = num_heads + context_dim = context_dim or dim + + self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.k_proj = nn.Linear(context_dim, dim, bias=qkv_bias) + self.v_proj = nn.Linear(context_dim, dim, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + # self.rope = Rope(dim=head_dim, **rope_kwargs) + slopes = torch.Tensor(get_slopes(num_heads)) + self.slopes = slopes[ + None, :, None, None + ] # einops.rearrange(slopes, 'b -> 1 b 1 1') + + def forward( + self, + x: Float[Tensor, "b n_q d"], + coords: Float[Tensor, "b n_q 2"], + context: Float[Tensor, "b n_k d_k"] | None = None, + context_coords: Float[Tensor, "b n_k 2"] | None = None, + ) -> Float[Tensor, "b n_q d"]: + if context is None or context_coords is None: + context = x + context_coords = coords + b, n_q, d = x.shape + b, n_k, _ = context.shape + h = self.num_heads + + q = self.q_proj(x).reshape(b, n_q, h, d // h).transpose(1, 2) + k = self.k_proj(context).reshape(b, n_k, h, d // h).transpose(1, 2) + v = self.v_proj(context).reshape(b, n_k, h, d // h).transpose(1, 2) + + corrds_expanded = coords.unsqueeze(2).expand( + -1, -1, n_k, -1 + ) # (b, m, d) -> (b, m, 1, d) -> (b, m, n, d) + context_coords_expanded = context_coords.unsqueeze(1).expand(-1, n_q, -1, -1) + euclid_dist = torch.sqrt( + torch.sum((corrds_expanded - context_coords_expanded) ** 2, dim=-1) + ) + self.slopes = self.slopes.to(x.device) + attn_bias = (-1) * self.slopes * euclid_dist[:, None, :, :] + + # x = F.scaled_dot_product_attention(q, k, v) + x = scaled_dot_product_attention_custom(q, k, v, attn_bias=attn_bias) + x = x.transpose(1, 2).reshape([b, n_q, d]) + x = self.proj(x) + return x + + +class NaiveResidual(nn.Module): + def __init__( + self, + drop_prob: float | int, + norm: nn.Module, + fn: nn.Module, + gamma: nn.Parameter, + ): + super().__init__() + self.norm = norm + self.fn = fn + self.keep_prob = 1 - drop_prob + self.gamma = gamma + + def forward( + self, + x: Float[Tensor, "b n d"], + **kwargs: Float[Tensor, "b ..."] | None, + ) -> Float[Tensor, "b n d"]: + fn_out = self.fn(self.norm(x), **kwargs) + if self.gamma is not None: + if self.keep_prob == 1.0 or not self.training: + return x + self.gamma * fn_out + mask = fn_out.new_empty(x.shape[0]).bernoulli_(self.keep_prob)[ + :, None, None + ] + return x + self.gamma * fn_out * mask / self.keep_prob + else: + if self.keep_prob == 1.0 or not self.training: + return x + fn_out + mask = fn_out.new_empty(x.shape[0]).bernoulli_(self.keep_prob)[ + :, None, None + ] + return x + fn_out * mask / self.keep_prob + + +class EfficientResidual(NaiveResidual): + def forward( + self, + x: Float[Tensor, "b n d"], + **kwargs: Float[Tensor, "b ..."] | None, + ) -> Float[Tensor, "b n d"]: + if self.keep_prob == 1.0 or not self.training: + if self.gamma is not None: + return x + self.gamma * self.fn(self.norm(x), **kwargs) + else: + return x + self.fn(self.norm(x), **kwargs) + + b, _, _ = x.shape + n_keep = max(int(b * self.keep_prob), 1) + indices = torch.randperm(b, device=x.device)[:n_keep] + for k, v in kwargs.items(): + if v is not None: + kwargs[k] = v[indices] + if self.gamma is not None: + return torch.index_add( + x, + dim=0, + source=self.gamma * self.fn(self.norm(x[indices]), **kwargs), + index=indices, + alpha=b / n_keep, + ) + else: + return torch.index_add( + x, + dim=0, + source=self.fn(self.norm(x[indices]), **kwargs), + index=indices, + alpha=b / n_keep, + ) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + drop_path: float | int, + norm_layer: Callable[[int], nn.Module], + context_dim: int | None, + drop_path_type: str = "efficient", + layer_scale: int = True, + attn_kwargs: Mapping = {}, + ) -> None: + super().__init__() + residual_module = { + "naive": NaiveResidual, + "efficient": EfficientResidual, + }[drop_path_type] + + self.layer_scale = layer_scale + if layer_scale: + gamma1 = nn.Parameter(torch.ones((dim)), requires_grad=True) + gamma2 = nn.Parameter(torch.ones((dim)), requires_grad=True) + else: + gamma1 = None + gamma2 = None + + self.residual1 = residual_module( + drop_path, + norm_layer(dim), + Attention( + dim, + context_dim=context_dim, + **attn_kwargs, + ), + gamma1, + ) + self.residual2 = residual_module( + drop_path, norm_layer(dim), Mlp(in_features=dim), gamma2 + ) + + def forward( + self, + x: Float[Tensor, "b n d"], + context: Float[Tensor, "b n_k d_k"] | None = None, + coords: Float[Tensor, "b n 2"] | None = None, + context_coords: Float[Tensor, "b n_k 2"] | None = None, + ) -> Float[Tensor, "b n d"]: + x = self.residual1( + x, + context=context, + coords=coords, + context_coords=context_coords, + ) + x = self.residual2(x) + return x + + +class Transformer(nn.Module): + def __init__( + self, + embed_dim: int, + norm_layer: Callable[[int], nn.Module], + depth: int, + drop_path_rate: float | int, + context_dim: int | None = None, + block_kwargs: Mapping[str, Any] = {}, + ): + super().__init__() + self.embed_dim = embed_dim + self.n_blocks = depth + + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + drop_path=drop_path_rate, + norm_layer=norm_layer, + context_dim=context_dim, + **block_kwargs, + ) + for i in range(depth) + ], + ) + + def forward( + self, + x: Float[Tensor, "b n d"], + return_layers: set[int], + contexts: list[Float[Tensor, "b n_k d_k"]] | None = None, + coords: Float[Tensor, "b n 2"] | None = None, + context_coords: Float[Tensor, "b n_k 2"] | None = None, + ) -> dict[int, Float[Tensor, "b n d"]]: + outputs = {} + if 0 in return_layers: + outputs[0] = x + for blk_idx, blk in enumerate(self.blocks): + context = contexts[blk_idx] if contexts is not None else None + x = blk( + x, + context=context, + coords=coords, + context_coords=context_coords, + ) + if blk_idx + 1 in return_layers: + outputs[blk_idx + 1] = x + return outputs + + +class EncoderDecoder(nn.Module): + def __init__( + self, + patch_size: int = 14, + in_dims: list = [], + tile_encoder_keys: list = [], + norm_layer_type: str = "LayerNorm", + transformers_kwargs: Mapping[str, Any] = {}, + encoder_kwargs: Mapping[str, Any] = {}, + decoder_kwargs: Mapping[str, Any] = {}, + norm_layer_kwargs: Mapping[str, Any] = {"eps": 1e-5}, + final_norm_kwargs: Mapping[str, Any] = {"elementwise_affine": True}, + out_layer: int = -1, + num_decoders: int = 1, + decoder_out_dims: list = [], + ): + super().__init__() + self.patch_size = patch_size + + norm_layer: Callable[[int], nn.Module] = partial( + getattr(torch.nn, norm_layer_type), **norm_layer_kwargs + ) + + self.encoder = Transformer( + **transformers_kwargs, + **encoder_kwargs, + norm_layer=norm_layer, + ) + + self.tile_encoder_keys = tile_encoder_keys + self.embed_dim = self.encoder.embed_dim + self.n_blocks = len(self.encoder.blocks) + self.out_layer = out_layer % (len(self.encoder.blocks) + 1) + self.enc_norm = norm_layer(self.embed_dim, **final_norm_kwargs) + self.num_decoders = num_decoders + self.decoder_out_dims = decoder_out_dims + + self.decoder_dict = nn.ModuleDict({}) + self.mask_dict = nn.ParameterDict({}) + self.input_proj_dict = nn.ModuleDict({}) + self.output_proj_dict = nn.ModuleDict({}) + + for i in range(len(in_dims)): + self.input_proj_dict[f"input_proj_{self.tile_encoder_keys[i]}"] = ( + ProjectionMlp( + in_features=in_dims[i], + hidden_features=self.encoder.embed_dim, + out_features=self.encoder.embed_dim, + ) + ) + + for i in range(self.num_decoders): + self.decoder_dict[f"decoder_{i}"] = nn.ModuleDict({}) + self.decoder_dict[f"decoder_{i}"]["transformer"] = Transformer( # pyright: ignore[reportIndexIssue] + **transformers_kwargs, + **decoder_kwargs, + context_dim=self.encoder.embed_dim, + norm_layer=norm_layer, + ) + + self.decoder_dict[f"decoder_{i}"]["norm"] = norm_layer( # pyright: ignore[reportIndexIssue] + self.decoder_dict[f"decoder_{i}"]["transformer"].embed_dim, # pyright: ignore[reportIndexIssue] + **final_norm_kwargs, + ) + self.mask_dict[f"mask_token_{i}"] = nn.Parameter( + torch.empty( + 1, + self.decoder_dict[f"decoder_{i}"]["transformer"].embed_dim, # pyright: ignore[reportIndexIssue] + ) + ) + + for i in range(len(self.decoder_out_dims)): + self.output_proj_dict[f"output_proj_{self.tile_encoder_keys[i]}"] = ( + ProjectionMlp( + in_features=self.encoder.embed_dim, + hidden_features=self.encoder.embed_dim, + out_features=self.decoder_out_dims[i], + ) + ) + + assert self.num_decoders <= 1 + + def init_weights(self): + for mask_key in self.mask_dict.keys(): + nn.init.normal_(self.mask_dict[mask_key], std=0.02) + self.apply(_init_weights) + return self + + def forward_features( + self, + x: Float[Tensor, "b n d"], + relative_coords: Float[Tensor, "b n 2"] | None, + predict_coords: Float[Tensor, "b n 2"] | None, + enc_layer: int, + dec_layer: int | None, + tile_encoder_key: str | None, + ) -> tuple[Float[Tensor, "b n d"], dict | None]: + b, _, _ = x.shape + + # these are the layers we need + enc_layers = {enc_layer} + if dec_layer is not None: + enc_layers.add(len(self.encoder.blocks)) + + # encoder fwd + coords_enc = relative_coords + coords_dec = predict_coords + x = self.input_proj_dict[f"input_proj_{tile_encoder_key}"](x) + encoder_outputs = self.encoder(x, coords=coords_enc, return_layers=enc_layers) + encoder_outputs = {k: self.enc_norm(v) for k, v in encoder_outputs.items()} + + # decoder fwd + if dec_layer is not None: + dec_final_output = {} + assert self.num_decoders == 1 + for dec_index in range(self.num_decoders): + decoder_outputs = self.decoder_dict[ + f"decoder_{dec_index}" + ][ # pyright: ignore[reportIndexIssue] + "transformer" + ]( + self.mask_dict[f"mask_token_{dec_index}"][None].expand( + *coords_dec.shape[:2], # pyright: ignore[reportOptionalMemberAccess] + -1, # pyright: ignore[reportOptionalMemberAccess] + ), + contexts=[encoder_outputs[len(self.encoder.blocks)]] + * self.decoder_dict[f"decoder_{dec_index}"]["transformer"].n_blocks, # pyright: ignore[reportIndexIssue] + coords=coords_dec, + context_coords=coords_enc, + return_layers={dec_layer}, + ) + dec_output = self.decoder_dict[f"decoder_{dec_index}"]["norm"]( # pyright: ignore[reportIndexIssue] + decoder_outputs[dec_layer] + ) + + for out_index in range(len(self.decoder_out_dims)): + dec_final_output[self.tile_encoder_keys[out_index]] = ( + self.output_proj_dict[ + f"output_proj_{self.tile_encoder_keys[out_index]}" + ](dec_output) + ) + else: + dec_final_output = None + enc_output = encoder_outputs[enc_layer] + return (enc_output, dec_final_output) + + def forward( + self, + x: Float[Tensor, "b n d"], + relative_coords: Float[Tensor, "b n 2"] | None = None, + tile_encoder_key: str | None = None, + ) -> Float[Tensor, "b n d"]: + # print("Input feature range", torch.min(x), torch.max(x)) + # print("Input coords range", torch.min(relative_coords), torch.max(relative_coords)) + enc_output, dec_output = self.forward_features( + x, + relative_coords=relative_coords, + predict_coords=None, + enc_layer=self.out_layer, + dec_layer=None, + tile_encoder_key=tile_encoder_key, + ) + + # print(torch.min(enc_output), torch.max(enc_output)) + return enc_output + + +# from https://github.com/facebookresearch/mae/blob/main/models_mae.py +def _init_weights(m: nn.Module, xavier_gain=1) -> None: + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight, gain=xavier_gain) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm | nn.RMSNorm) and m.elementwise_affine: + nn.init.constant_(m.weight, 1.0) + if hasattr(m, "bias") and m.bias is not None: + nn.init.constant_(m.bias, 0) # pyright: ignore[reportArgumentType] + if hasattr(m, "_device_weight_init"): + m._device_weight_init() # pyright: ignore[reportCallIssue] + + +def load_ticon(device: str = "cuda") -> nn.Module: + model_cfg = { + "transformers_kwargs": { + "embed_dim": 1536, + "drop_path_rate": 0.0, + "block_kwargs": { + "attn_kwargs": {"num_heads": 24}, + }, + }, + "encoder_kwargs": {"depth": 6}, + "decoder_kwargs": {"depth": 1}, + "in_dims": [768, 1536, 1536, 1536, 1280], + "tile_encoder_keys": [ + "conchv15", + "hoptimus1", + "uni2h", + "gigapath", + "virchow2", + ], + "num_decoders": 1, + "decoder_out_dims": [768, 1536, 1536, 1536, 1280], + } + + ckpt = hf_hub_download( + repo_id="varunb/TICON", + filename="backbone/checkpoint.pth", + repo_type="model", + ) + + with torch.device("meta"): + model = EncoderDecoder(**model_cfg) + + model.to_empty(device=device) + model.init_weights() + + sd = torch.load(ckpt, map_location="cpu", weights_only=True) + sd = { + k.removeprefix("backbone."): v + for k, v in sd.items() + if k.startswith("backbone.") + } + + model.load_state_dict(sd, strict=False) + model.eval() + return model + + +class HOptimusTICON(nn.Module): + def __init__(self, device: torch.device): + super().__init__() + + # ---------------------------- + # Stage 1: H-OptimUS + # ---------------------------- + self.tile_encoder = timm.create_model( + "hf-hub:bioptimus/H-optimus-1", + pretrained=True, + init_values=1e-5, + dynamic_img_size=False, + num_classes=0, + pretrained_strict=False, + ) + + # ---------------------------- + # Stage 2: TICON + # ---------------------------- + ticon_cfg = { + "transformers_kwargs": { + "embed_dim": 1536, + "drop_path_rate": 0.0, + "block_kwargs": { + "attn_kwargs": {"num_heads": 24}, + }, + }, + "encoder_kwargs": {"depth": 6}, + "decoder_kwargs": {"depth": 1}, + "in_dims": [768, 1536, 1536, 1536, 1280], + "tile_encoder_keys": [ + "conchv15", + "hoptimus1", + "uni2h", + "gigapath", + "virchow2", + ], + "num_decoders": 1, + "decoder_out_dims": [768, 1536, 1536, 1536, 1280], + } + + with torch.device("meta"): + self.ticon = EncoderDecoder(**ticon_cfg) + + self.ticon.to_empty(device=device) + self.ticon.init_weights() + + ckpt = hf_hub_download( + repo_id="varunb/TICON", + filename="backbone/checkpoint.pth", + repo_type="model", + ) + + sd = torch.load(ckpt, map_location="cpu", weights_only=True) + sd = { + k.removeprefix("backbone."): v + for k, v in sd.items() + if k.startswith("backbone.") + } + self.ticon.load_state_dict(sd, strict=False) + + self.to(device) + self.eval() + + @torch.inference_mode() + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: [B, 3, 224, 224] (CPU or CUDA) + """ + # Respect the current module device (it may be moved after construction). + device = next(self.parameters()).device + x = x.to(device, non_blocking=True) + + # H-Optimus_1 + emb = self.tile_encoder(x) # [B, 1536] + emb = emb.unsqueeze(1) # [B, 1, 1536] + # TICON + # single-tile → zero relative coords + coords = torch.zeros( + emb.size(0), + 1, + 2, + device=device, + dtype=torch.float32, + ) + + out = self.ticon( + x=emb, + relative_coords=coords, + tile_encoder_key="hoptimus1", + ) + + return out.squeeze(1) # [B, 1536] + + +def ticon(device: str | None = None) -> Extractor[nn.Module]: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + model = HOptimusTICON(torch.device(device)) + + transform = transforms.Compose( + [ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.707223, 0.578729, 0.703617), + std=(0.211883, 0.230117, 0.177517), + ), + ] + ) + + return Extractor( + model=model, + transform=transform, + identifier="ticon", + ) diff --git a/src/stamp/preprocessing/tiling.py b/src/stamp/preprocessing/tiling.py index ce684ba4..1f560981 100644 --- a/src/stamp/preprocessing/tiling.py +++ b/src/stamp/preprocessing/tiling.py @@ -397,8 +397,10 @@ def _tiles_from_cache_file(cache_file_path: Path) -> Iterator[_Tile]: x_um, y_um = Microns(float(x_um_str)), Microns(float(y_um_str)) with zip_fp.open(name, "r") as tile_fp: + img = Image.open(tile_fp) + img.load() # force eager pixel decode while tile_fp is still open yield _Tile( - image=Image.open(tile_fp), + image=img, coordinates=_XYCoords(x_um, y_um), size=tiler_params["tile_size_um"], ) diff --git a/src/stamp/statistics/categorical.py b/src/stamp/statistics/categorical.py index 9d6c4c12..e19b1659 100755 --- a/src/stamp/statistics/categorical.py +++ b/src/stamp/statistics/categorical.py @@ -53,7 +53,9 @@ def _categorical(preds_df: pd.DataFrame, target_label: str) -> pd.DataFrame: """ categories = preds_df[target_label].unique() y_true = preds_df[target_label] - y_pred = preds_df[[f"{target_label}_{cat}" for cat in categories]].map(float).values + y_pred = ( + preds_df[[f"{target_label}_{cat}" for cat in categories]].astype(float).values + ) stats_df = pd.DataFrame(index=categories) @@ -156,15 +158,19 @@ def categorical_aggregated_multitarget_( all_target_stats = {} + # Read each CSV once and cache so we don't re-read N×M times. + csv_cache: dict[str, pd.DataFrame] = { + Path(p).parent.name: pd.read_csv(p, dtype=str) for p in preds_csvs + } + for target_label in target_labels: # Process each target separately preds_dfs = {} - for p in preds_csvs: - df = pd.read_csv(p, dtype=str) + for fold_name, df in csv_cache.items(): # Drop rows where this target's ground truth is missing df_clean = df.dropna(subset=[target_label]) if len(df_clean) > 0: - preds_dfs[Path(p).parent.name] = _categorical(df_clean, target_label) + preds_dfs[fold_name] = _categorical(df_clean, target_label) if not preds_dfs: continue diff --git a/src/stamp/statistics/prc.py b/src/stamp/statistics/prc.py index 867885e9..dd58ea2e 100755 --- a/src/stamp/statistics/prc.py +++ b/src/stamp/statistics/prc.py @@ -6,7 +6,6 @@ import scipy.stats as st from jaxtyping import Bool, Float from matplotlib.axes import Axes -from scipy.interpolate import interp1d from sklearn.metrics import ( auc, average_precision_score, @@ -56,15 +55,9 @@ def _plot_bootstrapped_pr_curve( continue precision, recall, _ = precision_recall_curve(sample_y_true, sample_y_pred) - # Create an interpolation function with decreasing values - interp_func = interp1d( - recall[::-1], - precision[::-1], - kind="linear", - fill_value=np.nan, - bounds_error=False, - ) - interp_prc = interp_func(interp_recall) + # np.interp requires increasing x; precision_recall_curve returns + # decreasing recall, so reverse both arrays. + interp_prc = np.interp(interp_recall, recall[::-1], precision[::-1]) interp_prcs[i] = interp_prc bootstrapped_auprc = auc(interp_recall, interp_prc) bootstrap_auprcs.append(bootstrapped_auprc) diff --git a/src/stamp/statistics/roc.py b/src/stamp/statistics/roc.py index d42413a4..338cf876 100755 --- a/src/stamp/statistics/roc.py +++ b/src/stamp/statistics/roc.py @@ -180,9 +180,11 @@ def _plot_bootstrapped_roc_curve( # and then sample the bottom 0.025 / top 0.975 quantile point # for each sampled fpr-position rng = np.random.default_rng() - interp_rocs = [] interp_fpr = np.linspace(0, 1, num=1000) + # Pre-allocate; rows that correspond to skipped samples stay NaN. + interp_rocs = np.full((n_bootstrap_samples, len(interp_fpr)), np.nan) bootstrap_aucs: list[float] = [] + valid_row = 0 for _ in trange(n_bootstrap_samples, desc="Bootstrapping ROC curves", leave=False): sample_idxs = rng.choice(len(y_true), len(y_true)) sample_y_true = y_true[sample_idxs] @@ -190,15 +192,17 @@ def _plot_bootstrapped_roc_curve( if len(np.unique(sample_y_true)) != 2: continue fpr, tpr, thresh = roc_curve(sample_y_true, sample_y_score) - interp_rocs.append(np.interp(interp_fpr, fpr, tpr)) + interp_rocs[valid_row] = np.interp(interp_fpr, fpr, tpr) + valid_row += 1 bootstrap_aucs.append(float(roc_auc_score(sample_y_true, sample_y_score))) + interp_rocs = interp_rocs[:valid_row] # trim unused rows roc_lower, roc_upper = cast( tuple[ Float[np.ndarray, "fpr"], # noqa: F821 Float[np.ndarray, "fpr"], # noqa: F821 ], - np.quantile(interp_rocs, [0.025, 0.975], axis=0), + np.nanquantile(interp_rocs, [0.025, 0.975], axis=0), ) ax.fill_between(interp_fpr, roc_lower, roc_upper, alpha=0.5) diff --git a/src/stamp/statistics/survival.py b/src/stamp/statistics/survival.py index 7c298a54..8fbf5d63 100644 --- a/src/stamp/statistics/survival.py +++ b/src/stamp/statistics/survival.py @@ -4,11 +4,11 @@ from pathlib import Path +import lifelines.plotting as lifelines_plotting import matplotlib.pyplot as plt import numpy as np import pandas as pd from lifelines import KaplanMeierFitter -from lifelines.plotting import add_at_risk_counts from lifelines.statistics import logrank_test from lifelines.utils import concordance_index @@ -24,7 +24,7 @@ def _comparable_pairs_count(times: np.ndarray, events: np.ndarray) -> int: def _cindex( time: np.ndarray, event: np.ndarray, - risk: np.ndarray, # will be flipped in function + risk: np.ndarray, ) -> tuple[float, int]: """Compute C-index using Lifelines convention: higher risk → shorter survival (worse outcome). @@ -40,7 +40,7 @@ def _survival_stats_for_csv( time_label: str, status_label: str, risk_label: str | None = None, - cut_off: float | None = None, # will be flipped in function + cut_off: float | None = None, ) -> pd.Series: """Compute C-index and log-rank p for one CSV.""" if risk_label is None: @@ -136,7 +136,16 @@ def _plot_km( ) kmf_high.plot_survival_function(ax=ax, ci_show=False, color="red") - add_at_risk_counts(kmf_low, kmf_high, ax=ax) + # add at-risk counts only for fitted curves + fitters = [] + if len(low_df) > 0: + fitters.append(kmf_low) + if len(high_df) > 0: + fitters.append(kmf_high) + + # add at-risk table for fitted curves + if len(fitters) > 0: + lifelines_plotting.add_at_risk_counts(*fitters, ax=ax) # log-rank and c-index res = logrank_test( diff --git a/src/stamp/utils/cache.py b/src/stamp/utils/cache.py index d65c7dcd..c0b00a51 100644 --- a/src/stamp/utils/cache.py +++ b/src/stamp/utils/cache.py @@ -9,28 +9,28 @@ STAMP_CACHE_DIR: Final[Path] = ( Path(os.environ.get("XDG_CACHE_HOME") or (Path.home() / ".cache")) / "stamp" ) - -# If we imported this, we probably want to use it, -# so it's okay creating the directory now -STAMP_CACHE_DIR.mkdir(exist_ok=True, parents=True) +# Directory is created on demand (inside functions that write to it) +# so that a bare import of this module does not cause filesystem I/O. def download_file(*, url: str, file_name: str, sha256sum: str) -> Path: - """Downloads a file, or loads it from cache if it has been downloaded before""" + """Downloads a file, or loads it from cache if it has been downloaded before. + + The checksum is only verified on the initial download. Once the file + exists in the cache it is trusted as-is to avoid re-reading large weight + files (which can be ~1 GB) on every run. + """ + STAMP_CACHE_DIR.mkdir(exist_ok=True, parents=True) outfile_path = STAMP_CACHE_DIR / file_name if outfile_path.is_file(): - with open(outfile_path, "rb") as weight_file: - digest = hashlib.file_digest(weight_file, "sha256") - assert digest.hexdigest() == sha256sum, ( - f"{outfile_path} has the wrong checksum. Try deleting it and rerunning this script." - ) - else: - filename, _ = urllib.request.urlretrieve(url) - with open(filename, "rb") as weight_file: - digest = hashlib.file_digest(weight_file, "sha256") - assert digest.hexdigest() == sha256sum, "hash of downloaded file did not match" - shutil.move(filename, outfile_path) - + # File already cached and verified on first download — skip re-hash. + return outfile_path + + filename, _ = urllib.request.urlretrieve(url) + with open(filename, "rb") as weight_file: + digest = hashlib.file_digest(weight_file, "sha256") + assert digest.hexdigest() == sha256sum, "hash of downloaded file did not match" + shutil.move(filename, outfile_path) return outfile_path @@ -40,14 +40,16 @@ def file_digest(file: str | Path) -> str: @cache -def get_processing_code_hash(file_path) -> str: +def get_processing_code_hash(file_path: Path) -> str: """The hash of the entire process codebase. - It is used to assure that features extracted with different versions of this code base - can be identified as such after the fact. + It is used to assure that features extracted with different versions of + this code base can be identified as such after the fact. """ hasher = hashlib.sha256() - for file_path in sorted(file_path.parent.glob("*.py")): - with open(file_path, "rb") as fp: - hasher.update(fp.read()) + for py_file in sorted(file_path.parent.glob("*.py")): + # Use file_digest to stream the file in chunks instead of reading + # the entire source into memory at once. + with open(py_file, "rb") as fp: + hasher.update(hashlib.file_digest(fp, "sha256").digest()) return hasher.hexdigest() diff --git a/tests/conftest.py b/tests/conftest.py index f7f4f005..16dcae5b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,15 @@ +import multiprocessing from stamp.preprocessing import ExtractorName +# Ensure tests use a safe multiprocessing start method to avoid +# fork-from-multi-threaded-process warnings on Linux. +if multiprocessing.get_start_method(allow_none=True) != "spawn": + try: + multiprocessing.set_start_method("spawn") + except RuntimeError: + # start method already set by the test runner/environment + pass + # This lets you choose which extractors to run on pytest. Useful for the # CI pipeline as extractors like MUSK take an eternity on CPU. diff --git a/tests/test_deployment.py b/tests/test_deployment.py index 7d1d6589..55939e66 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -221,8 +221,8 @@ def test_to_prediction_df(task: str) -> None: assert preds_df["loss"].isna().all() else: patient_to_ground_truth = { - PatientId("p1"): "10.0 1", - PatientId("p2"): "12.3 0", + PatientId("p1"): (10.0, 1), + PatientId("p2"): (12.3, 0), } predictions = { PatientId("p1"): torch.tensor([0.8]), @@ -304,17 +304,19 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: feature_file = make_old_feature_file( feats=torch.rand(23, dim_feats), coords=torch.rand(23, 2) ) - gt = GroundTruth("foo") + gt = cast(GroundTruth, "foo") elif task == "regression": feature_file = make_old_feature_file( feats=torch.rand(30, dim_feats), coords=torch.rand(30, 2) ) - gt = GroundTruth(42.5) # numeric target wrapped for typing + gt = cast(GroundTruth, 42.5) # numeric target wrapped for typing else: # survival feature_file = make_old_feature_file( feats=torch.rand(40, dim_feats), coords=torch.rand(40, 2) ) - gt = GroundTruth("12 0") # (time, status) + gt = cast( + GroundTruth, (12.0, 0) + ) # (time, status) - use raw tuple (GroundTruth is a str alias) patient_to_data = { PatientId("pat_test"): PatientData( diff --git a/tests/test_train_deploy.py b/tests/test_train_deploy.py index ea5547f1..a21af576 100644 --- a/tests/test_train_deploy.py +++ b/tests/test_train_deploy.py @@ -54,20 +54,20 @@ def test_train_deploy_integration( create_random_dataset( dir=tmp_path / "train", n_categories=3, - n_patients=400, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=30, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) ) deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = create_random_dataset( dir=tmp_path / "deploy", categories=categories, - n_patients=50, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=10, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) @@ -85,7 +85,7 @@ def test_train_deploy_integration( advanced = AdvancedConfig( # Dataset and -loader parameters - bag_size=500, + bag_size=32, num_workers=min(os.cpu_count() or 1, 16), # Training paramenters batch_size=8, @@ -142,7 +142,7 @@ def test_train_deploy_patient_level_integration( create_random_patient_level_dataset( dir=tmp_path / "train", n_categories=3, - n_patients=400, + n_patients=30, feat_dim=feat_dim, ) ) @@ -150,7 +150,7 @@ def test_train_deploy_patient_level_integration( create_random_patient_level_dataset( dir=tmp_path / "deploy", categories=categories, - n_patients=50, + n_patients=10, feat_dim=feat_dim, ) ) @@ -218,20 +218,20 @@ def test_train_deploy_regression_integration( train_clini_path, train_slide_path, train_feature_dir, _ = ( create_random_regression_dataset( dir=tmp_path / "train", - n_patients=400, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=30, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) ) deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( create_random_regression_dataset( dir=tmp_path / "deploy", - n_patients=50, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=10, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) ) @@ -250,9 +250,9 @@ def test_train_deploy_regression_integration( ) advanced = AdvancedConfig( - bag_size=500, + bag_size=32, num_workers=min(os.cpu_count() or 1, 16), - batch_size=1, + batch_size=8, max_epochs=2, patience=1, accelerator="gpu" if torch.cuda.is_available() else "cpu", @@ -297,20 +297,20 @@ def test_train_deploy_survival_integration( train_clini_path, train_slide_path, train_feature_dir, _ = ( create_random_survival_dataset( dir=tmp_path / "train", - n_patients=400, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=30, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) ) deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( create_random_survival_dataset( dir=tmp_path / "deploy", - n_patients=50, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=10, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) ) @@ -329,7 +329,7 @@ def test_train_deploy_survival_integration( ) advanced = AdvancedConfig( - bag_size=500, + bag_size=32, num_workers=min(os.cpu_count() or 1, 16), batch_size=8, max_epochs=2, @@ -385,7 +385,7 @@ def test_train_deploy_patient_level_regression_integration( train_feat_dir.mkdir(parents=True, exist_ok=True) deploy_feat_dir.mkdir(parents=True, exist_ok=True) - n_train, n_deploy = 300, 60 + n_train, n_deploy = 30, 10 train_rows, deploy_rows = [], [] # --- Generate random patient-level features and numeric targets --- @@ -490,14 +490,14 @@ def test_train_deploy_patient_level_survival_integration( train_clini_path, train_slide_path, train_feature_dir, _ = ( create_random_patient_level_survival_dataset( dir=tmp_path / "train", - n_patients=300, + n_patients=30, feat_dim=feat_dim, ) ) deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( create_random_patient_level_survival_dataset( dir=tmp_path / "deploy", - n_patients=60, + n_patients=10, feat_dim=feat_dim, ) ) @@ -565,10 +565,10 @@ def test_train_deploy_multi_target_integration( train_clini_path, train_slide_path, train_feature_dir, _ = ( create_random_multi_target_dataset( dir=tmp_path / "train", - n_patients=400, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=30, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, target_labels=target_labels, categories_per_target=categories_per_target, @@ -577,10 +577,10 @@ def test_train_deploy_multi_target_integration( deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( create_random_multi_target_dataset( dir=tmp_path / "deploy", - n_patients=50, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=10, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, target_labels=target_labels, categories_per_target=categories_per_target, @@ -601,7 +601,7 @@ def test_train_deploy_multi_target_integration( ) advanced = AdvancedConfig( - bag_size=500, + bag_size=32, num_workers=min(os.cpu_count() or 1, 16), batch_size=8, max_epochs=2, diff --git a/uv.lock b/uv.lock index 96b4b73a..beb0a251 100644 --- a/uv.lock +++ b/uv.lock @@ -3699,7 +3699,7 @@ wheels = [ [[package]] name = "stamp" -version = "2.4.0" +version = "2.5.0" source = { editable = "." } dependencies = [ { name = "beartype" }, From 2782c96f6971e119f2b1e178e530c5677eb5cfa7 Mon Sep 17 00:00:00 2001 From: mducducd Date: Thu, 12 Mar 2026 12:12:53 +0000 Subject: [PATCH 8/9] fix: remove duplicate setup_model_from_dataloaders function --- src/stamp/modeling/train.py | 118 ------------------------------------ 1 file changed, 118 deletions(-) diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index 06d91717..a55ec1e3 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -349,124 +349,6 @@ def setup_model_from_dataloaders( return model -def setup_model_from_dataloaders( - *, - train_dl: DataLoader, - valid_dl: DataLoader, - task: Task, - train_categories: Sequence[Category] | Mapping[str, Sequence[Category]], - dim_feats: int, - train_patients: Sequence[PatientId], - valid_patients: Sequence[PatientId], - feature_type: str, - advanced: AdvancedConfig, - # Metadata, has no effect on model training - ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, - time_label: PandasLabel | None, - status_label: PandasLabel | None, - clini_table: Path, - slide_table: Path | None, - feature_dir: Path, -) -> lightning.LightningModule: - """Creates a model from pre-built dataloaders (no internal split).""" - - _logger.info( - "Training dataloaders: task=%s, feature_type=%s", - task, - feature_type, - ) - - category_weights: torch.Tensor | dict[str, torch.Tensor] | list = [] - if task == "classification": - category_weights = _compute_class_weights_and_check_categories( - train_dl=train_dl, - feature_type=feature_type, - train_categories=train_categories, - ) - - # 1. Default to a model if none is specified - if advanced.model_name is None: - advanced.model_name = ModelName.VIT if feature_type == "tile" else ModelName.MLP - _logger.info( - f"No model specified, defaulting to '{advanced.model_name.value}' for feature type '{feature_type}'" - ) - - # Prevent selecting `barspoon` for single-target classification - if ( - task == "classification" - and isinstance(ground_truth_label, str) - and advanced.model_name == ModelName.BARSPOON - ): - raise ValueError( - "Model 'barspoon' requires multi-target classification. " - "For single-target classification set model_name to 'vit', 'trans_mil', or 'mlp'." - ) - - # 2. Instantiate the lightning wrapper (based on provided task, feature type) and model backbone dynamically - LitModelClass, ModelClass = load_model_class( - task, feature_type, advanced.model_name - ) - - # 3. Validate that the chosen model supports the feature type - if feature_type not in LitModelClass.supported_features: - raise ValueError( - f"Model '{advanced.model_name.value}' does not support feature type '{feature_type}'. " - f"Supported types are: {LitModelClass.supported_features}" - ) - elif ( - feature_type in ("slide", "patient") - and advanced.model_name.value.lower() != "mlp" - ): - raise ValueError( - f"Feature type '{feature_type}' only supports MLP backbones. " - f"Got '{advanced.model_name.value}'. Please set model_name='mlp'." - ) - - # 4. Get model-specific hyperparameters - model_specific_params = ( - advanced.model_params.model_dump().get(advanced.model_name.value) or {} - ) - - # 5. Calculate total steps for scheduler - steps_per_epoch = len(train_dl) - total_steps = steps_per_epoch * advanced.max_epochs - - # 6. Prepare common parameters - common_params = { - "categories": train_categories, - "category_weights": category_weights, - "dim_input": dim_feats, - "total_steps": total_steps, - "max_lr": advanced.max_lr, - "div_factor": advanced.div_factor, - # Metadata, has no effect on model training - "model_name": advanced.model_name.value, - "ground_truth_label": ground_truth_label, - "time_label": time_label, - "status_label": status_label, - "train_patients": train_patients, - "valid_patients": valid_patients, - "clini_table": clini_table, - "slide_table": slide_table, - "feature_dir": feature_dir, - } - - all_params = {**common_params, **model_specific_params} - - _logger.info( - f"Instantiating model '{advanced.model_name.value}' with parameters: {model_specific_params}" - ) - _logger.info( - "Other params: max_epochs=%s, patience=%s", - advanced.max_epochs, - advanced.patience, - ) - - model = LitModelClass(model_class=ModelClass, **all_params) - - return model - - def setup_dataloaders_for_training( *, patient_to_data: Mapping[PatientId, PatientData[GroundTruth | None | dict]], From dc17ec5ad97126d9a0ad40c0a84736c4d38904da Mon Sep 17 00:00:00 2001 From: mducducd Date: Thu, 12 Mar 2026 12:44:08 +0000 Subject: [PATCH 9/9] docs: update multi-target usage and model options --- getting-started.md | 15 ++++++++++++--- src/stamp/config.yaml | 4 ++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/getting-started.md b/getting-started.md index 9e1069cc..d1e66688 100644 --- a/getting-started.md +++ b/getting-started.md @@ -182,8 +182,12 @@ either in excel or `.csv` format, with contents as described below. Finally, `ground_truth_label` needs to contain the column name of the data we want to train our model on. +For single-target classification, use one column name. +For multi-target classification, use a list of column names and set +`advanced_config.model_name: "barspoon"`. Stamp only can be used to train neural networks for categorical targets. -We recommend explicitly setting the possible classes using the `categories` field. +For single-target runs, we recommend explicitly setting the possible classes +using the `categories` field. ```yaml # stamp-test-experiment/config.yaml @@ -210,12 +214,15 @@ crossval: # Name of the column from the clini table to train on. ground_truth_label: "isMSIH" + # For multi-target classification with barspoon: + # ground_truth_label: ["subtype", "grade"] # Optional settings: # The categories occurring in the target label column of the clini table. # If unspecified, they will be inferred from the table itself. categories: ["yes", "no"] + # For multi-target classification, per-target categories are inferred. # Number of folds to split the data into for cross-validation #n_splits: 5 @@ -227,6 +234,7 @@ we can run it by invoking: stamp --config stamp-test-experiment/config.yaml crossval ``` + ## Generating Statistics After training and validating your model, you may want to generate statistics to evaluate its performance. @@ -501,7 +509,7 @@ advanced_config: max_lr: 1e-4 div_factor: 25. # Select a model regardless of task - # Available models are: vit, trans_mil, mlp + # Available models are: vit, trans_mil, mlp, linear, barspoon model_name: "vit" model_params: @@ -518,4 +526,5 @@ STAMP automatically adapts its **model architecture**, **loss function**, and ** **Regression** tasks only require `ground_truth_label`. **Survival analysis** tasks require `time_label` (follow-up time) and `status_label` (event indicator). -These requirements apply consistently across cross-validation, training, deployment, and statistics. \ No newline at end of file +**Multi-target classification** requires `ground_truth_label` as a list and `advanced_config.model_name: "barspoon"`. +These requirements apply consistently across cross-validation, training, deployment, and statistics. diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index 80d105bf..900b40de 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -327,7 +327,7 @@ advanced_config: max_lr: 1e-4 div_factor: 25. # Select a model regardless of task - model_name: "vit" # or mlp, trans_mil, barspoon + model_name: "vit" # or mlp, linear, trans_mil, barspoon model_params: vit: # Vision Transformer @@ -357,4 +357,4 @@ advanced_config: num_encoder_layers: 2 num_decoder_layers: 2 dim_feedforward: 2048 - positional_encoding: true \ No newline at end of file + positional_encoding: true