diff --git a/README.md b/README.md index 3069dddf..9a0ed68b 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ STAMP is an **end‑to‑end, weakly‑supervised deep‑learning pipeline** tha * 🎓 **Beginner‑friendly & expert‑ready**: Zero‑code CLI and YAML config for routine use; optional code‑level customization for advanced research. * 🧩 **Model‑rich**: Out‑of‑the‑box support for **+20 foundation models** at [tile level](getting-started.md#feature-extraction) (e.g., *Virchow‑v2*, *UNI‑v2*) and [slide level](getting-started.md#slide-level-encoding) (e.g., *TITAN*, *COBRA*). * 🔬 **Weakly‑supervised**: End‑to‑end MIL with Transformer aggregation for training, cross‑validation and external deployment; no pixel‑level labels required. -* 🧮 **Multi-task learning**: Unified framework for **classification**, **regression**, and **cox-based survival analysis**. +* 🧮 **Multi-task learning**: Unified framework for **classification**, **multi-target classification**, **regression**, and **cox-based survival analysis**. * 📊 **Stats & results**: Built‑in metrics and patient‑level predictions, ready for analysis and reporting. * 🖼️ **Explainable**: Generates heatmaps and top‑tile exports out‑of‑the‑box for transparent model auditing and publication‑ready figures. * 🤝 **Collaborative by design**: Clinicians drive hypothesis & interpretation while engineers handle compute; STAMP’s modular CLI mirrors real‑world workflows and tracks every step for full reproducibility. diff --git a/getting-started.md b/getting-started.md index 6d5bffec..9e1069cc 100644 --- a/getting-started.md +++ b/getting-started.md @@ -58,7 +58,9 @@ Stamp currently supports the following feature extractors: - [mSTAR][mstar] - [MUSK][musk] - [PLIP][plip] + - [KEEP][keep] - [TICON][ticon] + - [RedDino][reddino] As some of the above require you to request access to the model on huggingface, @@ -154,12 +156,14 @@ meaning ignored that it was ignored during feature extraction. [mstar]: https://huggingface.co/Wangyh/mSTAR [musk]: https://huggingface.co/xiangjx/musk [plip]: https://github.com/PathologyFoundation/plip +[keep]: https://loiesun.github.io/keep/ "A Knowledge-enhanced Pathology Vision-language Foundation Model for Cancer Diagnosis" [TITAN]: https://huggingface.co/MahmoodLab/TITAN [COBRA2]: https://huggingface.co/KatherLab/COBRA [EAGLE]: https://github.com/KatherLab/EAGLE [MADELEINE]: https://huggingface.co/MahmoodLab/madeleine [PRISM]: https://huggingface.co/paige-ai/Prism [TICON]: https://cvlab-stonybrook.github.io/TICON/ "TICON: A Slide-Level Tile Contextualizer for Histopathology Representation Learning" +[reddino]: https://github.com/Snarci/RedDino "RedDino: A Foundation Model for Red Blood Cell Analysis" diff --git a/pyproject.toml b/pyproject.toml index fc79a26b..7ccca45f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "stamp" -version = "2.4.0" +version = "2.4.1" authors = [ { name = "Omar El Nahhas", email = "omar.el_nahhas@tu-dresden.de" }, { name = "Marko van Treeck", email = "markovantreeck@gmail.com" }, diff --git a/src/stamp/__main__.py b/src/stamp/__main__.py index 4ab8416f..0b252d6f 100755 --- a/src/stamp/__main__.py +++ b/src/stamp/__main__.py @@ -6,15 +6,6 @@ import yaml -from stamp.config import StampConfig -from stamp.modeling.config import ( - AdvancedConfig, - MlpModelParams, - ModelParams, - VitModelParams, -) -from stamp.seed import Seed - STAMP_FACTORY_SETTINGS = Path(__file__).with_name("config.yaml") # Set up the logger @@ -41,23 +32,38 @@ def _create_config_file(config_file: Path) -> None: def _run_cli(args: argparse.Namespace) -> None: - # Handle init command + # Handle init command before any stamp-internal imports so that + # `stamp init` and `stamp --help` don't pay the full torch/pydantic + # import cost. if args.command == "init": _create_config_file(args.config_file_path) return + # Deferred imports: only reached for real commands, not --help / init. + from stamp.modeling.config import ( + AdvancedConfig, + MlpModelParams, + ModelParams, + VitModelParams, + ) + from stamp.utils.config import StampConfig + from stamp.utils.seed import Seed + # Load YAML configuration with open(args.config_file_path, "r") as config_yaml: config = StampConfig.model_validate(yaml.safe_load(config_yaml)) - # use default advanced config in case none is provided - if config.advanced_config is None: - config.advanced_config = AdvancedConfig( - model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams()), - ) + # Only build a default AdvancedConfig (with model-params) for commands + # that actually use it. Preprocess / encode / statistics / heatmaps + # never touch config.advanced_config, so don't pay the construction cost. + if args.command in {"train", "crossval"}: + if config.advanced_config is None: + config.advanced_config = AdvancedConfig( + model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams()), + ) - # Set global random seed - if config.advanced_config.seed is not None: + # Apply the global seed for any command that has one configured. + if config.advanced_config is not None and config.advanced_config.seed is not None: Seed.set(config.advanced_config.seed) match args.command: @@ -153,6 +159,7 @@ def _run_cli(args: argparse.Namespace) -> None: if config.training.task is None: raise ValueError("task must be set in training configuration") + assert config.advanced_config is not None # guaranteed above for "train" train_categorical_model_( config=config.training, advanced=config.advanced_config ) @@ -198,6 +205,7 @@ def _run_cli(args: argparse.Namespace) -> None: f"{yaml.dump(config.crossval.model_dump(mode='json', exclude_none=True))}" ) + assert config.advanced_config is not None # guaranteed above for "crossval" categorical_crossval_( config=config.crossval, advanced=config.advanced_config, diff --git a/src/stamp/cache.py b/src/stamp/cache.py deleted file mode 100644 index d65c7dcd..00000000 --- a/src/stamp/cache.py +++ /dev/null @@ -1,53 +0,0 @@ -import hashlib -import os -import shutil -import urllib.request -from functools import cache -from pathlib import Path -from typing import Final - -STAMP_CACHE_DIR: Final[Path] = ( - Path(os.environ.get("XDG_CACHE_HOME") or (Path.home() / ".cache")) / "stamp" -) - -# If we imported this, we probably want to use it, -# so it's okay creating the directory now -STAMP_CACHE_DIR.mkdir(exist_ok=True, parents=True) - - -def download_file(*, url: str, file_name: str, sha256sum: str) -> Path: - """Downloads a file, or loads it from cache if it has been downloaded before""" - outfile_path = STAMP_CACHE_DIR / file_name - if outfile_path.is_file(): - with open(outfile_path, "rb") as weight_file: - digest = hashlib.file_digest(weight_file, "sha256") - assert digest.hexdigest() == sha256sum, ( - f"{outfile_path} has the wrong checksum. Try deleting it and rerunning this script." - ) - else: - filename, _ = urllib.request.urlretrieve(url) - with open(filename, "rb") as weight_file: - digest = hashlib.file_digest(weight_file, "sha256") - assert digest.hexdigest() == sha256sum, "hash of downloaded file did not match" - shutil.move(filename, outfile_path) - - return outfile_path - - -def file_digest(file: str | Path) -> str: - with open(file, "rb") as fp: - return hashlib.file_digest(fp, "sha256").hexdigest() - - -@cache -def get_processing_code_hash(file_path) -> str: - """The hash of the entire process codebase. - - It is used to assure that features extracted with different versions of this code base - can be identified as such after the fact. - """ - hasher = hashlib.sha256() - for file_path in sorted(file_path.parent.glob("*.py")): - with open(file_path, "rb") as fp: - hasher.update(fp.read()) - return hasher.hexdigest() diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index 8440560b..80d105bf 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -4,7 +4,7 @@ preprocessing: # Extractor to use for feature extractor. Possible options are "ctranspath", # "uni", "conch", "chief-ctranspath", "conch1_5", "uni2", "dino-bloom", # "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow", - # "virchow-full", "musk", "mstar", "plip" + # "virchow-full", "musk", "mstar", "plip", "ticon", "red-dino", "keep" # Some of them require requesting access to the respective authors beforehand. extractor: "chief-ctranspath" @@ -76,6 +76,8 @@ crossval: # Name of the column from the clini table to train on. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # For survival (should be status and follow-up days columns in clini table) # status_label: "event" @@ -133,6 +135,8 @@ training: # Name of the column from the clini table to train on. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # For survival (should be status and follow-up days columns in clini table) # status_label: "event" @@ -175,6 +179,8 @@ deployment: # Name of the column from the clini to compare predictions to. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # For survival (should be status and follow-up days columns in clini table) # status_label: "event" @@ -200,6 +206,8 @@ statistics: # Name of the target label. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # A lot of the statistics are computed "one-vs-all", i.e. there needs to be # a positive class to calculate the statistics for. @@ -319,7 +327,7 @@ advanced_config: max_lr: 1e-4 div_factor: 25. # Select a model regardless of task - model_name: "vit" # or mlp, trans_mil + model_name: "vit" # or mlp, trans_mil, barspoon model_params: vit: # Vision Transformer @@ -338,3 +346,15 @@ advanced_config: dim_hidden: 512 num_layers: 2 dropout: 0.25 + + # NOTE: Only the `barspoon` model supports multi-target classification + # (i.e. `ground_truth_label` can be a list of column names). Other + # models expect a single target column. + barspoon: # Encoder-Decoder Transformer for multi-target classification + d_model: 512 + num_encoder_heads: 8 + num_decoder_heads: 8 + num_encoder_layers: 2 + num_decoder_layers: 2 + dim_feedforward: 2048 + positional_encoding: true \ No newline at end of file diff --git a/src/stamp/encoding/__init__.py b/src/stamp/encoding/__init__.py index 9cb873bb..02d46594 100644 --- a/src/stamp/encoding/__init__.py +++ b/src/stamp/encoding/__init__.py @@ -73,7 +73,7 @@ def init_slide_encoder_( selected_encoder = encoder case _ as unreachable: - assert_never(unreachable) # type: ignore + assert_never(unreachable) selected_encoder.encode_slides_( output_dir=output_dir, diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index 5827e884..3b4c3ac4 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from pathlib import Path from tempfile import NamedTemporaryFile +from typing import cast import h5py import numpy as np @@ -12,11 +13,11 @@ from tqdm import tqdm import stamp -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.modeling.data import CoordsInfo, get_coords, read_table from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, PandasLabel +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" @@ -61,7 +62,8 @@ def encode_slides_( if self.precision == torch.float16: self.model.half() - for tile_feats_filename in (progress := tqdm(os.listdir(feat_dir))): + h5_files = sorted(f for f in os.listdir(feat_dir) if f.endswith(".h5")) + for tile_feats_filename in (progress := tqdm(h5_files)): h5_path = os.path.join(feat_dir, tile_feats_filename) slide_name: str = Path(tile_feats_filename).stem progress.set_description(slide_name) @@ -183,7 +185,9 @@ def _read_h5( elif not h5_path.endswith(".h5"): raise ValueError(f"File is not of type .h5: {os.path.basename(h5_path)}") with h5py.File(h5_path, "r") as f: - feats: Tensor = torch.tensor(f["feats"][:], dtype=self.precision) # type: ignore + feats_ds = cast(h5py.Dataset, f["feats"]) + # torch.from_numpy avoids a redundant data copy vs torch.tensor(array) + feats: Tensor = torch.from_numpy(feats_ds[()]).to(dtype=self.precision) coords: CoordsInfo = get_coords(f) extractor: str = f.attrs.get("extractor", "") if extractor == "": diff --git a/src/stamp/encoding/encoder/chief.py b/src/stamp/encoding/encoder/chief.py index 2ad4b91b..924ceebb 100644 --- a/src/stamp/encoding/encoder/chief.py +++ b/src/stamp/encoding/encoder/chief.py @@ -10,11 +10,11 @@ from numpy import ndarray from tqdm import tqdm -from stamp.cache import STAMP_CACHE_DIR, file_digest, get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, PandasLabel +from stamp.utils.cache import STAMP_CACHE_DIR, file_digest, get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/eagle.py b/src/stamp/encoding/encoder/eagle.py index de49c369..45092f4f 100644 --- a/src/stamp/encoding/encoder/eagle.py +++ b/src/stamp/encoding/encoder/eagle.py @@ -9,13 +9,13 @@ from torch import Tensor from tqdm import tqdm -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.encoding.encoder.chief import CHIEF from stamp.modeling.data import CoordsInfo from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, PandasLabel +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/gigapath.py b/src/stamp/encoding/encoder/gigapath.py index 9cb3f6f5..4c0a2f6b 100644 --- a/src/stamp/encoding/encoder/gigapath.py +++ b/src/stamp/encoding/encoder/gigapath.py @@ -9,12 +9,12 @@ from gigapath import slide_encoder from tqdm import tqdm -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.modeling.data import CoordsInfo from stamp.preprocessing.config import ExtractorName from stamp.types import PandasLabel, SlideMPP +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/madeleine.py b/src/stamp/encoding/encoder/madeleine.py index 5798a592..a0c74dcd 100644 --- a/src/stamp/encoding/encoder/madeleine.py +++ b/src/stamp/encoding/encoder/madeleine.py @@ -3,10 +3,10 @@ import torch from numpy import ndarray -from stamp.cache import STAMP_CACHE_DIR from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.preprocessing.config import ExtractorName +from stamp.utils.cache import STAMP_CACHE_DIR try: from madeleine.models.factory import create_model_from_pretrained diff --git a/src/stamp/encoding/encoder/titan.py b/src/stamp/encoding/encoder/titan.py index 2dba6021..920b3db8 100644 --- a/src/stamp/encoding/encoder/titan.py +++ b/src/stamp/encoding/encoder/titan.py @@ -10,12 +10,12 @@ from tqdm import tqdm from transformers import AutoModel -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.modeling.data import CoordsInfo from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, Microns, PandasLabel, SlideMPP +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/heatmaps/__init__.py b/src/stamp/heatmaps/__init__.py index 446d85d6..fd033aff 100644 --- a/src/stamp/heatmaps/__init__.py +++ b/src/stamp/heatmaps/__init__.py @@ -5,7 +5,7 @@ import logging from collections.abc import Collection, Iterable from pathlib import Path -from typing import cast, no_type_check +from typing import cast import h5py import matplotlib.pyplot as plt @@ -19,7 +19,7 @@ from packaging.version import Version from PIL import Image from torch import Tensor -from torch.func import jacrev # pyright: ignore[reportPrivateImportUsage] +from torch.func import jacrev from stamp.modeling.data import get_coords, get_stride from stamp.modeling.deploy import load_model_from_ckpt @@ -29,6 +29,8 @@ _logger = logging.getLogger("stamp") +_SlideLike = openslide.OpenSlide | openslide.ImageSlide + def _gradcam_per_category( model: torch.nn.Module, @@ -37,23 +39,19 @@ def _gradcam_per_category( ) -> Float[Tensor, "tile category"]: feat_dim = -1 - cam = ( - ( - feats - * jacrev( - lambda bags: model.forward( - bags.unsqueeze(0), - coords=coords.unsqueeze(0), - mask=None, - ).squeeze(0) - )(feats) - ) - .mean(feat_dim) # type: ignore - .abs() + jac = cast( + Tensor, + jacrev( + lambda bags: model.forward( + bags.unsqueeze(0), + coords=coords.unsqueeze(0), + mask=None, + ).squeeze(0) + )(feats), ) + cam = (feats * jac).mean(feat_dim).abs() cam = torch.softmax(cam, dim=-1) - return cam.permute(-1, -2) @@ -79,12 +77,19 @@ def _attention_rollout_single( # --- 2. Rollout computation --- attn_rollout: torch.Tensor | None = None - for layer in model.transformer.layers: # type: ignore - attn = getattr(layer[0], "attn_weights", None) # SelfAttention.attn_weights + transformer = getattr(model, "transformer", None) + if transformer is None: + raise RuntimeError("Model does not have a transformer attribute") + for layer in transformer.layers: + attn = getattr(layer, "attn_weights", None) + if attn is None: + first_child = next(iter(layer.children()), None) + if first_child is not None: + attn = getattr(first_child, "attn_weights", None) if attn is None: raise RuntimeError( "SelfAttention.attn_weights not found. " - "Make sure SelfAttention stores them." + "Make sure SelfAttention stores them on the layer or its first child." ) # attn: [heads, seq, seq] @@ -117,15 +122,18 @@ def _gradcam_single( """ feat_dim = -1 - jac = jacrev( - lambda bags: model.forward( - bags.unsqueeze(0), - coords=coords.unsqueeze(0), - mask=None, - ).squeeze() - )(feats) + jac = cast( + Tensor, + jacrev( + lambda bags: model.forward( + bags.unsqueeze(0), + coords=coords.unsqueeze(0), + mask=None, + ).squeeze() + )(feats), + ) - cam = (feats * jac).mean(feat_dim).abs() # type: ignore # [tile] + cam = (feats * jac).mean(feat_dim).abs() # [tile] return cam @@ -148,17 +156,21 @@ def _vals_to_im( def _show_thumb( - slide, thumb_ax: Axes, attention: Tensor, default_slide_mpp: SlideMPP | None + slide: _SlideLike, + thumb_ax: Axes, + attention: Tensor, + default_slide_mpp: SlideMPP | None, ) -> np.ndarray: mpp = get_slide_mpp_(slide, default_mpp=default_slide_mpp) dims_um = np.array(slide.dimensions) * mpp - thumb = slide.get_thumbnail(np.round(dims_um * 8 / 256).astype(int)) + thumb_size = tuple(np.round(dims_um * 8 / 256).astype(int).tolist()) + thumb = slide.get_thumbnail(thumb_size) thumb_ax.imshow(np.array(thumb)[: attention.shape[0] * 8, : attention.shape[1] * 8]) return np.array(thumb)[: attention.shape[0] * 8, : attention.shape[1] * 8] def _get_thumb_array( - slide, + slide: _SlideLike, attention: torch.Tensor, default_slide_mpp: SlideMPP | None, ) -> np.ndarray: @@ -168,12 +180,12 @@ def _get_thumb_array( """ mpp = get_slide_mpp_(slide, default_mpp=default_slide_mpp) dims_um = np.array(slide.dimensions) * mpp - thumb = np.array(slide.get_thumbnail(np.round(dims_um * 8 / 256).astype(int))) + thumb_size = tuple(np.round(dims_um * 8 / 256).astype(int).tolist()) + thumb = np.array(slide.get_thumbnail(thumb_size)) thumb_crop = thumb[: attention.shape[0] * 8, : attention.shape[1] * 8] return thumb_crop -@no_type_check # beartype<=0.19.0 breaks here for some reason def _show_class_map( class_ax: Axes, top_score_indices: Integer[Tensor, "width height"], @@ -298,13 +310,8 @@ def heatmaps_( raise ValueError( f"Feature file {h5_path} is a slide or patient level feature. Heatmaps are currently supported for tile-level features only." ) - feats = ( - torch.tensor( - h5["feats"][:] # pyright: ignore[reportIndexIssue] - ) - .float() - .to(device) - ) + feats_np = np.asarray(h5["feats"]) + feats = torch.from_numpy(feats_np).float().to(device) coords_info = get_coords(h5) coords_um = torch.from_numpy(coords_info.coords_um).float() stride_um = Microns(get_stride(coords_um)) @@ -322,9 +329,10 @@ def heatmaps_( model = load_model_from_ckpt(checkpoint_path).eval() # TODO: Update version when a newer model logic breaks heatmaps. - if Version(model.stamp_version) < Version("2.4.0"): + stamp_version = str(getattr(model, "stamp_version", "")) + if Version(stamp_version) < Version("2.4.1"): raise ValueError( - f"model has been built with stamp version {model.stamp_version} " + f"model has been built with stamp version {stamp_version} " f"which is incompatible with the current version." ) @@ -356,7 +364,7 @@ def heatmaps_( with torch.no_grad(): scores = torch.softmax( - model.model.forward( + model.model( feats.unsqueeze(-2), coords=coords_um.unsqueeze(-2), mask=torch.zeros( diff --git a/src/stamp/modeling/config.py b/src/stamp/modeling/config.py index 21ce69db..5b9a6bcc 100644 --- a/src/stamp/modeling/config.py +++ b/src/stamp/modeling/config.py @@ -21,7 +21,7 @@ class TrainConfig(BaseModel): ) feature_dir: Path = Field(description="Directory containing feature files") - ground_truth_label: PandasLabel | None = Field( + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = Field( default=None, description="Name of categorical column in clinical table to train on", ) @@ -64,7 +64,7 @@ class DeploymentConfig(BaseModel): slide_table: Path feature_dir: Path - ground_truth_label: PandasLabel | None = None + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None patient_label: PandasLabel = "PATIENT" filename_label: PandasLabel = "FILENAME" @@ -99,8 +99,29 @@ class TransMILModelParams(BaseModel): dim_hidden: int = 512 +class BarspoonParams(BaseModel): + model_config = ConfigDict(extra="forbid") + d_model: int = 512 + num_encoder_heads: int = 8 + num_decoder_heads: int = 8 + num_encoder_layers: int = 2 + num_decoder_layers: int = 2 + dim_feedforward: int = 2048 + positional_encoding: bool = True + # Other hparams + learning_rate: float = 1e-4 + + class LinearModelParams(BaseModel): model_config = ConfigDict(extra="forbid") + num_encoder_heads: int = 8 + num_decoder_heads: int = 8 + num_encoder_layers: int = 2 + num_decoder_layers: int = 2 + dim_feedforward: int = 2048 + positional_encoding: bool = True + # Other hparams + learning_rate: float = 1e-4 class ModelParams(BaseModel): @@ -109,6 +130,7 @@ class ModelParams(BaseModel): trans_mil: TransMILModelParams = Field(default_factory=TransMILModelParams) mlp: MlpModelParams = Field(default_factory=MlpModelParams) linear: LinearModelParams = Field(default_factory=LinearModelParams) + barspoon: BarspoonParams = Field(default_factory=BarspoonParams) class AdvancedConfig(BaseModel): diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index 0ff037cf..4ee71563 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -1,8 +1,10 @@ import logging +from collections import Counter from collections.abc import Mapping, Sequence -from typing import Any, Final +from typing import Any, cast import numpy as np +import torch from pydantic import BaseModel from sklearn.model_selection import KFold, StratifiedKFold @@ -10,13 +12,8 @@ from stamp.modeling.data import ( PatientData, create_dataloader, - detect_feature_type, - filter_complete_patient_data_, - load_patient_level_data, + load_patient_data_, log_patient_class_summary, - patient_to_ground_truth_from_clini_table_, - patient_to_survival_from_clini_table_, - slide_to_patient_from_slide_table_, ) from stamp.modeling.deploy import ( _predict, @@ -25,10 +22,9 @@ _to_survival_prediction_df, load_model_from_ckpt, ) -from stamp.modeling.train import setup_model_for_training, train_model_ +from stamp.modeling.train import setup_model_from_dataloaders, train_model_ from stamp.modeling.transforms import VaryPrecisionTransform from stamp.types import ( - FeaturePath, GroundTruth, PatientId, ) @@ -53,86 +49,59 @@ def categorical_crossval_( config: CrossvalConfig, advanced: AdvancedConfig, ) -> None: - feature_type = detect_feature_type(config.feature_dir) + if config.task is None: + raise ValueError( + "task must be set to 'classification' | 'regression' | 'survival'" + ) + + patient_to_data, feature_type = load_patient_data_( + feature_dir=config.feature_dir, + clini_table=config.clini_table, + slide_table=config.slide_table, + task=config.task, + ground_truth_label=config.ground_truth_label, + time_label=config.time_label, + status_label=config.status_label, + patient_label=config.patient_label, + filename_label=config.filename_label, + drop_patients_with_missing_ground_truth=True, + ) _logger.info(f"Detected feature type: {feature_type}") - if feature_type in ("tile", "slide"): - if config.slide_table is None: - raise ValueError("A slide table is required for modeling") - if config.task == "survival": - if config.time_label is None or config.status_label is None: - raise ValueError( - "Both time_label and status_label are is required for survival modeling" - ) - patient_to_ground_truth: dict[PatientId, GroundTruth] = ( - patient_to_survival_from_clini_table_( - clini_table_path=config.clini_table, - time_label=config.time_label, - status_label=config.status_label, - patient_label=config.patient_label, - ) - ) - else: - if config.ground_truth_label is None: - raise ValueError( - "Ground truth label is required for classification or regression modeling" - ) - patient_to_ground_truth: dict[PatientId, GroundTruth] = ( - patient_to_ground_truth_from_clini_table_( - clini_table_path=config.clini_table, - ground_truth_label=config.ground_truth_label, - patient_label=config.patient_label, - ) - ) - slide_to_patient: Final[dict[FeaturePath, PatientId]] = ( - slide_to_patient_from_slide_table_( - slide_table_path=config.slide_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - filename_label=config.filename_label, - ) - ) - patient_to_data: Mapping[PatientId, PatientData] = ( - filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, - slide_to_patient=slide_to_patient, - drop_patients_with_missing_ground_truth=True, - ) - ) - elif feature_type == "patient": - patient_to_data: Mapping[PatientId, PatientData] = load_patient_level_data( - task=config.task, - clini_table=config.clini_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - ground_truth_label=config.ground_truth_label, - time_label=config.time_label, - status_label=config.status_label, - ) - patient_to_ground_truth: dict[PatientId, GroundTruth] = { - pid: pd.ground_truth for pid, pd in patient_to_data.items() - } - else: - raise RuntimeError(f"Unsupported feature type: {feature_type}") + patient_to_ground_truth = { + pid: pd.ground_truth for pid, pd in patient_to_data.items() + } + + if feature_type not in ("tile", "slide", "patient"): + raise ValueError(f"Unknown feature type: {feature_type}") config.output_dir.mkdir(parents=True, exist_ok=True) splits_file = config.output_dir / "splits.json" # Generate the splits, or load them from the splits file if they already exist if not splits_file.exists(): - splits = ( - _get_splits( - patient_to_data=patient_to_data, - n_splits=config.n_splits, - spliter=KFold, - ) - if config.task == "regression" - else _get_splits( - patient_to_data=patient_to_data, - n_splits=config.n_splits, - spliter=StratifiedKFold, - ) + # Detect multi-target classification (ground_truth is a dict) + is_multitarget = any( + isinstance(pd.ground_truth, dict) for pd in patient_to_data.values() ) + + # Use KFold for regression or multi-target classification; otherwise StratifiedKFold. + # For survival we want StratifiedKFold so folds are balanced by event status. + spliter = ( + KFold + if (config.task == "regression" or is_multitarget) + else StratifiedKFold + ) + + _logger.info(f"Using {spliter.__name__} for cross-validation splits") + + splits = _get_splits( + patient_to_data=patient_to_data, + n_splits=config.n_splits, + spliter=spliter, + task=config.task, + ) + with open(splits_file, "w") as fp: fp.write(splits.model_dump_json(indent=4)) else: @@ -158,18 +127,51 @@ def categorical_crossval_( f"{ground_truths_not_in_split}" ) + categories_for_export: ( + dict[str, list] | list + ) = [] # declare upfront to avoid unbound variable warnings + categories: Sequence[GroundTruth] | list | None = [] + if config.task == "classification": - categories = config.categories or sorted( - { - patient_data.ground_truth - for patient_data in patient_to_data.values() - if patient_data.ground_truth is not None - } - ) - log_patient_class_summary( - patient_to_data={pid: patient_to_data[pid] for pid in patient_to_data}, - categories=categories, - ) + # Determine categories for training (single-target) and for export (supports multi-target) + if isinstance(config.ground_truth_label, str): + categories = config.categories or sorted( + { + patient_data.ground_truth + for patient_data in patient_to_data.values() + if patient_data.ground_truth is not None + } + ) + log_patient_class_summary( + patient_to_data={pid: patient_to_data[pid] for pid in patient_to_data}, + categories=categories, + ) + categories_for_export = cast(list, categories) + else: + # Multi-target: build a mapping from target label -> sorted list of categories + categories_accum: dict[str, set[GroundTruth]] = {} + for patient_data in patient_to_data.values(): + gt = patient_data.ground_truth + if isinstance(gt, dict): + for k, v in gt.items(): + if v is not None: + categories_accum.setdefault(k, set()).add(v) + categories_for_export = {k: sorted(v) for k, v in categories_accum.items()} + # Log summary per target + for t, cats in categories_for_export.items(): + ground_truths = [ + pd.ground_truth.get(t) + for pd in patient_to_data.values() + if isinstance(pd.ground_truth, dict) + and pd.ground_truth.get(t) is not None + ] + counter = Counter(ground_truths) + _logger.info( + f"{t} | Total patients: {len(ground_truths)} | " + + " | ".join([f"Class {c}: {counter.get(c, 0)}" for c in cats]) + ) + # For training, categories can remain None (inferred later) + categories = config.categories or None else: categories = [] @@ -190,42 +192,93 @@ def categorical_crossval_( # Train the model if not (split_dir / "model.ckpt").exists(): - model, train_dl, valid_dl = setup_model_for_training( - clini_table=config.clini_table, - slide_table=config.slide_table, - feature_dir=config.feature_dir, - ground_truth_label=config.ground_truth_label, - time_label=config.time_label, - status_label=config.status_label, - advanced=advanced, - task=config.task, - patient_to_data={ - patient_id: patient_data - for patient_id, patient_data in patient_to_data.items() - if patient_id in split.train_patients - }, - categories=( - categories - or sorted( + # Build train and test dataloaders directly (pure 2-way k-fold split) + train_patient_ids = [ + pid for pid in split.train_patients if pid in patient_to_data + ] + test_patient_ids = [ + pid for pid in split.test_patients if pid in patient_to_data + ] + train_patient_data = [patient_to_data[pid] for pid in train_patient_ids] + test_patient_data = [patient_to_data[pid] for pid in test_patient_ids] + + fold_categories = ( + categories + if categories is not None + else ( + sorted( { patient_data.ground_truth for patient_data in patient_to_data.values() if patient_data.ground_truth is not None + and not isinstance(patient_data.ground_truth, dict) } ) - ), - train_transform=( - VaryPrecisionTransform(min_fraction_bits=1) - if config.use_vary_precision_transform + if not isinstance(config.ground_truth_label, Sequence) else None - ), + ) + ) + + train_transform = ( + VaryPrecisionTransform(min_fraction_bits=1) + if config.use_vary_precision_transform + else None + ) + + train_dl, train_categories = create_dataloader( feature_type=feature_type, + task=config.task, + patient_data=train_patient_data, + bag_size=advanced.bag_size, + batch_size=advanced.batch_size, + shuffle=True, + num_workers=advanced.num_workers, + transform=train_transform, + categories=fold_categories, + ) + test_dl, _ = create_dataloader( + feature_type=feature_type, + task=config.task, + patient_data=test_patient_data, + bag_size=None, + batch_size=1, + shuffle=False, + num_workers=advanced.num_workers, + transform=None, + categories=train_categories, + ) + + # Infer feature dimension + batch = next(iter(train_dl)) + if feature_type == "tile": + bags, _, _, _ = batch + dim_feats = bags.shape[-1] + else: + feats, _ = batch + dim_feats = feats.shape[-1] + + model = setup_model_from_dataloaders( + train_dl=train_dl, + valid_dl=test_dl, + task=config.task, + train_categories=train_categories, + dim_feats=dim_feats, + train_patients=train_patient_ids, + valid_patients=test_patient_ids, + feature_type=feature_type, + advanced=advanced, + ground_truth_label=config.ground_truth_label, + time_label=config.time_label, + status_label=config.status_label, + clini_table=config.clini_table, + slide_table=config.slide_table, + feature_dir=config.feature_dir, ) model = train_model_( output_dir=split_dir, model=model, train_dl=train_dl, - valid_dl=valid_dl, + valid_dl=test_dl, max_epochs=advanced.max_epochs, patience=advanced.patience, accelerator=advanced.accelerator, @@ -236,9 +289,9 @@ def categorical_crossval_( else: model = load_model_from_ckpt(split_dir / "model.ckpt") - # Deploy on test set + # Deploy on test fold (used as validation/prediction set) if not (split_dir / "patient-preds.csv").exists(): - # Prepare test dataloader + # Prepare validation dataloader for predictions test_patients = [ pid for pid in split.test_patients if pid in patient_to_data ] @@ -263,58 +316,109 @@ def categorical_crossval_( ) if config.task == "survival": - _to_survival_prediction_df( - patient_to_ground_truth=patient_to_ground_truth, - predictions=predictions, - patient_label=config.patient_label, - cut_off=getattr(model.hparams, "train_pred_median", None), - ).to_csv(split_dir / "patient-preds.csv", index=False) + # Export only when patients have single-target survival labels ("time status"). + # Don't rely on `ground_truth_label` for survival — check the loaded patient data. + if any(isinstance(gt, dict) for gt in patient_to_ground_truth.values()): + _logger.warning( + "Multi-target survival prediction export not yet supported; skipping CSV save" + ) + else: + _to_survival_prediction_df( + patient_to_ground_truth=cast( + Mapping[ + PatientId, str | tuple[float | None, int | None] | None + ], + patient_to_ground_truth, + ), + predictions=cast(Mapping[PatientId, torch.Tensor], predictions), + patient_label=config.patient_label, + cut_off=getattr(model.hparams, "train_pred_median", None), + ).to_csv(split_dir / "patient-preds.csv", index=False) elif config.task == "regression": if config.ground_truth_label is None: - raise RuntimeError("Grounf truth label is required for regression") - _to_regression_prediction_df( - patient_to_ground_truth=patient_to_ground_truth, - predictions=predictions, - patient_label=config.patient_label, - ground_truth_label=config.ground_truth_label, - ).to_csv(split_dir / "patient-preds.csv", index=False) + raise RuntimeError("Ground truth label is required for regression") + if isinstance(config.ground_truth_label, str): + _to_regression_prediction_df( + patient_to_ground_truth=cast( + Mapping[PatientId, str | None], patient_to_ground_truth + ), + predictions=cast(Mapping[PatientId, torch.Tensor], predictions), + patient_label=config.patient_label, + ground_truth_label=config.ground_truth_label, + ).to_csv(split_dir / "patient-preds.csv", index=False) + else: + _logger.warning( + "Multi-target regression prediction export not yet supported; skipping CSV save" + ) else: if config.ground_truth_label is None: raise RuntimeError( - "Grounf truth label is required for classification" + "Ground truth label is required for classification" ) _to_prediction_df( - categories=categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, - predictions=predictions, + predictions=cast( + Mapping[PatientId, torch.Tensor] + | Mapping[PatientId, dict[str, torch.Tensor]], + predictions, + ), patient_label=config.patient_label, ground_truth_label=config.ground_truth_label, ).to_csv(split_dir / "patient-preds.csv", index=False) def _get_splits( - *, patient_to_data: Mapping[PatientId, PatientData[Any]], n_splits: int, spliter + *, + patient_to_data: Mapping[PatientId, PatientData[Any]], + n_splits: int, + spliter, + task: str | None = None, ) -> _Splits: patients = np.array(list(patient_to_data.keys())) - # Detect survival GT: "time status" - tokens = [str(p.ground_truth).split() for p in patient_to_data.values()] + # Build stratification labels depending on the task + gts = [patient.ground_truth for patient in patient_to_data.values()] + + if task == "survival": + # use event status (0/1) for stratification + # Ground-truths are expected to be pre-parsed into (time, status) tuples + # by the data loading pipeline; extract the status element directly. + statuses: list[int] = [] + for gt in gts: + # support multi-target fallback: use first target + val = next(iter(gt.values())) if isinstance(gt, dict) else gt + # If structured (time, status) use second element, otherwise assume val is status + if isinstance(val, (tuple, list)) and len(val) == 2: + status_val = val[1] + else: + status_val = val + statuses.append(int(cast(int, status_val)) if status_val is not None else 0) + + y_strat = np.array(statuses) + elif task == "classification": + # For multi-target (dict), use the first target's value + y_strat = np.array( + [next(iter(gt.values())) if isinstance(gt, dict) else gt for gt in gts] + ) + else: + # regression or unknown: do not stratify (KFold will ignore y) + y_strat = None + + skf = spliter(n_splits=n_splits, shuffle=True, random_state=0) - if all(len(t) == 2 for t in tokens): - y = np.array([int(t[1]) for t in tokens], dtype=int) - skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0) - iterator = skf.split(patients, y) + if y_strat is None: + splits_iter = skf.split(patients) else: - skf = KFold(n_splits=n_splits, shuffle=True, random_state=0) - iterator = skf.split(patients) + splits_iter = skf.split(patients, y_strat) splits = _Splits( splits=[ _Split( - train_patients=set(patients[train_idx]), - test_patients=set(patients[test_idx]), + train_patients=set(patients[train_indices]), + test_patients=set(patients[test_indices]), ) - for train_idx, test_idx in iterator + for train_indices, test_indices in splits_iter ] ) return splits diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 21b86176..c3696c0f 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -1,23 +1,37 @@ """Helper classes to manage pytorch data.""" import logging +from collections import OrderedDict from collections.abc import Callable, Iterable, Mapping, Sequence from dataclasses import KW_ONLY, dataclass +from io import BytesIO # accept in _BinaryIOLike at runtime from itertools import groupby from pathlib import Path -from typing import IO, BinaryIO, Counter, Generic, TextIO, TypeAlias, Union, cast +from typing import ( + IO, + Any, + BinaryIO, + Final, + Generic, + List, + TextIO, + TypeAlias, + Union, + cast, +) import h5py import numpy as np import pandas as pd import torch -from jaxtyping import Float + +# Use beartype's typing for PEP-585 deprecation-safe hints +from beartype.typing import Dict from packaging.version import Version from torch import Tensor from torch.utils.data import DataLoader, Dataset import stamp -from stamp.seed import Seed from stamp.types import ( Bags, BagSize, @@ -35,6 +49,7 @@ Task, TilePixels, ) +from stamp.utils.seed import Seed _logger = logging.getLogger("stamp") @@ -43,14 +58,19 @@ __copyright__ = "Copyright (C) 2022-2025 Marko van Treeck, Minh Duc Nguyen" __license__ = "MIT" -_Bag: TypeAlias = Float[Tensor, "tile feature"] -_EncodedTarget: TypeAlias = Float[Tensor, "category_is_hot"] | Float[Tensor, "1"] # noqa: F821 -_BinaryIOLike: TypeAlias = Union[BinaryIO, IO[bytes]] +_Bag: TypeAlias = Tensor +_EncodedTarget: TypeAlias = ( + Tensor | dict[str, Tensor] +) # Union of encoded targets or multi-target dict +_BinaryIOLike: TypeAlias = Union[ + BinaryIO, IO[bytes], BytesIO +] # includes io.BytesIO for runtime checks """The ground truth, encoded numerically - classification: one-hot float [C] - regression: float [1] +- multi-target: dict[target_name -> one-hot/regression value] """ -_Coordinates: TypeAlias = Float[Tensor, "tile 2"] +_Coordinates: TypeAlias = Tensor @dataclass @@ -59,12 +79,12 @@ class PatientData(Generic[GroundTruthType]): _ = KW_ONLY ground_truth: GroundTruthType - feature_files: Iterable[FeaturePath | BinaryIO] + feature_files: Iterable[FeaturePath | _BinaryIOLike] def tile_bag_dataloader( *, - patient_data: Sequence[PatientData[GroundTruth | None]], + patient_data: Sequence[PatientData[GroundTruth | None | dict]], bag_size: int | None, task: Task, categories: Sequence[Category] | None = None, @@ -74,7 +94,7 @@ def tile_bag_dataloader( transform: Callable[[Tensor], Tensor] | None, ) -> tuple[ DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], - Sequence[Category], + Sequence[Category] | Mapping[str, Sequence[Category]], ]: """Creates a dataloader from patient data for tile-level (bagged) features. @@ -86,103 +106,151 @@ def tile_bag_dataloader( task='regression': returns float targets """ - if task == "classification": - raw_ground_truths = np.array([patient.ground_truth for patient in patient_data]) - categories = ( - categories if categories is not None else list(np.unique(raw_ground_truths)) - ) - # one_hot = torch.tensor(raw_ground_truths.reshape(-1, 1) == categories) - one_hot = torch.tensor( - raw_ground_truths.reshape(-1, 1) == categories, dtype=torch.float32 - ) - ds = BagDataset( - bags=[patient.feature_files for patient in patient_data], - bag_size=bag_size, - ground_truths=one_hot, - transform=transform, - ) - cats_out: Sequence[Category] = list(categories) - elif task == "regression": - raw_targets = np.array( - [ - np.nan if p.ground_truth is None else float(p.ground_truth) - for p in patient_data - ], - dtype=np.float32, - ) - y = torch.from_numpy(raw_targets).reshape(-1, 1) + targets, cats_out = _parse_targets( + patient_data=patient_data, + task=task, + categories=categories, + ) - ds = BagDataset( - bags=[patient.feature_files for patient in patient_data], - bag_size=bag_size, - ground_truths=y, - transform=transform, - ) - cats_out = [] + is_multitarget = isinstance(targets[0], dict) - elif task == "survival": # Not yet support logistic-harzard - times: list[float] = [] - events: list[float] = [] + collate_fn = _collate_multitarget if is_multitarget else _collate_to_tuple - for p in patient_data: - if p.ground_truth is None: - times.append(np.nan) - events.append(np.nan) - continue + ds = BagDataset( + bags=[patient.feature_files for patient in patient_data], + bag_size=bag_size, + ground_truths=targets, + transform=transform, + deterministic=(not shuffle), + ) + dl = DataLoader( + ds, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=collate_fn, + worker_init_fn=Seed.get_loader_worker_init() if Seed._is_set() else None, + persistent_workers=(num_workers > 0), + pin_memory=torch.cuda.is_available(), + ) - try: - time_str, status_str = p.ground_truth.split(" ", 1) + return ( + cast( + DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], + dl, + ), + cats_out, + ) - # Handle missing values encoded as "nan" - if time_str.lower() == "nan": - times.append(np.nan) - else: - times.append(float(time_str)) - - if status_str.lower() == "nan": - events.append(np.nan) - elif status_str.lower() in {"dead", "event", "1", "Yes", "yes"}: - events.append(1.0) - elif status_str.lower() in {"alive", "censored", "0", "No", "no"}: - events.append(0.0) - else: - events.append(np.nan) # unknown status → mark missing - except Exception: +def _parse_targets( + *, + patient_data: Sequence, + task: Task, + categories: Sequence[Category] | None = None, + target_spec: dict[str, Any] | None = None, + target_label: str | None = None, +) -> tuple[ + Union[torch.Tensor, list[dict[str, torch.Tensor]]], + Sequence[Category] | Mapping[str, Sequence[Category]], +]: + """ + Parse raw GroundTruth (str) into model-ready tensors. + This is the ONLY place task semantics live. + """ + + gts = [p.ground_truth for p in patient_data] + + if task == "classification": + if any(isinstance(gt, dict) for gt in gts if gt is not None): + # infer target names from the first non-None dict + first_dict = next(gt for gt in gts if isinstance(gt, dict)) + target_names = list(first_dict.keys()) + + # infer categories per target (ignore None patients, ignore None values) + categories_out: dict[str, list[str]] = {t: [] for t in target_names} + for gt in gts: + if not isinstance(gt, dict): + continue + for t in target_names: + v = gt.get(t) + if v is not None: + categories_out[t].append(v) + + # make unique + sorted + categories_out = { + t: sorted(set(vals)) for t, vals in categories_out.items() + } + + # encode per patient; if gt missing -> all zeros + encoded: list[dict[str, Tensor]] = [] + for gt in gts: + patient_encoded: dict[str, Tensor] = {} + for t in target_names: + cats = categories_out[t] + if not isinstance(gt, dict) or gt.get(t) is None: + one_hot = torch.zeros(len(cats), dtype=torch.float32) + else: + one_hot = torch.tensor( + [gt[t] == c for c in cats], + dtype=torch.float32, + ) + patient_encoded[t] = one_hot + encoded.append(patient_encoded) + + # IMPORTANT: return categories as mapping, not list-of-target-names + return encoded, categories_out + + # single target + unique = {gt for gt in gts if gt is not None} + if len(unique) >= 2 or categories is not None: + raw = np.array([p.ground_truth for p in patient_data]) + categories = categories or list(sorted(unique)) + labels = torch.tensor( + raw.reshape(-1, 1) == categories, + dtype=torch.float32, + ) + return labels, categories + + raise ValueError( + "Only one unique class found in classification task. " + "This is usually a data or configuration error." + ) + + elif task == "regression": + y = torch.tensor( + [np.nan if gt is None else float(gt) for gt in gts], + dtype=torch.float32, + ).reshape(-1, 1) + return y, [] + + elif task == "survival": + times, events = [], [] + for gt in gts: + if gt is None: times.append(np.nan) events.append(np.nan) + continue + # Expect a structured tuple/list (time, event). + if isinstance(gt, (tuple, list)) and len(gt) == 2: + t_val, e_val = gt + times.append( + np.nan + if t_val is None or str(t_val).lower() == "nan" + else float(t_val) + ) + events.append(float(e_val) if e_val is not None else np.nan) + else: + raise ValueError( + "survival ground truth must be a (time, event) tuple/list" + ) - # Final tensor shape: (N, 2) y = torch.tensor(np.column_stack([times, events]), dtype=torch.float32) - - ds = BagDataset( - bags=[patient.feature_files for patient in patient_data], - bag_size=bag_size, - ground_truths=y, - transform=transform, - ) - cats_out: Sequence[Category] = [] # survival has no categories + return y, [] else: - raise ValueError(f"Unknown task: {task}") - - return ( - cast( - DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], - DataLoader( - ds, - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - collate_fn=_collate_to_tuple, - worker_init_fn=Seed.get_loader_worker_init() - if Seed._is_set() - else None, - ), - ), - cats_out, - ) + raise ValueError(f"Unsupported task: {task}") def _collate_to_tuple( @@ -210,6 +278,24 @@ def _collate_to_tuple( return (bags, coords, bag_sizes, encoded_targets) +def _collate_multitarget( + items: list[tuple[_Bag, _Coordinates, BagSize, Dict[str, Tensor]]], +) -> tuple[Bags, CoordinatesBatch, BagSizes, Dict[str, Tensor]]: + bags = torch.stack([b for b, _, _, _ in items]) + coords = torch.stack([c for _, c, _, _ in items]) + bag_sizes = torch.tensor([s for _, _, s, _ in items]) + + acc: Dict[str, List[Tensor]] = {} + + for _, _, _, tdict in items: + for k, v in tdict.items(): + acc.setdefault(k, []).append(v) + + targets: Dict[str, Tensor] = {k: torch.stack(v, dim=0) for k, v in acc.items()} + + return bags, coords, bag_sizes, targets + + def patient_feature_dataloader( *, patient_data: Sequence[PatientData[GroundTruth | None]], @@ -237,21 +323,31 @@ def create_dataloader( *, feature_type: str, task: Task, - patient_data: Sequence[PatientData[GroundTruth | None]], + patient_data: Sequence[PatientData[GroundTruth | None | dict]], bag_size: int | None = None, batch_size: int, shuffle: bool, num_workers: int, transform: Callable[[Tensor], Tensor] | None, - categories: Sequence[Category] | None = None, -) -> tuple[DataLoader, Sequence[Category]]: + categories: Sequence[Category] | Mapping[str, Sequence[Category]] | None = None, +) -> tuple[DataLoader, Sequence[Category] | Mapping[str, Sequence[Category]]]: """Unified dataloader for all feature types and tasks.""" if feature_type == "tile": + # For multi-target classification, categories may be a mapping from + # target name to per-target categories. _parse_targets (used inside + # tile_bag_dataloader) only consumes explicit categories for the + # single-target case, so we pass a sequence or None here. + cats_arg: Sequence[Category] | None + if isinstance(categories, Mapping): + cats_arg = None + else: + cats_arg = categories + return tile_bag_dataloader( patient_data=patient_data, bag_size=bag_size, task=task, - categories=categories, + categories=cats_arg, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, @@ -266,19 +362,46 @@ def create_dataloader( categories = categories or list(np.unique(raw)) labels = torch.tensor(raw.reshape(-1, 1) == categories, dtype=torch.float32) elif task == "regression": - labels = torch.tensor( - [ - float(gt) - for gt in (p.ground_truth for p in patient_data) - if gt is not None - ], - dtype=torch.float32, - ).reshape(-1, 1) + values: list[float] = [] + for gt in (p.ground_truth for p in patient_data): + if gt is None: + continue + if isinstance(gt, dict): + raise ValueError( + "Multi-target regression is not supported; provide a single numeric target per patient" + ) + values.append(float(gt)) + + labels = torch.tensor(values, dtype=torch.float32).reshape(-1, 1) elif task == "survival": times, events = [], [] for p in patient_data: - t, e = (p.ground_truth or "nan nan").split(" ", 1) - times.append(float(t) if t.lower() != "nan" else np.nan) + if isinstance(p.ground_truth, dict): + raise ValueError( + "Multi-target survival is not supported; provide a single survival time/status per patient" + ) + gt = p.ground_truth + # Prefer structured (time, event) tuples. Do NOT call .split + # on the ground-truth value. If not a tuple/list, treat time + # as the whole value and event as unknown. + if isinstance(gt, (tuple, list)) and len(gt) == 2: + t, e = gt + elif gt is None: + t, e = None, None + else: + t, e = str(gt), "nan" + + # Parse time defensively + if t is None: + times.append(np.nan) + elif isinstance(t, str): + try: + times.append(np.nan if t.lower() == "nan" else float(t)) + except Exception: + times.append(np.nan) + else: + times.append(float(t)) + events.append(_parse_survival_status(e)) labels = torch.tensor(np.column_stack([times, events]), dtype=torch.float32) @@ -292,6 +415,8 @@ def create_dataloader( shuffle=shuffle, num_workers=num_workers, worker_init_fn=Seed.get_loader_worker_init() if Seed._is_set() else None, + persistent_workers=(num_workers > 0), + pin_memory=torch.cuda.is_available(), ) return dl, categories or [] else: @@ -340,7 +465,7 @@ def load_patient_level_data( clini_table: Path, feature_dir: Path, patient_label: PandasLabel, - ground_truth_label: PandasLabel | None = None, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None, time_label: PandasLabel | None = None, status_label: PandasLabel | None = None, feature_ext: str = ".h5", @@ -355,6 +480,15 @@ def load_patient_level_data( """ # Load ground truth mapping + # Multi-target ground truths are only supported for classification. + if task is not None and task != "classification": + if isinstance(ground_truth_label, Sequence) and not isinstance( + ground_truth_label, str + ): + raise ValueError( + "Multi-target ground_truth_label is only supported for classification tasks" + ) + if task == "survival" and time_label is not None and status_label is not None: # Survival: use the existing helper patient_to_ground_truth = patient_to_survival_from_clini_table_( @@ -419,18 +553,32 @@ class BagDataset(Dataset[tuple[_Bag, _Coordinates, BagSize, _EncodedTarget]]): If `bag_size` is None, all the samples will be used. """ - ground_truths: Float[Tensor, "index category_is_hot"] | Float[Tensor, "index 1"] + ground_truths: Tensor | list[dict[str, Tensor]] # ground_truths: Bool[Tensor, "index category_is_hot"] # """The ground truth for each bag, one-hot encoded.""" transform: Callable[[Tensor], Tensor] | None + deterministic: bool = False def __post_init__(self) -> None: if len(self.bags) != len(self.ground_truths): raise ValueError( "the number of ground truths has to match the number of bags" ) + # Initialise per-worker HDF5 handle cache here so __getitem__ avoids + # a hasattr() call on every tile read. + self._h5_handle_cache: OrderedDict[FeaturePath | _BinaryIOLike, h5py.File] = ( + OrderedDict() + ) + + def __getstate__(self) -> dict: + # h5py file handles cannot be pickled (required when DataLoader uses + # spawn-based multiprocessing). Drop the cache; each worker reopens + # files lazily on the first __getitem__ access. + state = self.__dict__.copy() + state["_h5_handle_cache"] = OrderedDict() + return state def __len__(self) -> int: return len(self.bags) @@ -442,14 +590,46 @@ def __getitem__( feats = [] coords_um = [] for bag_file in self.bags[index]: - with h5py.File(bag_file, "r") as h5: - if "feats" in h5: - arr = h5["feats"][:] # pyright: ignore[reportIndexIssue] # original STAMP files - else: - arr = h5["patch_embeddings"][:] # type: ignore # your Kronos files + if bag_file not in self._h5_handle_cache: + # Limit open handles to avoid reaching OS ulimits + if len(self._h5_handle_cache) >= 128: + _, h = self._h5_handle_cache.popitem(last=False) + h.close() + + try: + # libver='latest' and swmr=True can provide better performance + # on some network/HPC filesystems + self._h5_handle_cache[bag_file] = h5py.File( + bag_file, "r", swmr=True, libver="latest" + ) + except Exception: + # Fallback for older HDF5 files or unconventional storage + self._h5_handle_cache[bag_file] = h5py.File(bag_file, "r") + else: + # Move recently accessed file to end (mark as recently used) + self._h5_handle_cache.move_to_end(bag_file) + + h5 = self._h5_handle_cache[bag_file] + + if "feats" in h5: + feats_obj = h5["feats"] + if not isinstance(feats_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" + ) + arr = feats_obj[ + () + ] # uses [()] instead of [:] for clarity, both read entire dataset + else: + embeddings_obj = h5["patch_embeddings"] + if not isinstance(embeddings_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'patch_embeddings' to be an HDF5 dataset but got {type(embeddings_obj)}" + ) + arr = embeddings_obj[()] # your Kronos files - feats.append(torch.from_numpy(arr)) - coords_um.append(torch.from_numpy(get_coords(h5).coords_um)) + feats.append(torch.from_numpy(arr)) + coords_um.append(torch.from_numpy(get_coords(h5).coords_um)) feats = torch.concat(feats).float() coords_um = torch.concat(coords_um).float() @@ -460,7 +640,12 @@ def __getitem__( # Sample a subset, if required if self.bag_size is not None: return ( - *_to_fixed_size_bag(feats, coords=coords_um, bag_size=self.bag_size), + *_to_fixed_size_bag( + feats, + coords=coords_um, + bag_size=self.bag_size, + deterministic=self.deterministic, + ), self.ground_truths[index], ) else: @@ -480,7 +665,7 @@ class PatientFeatureDataset(Dataset): def __init__( self, - feature_files: Sequence[FeaturePath | BinaryIO], + feature_files: Sequence[FeaturePath | _BinaryIOLike], ground_truths: Tensor, # shape: [num_samples, num_classes] transform: Callable[[Tensor], Tensor] | None, ): @@ -489,26 +674,53 @@ def __init__( self.feature_files = feature_files self.ground_truths = ground_truths self.transform = transform + # Initialise per-worker HDF5 handle cache eagerly so __getitem__ avoids + # a hasattr() call on every sample read. + self._h5_handle_cache: dict[FeaturePath | _BinaryIOLike, h5py.File] = {} + + def __getstate__(self) -> dict: + # h5py file handles cannot be pickled (required when DataLoader uses + # spawn-based multiprocessing). Drop the cache; each worker reopens + # files lazily on the first __getitem__ access. + state = self.__dict__.copy() + state["_h5_handle_cache"] = {} + return state def __len__(self): return len(self.feature_files) def __getitem__(self, idx: int): feature_file = self.feature_files[idx] - with h5py.File(feature_file, "r") as h5: - feats = torch.from_numpy(h5["feats"][:]) # pyright: ignore[reportIndexIssue] - # Accept [V] or [1, V] - if feats.ndim == 2 and feats.shape[0] == 1: - feats = feats[0] - elif feats.ndim == 1: - pass - else: - raise RuntimeError( - f"Expected single feature vector (shape [F] or [1, F]), got {feats.shape} in {feature_file}." - "Check that the features are patient-level." + if feature_file not in self._h5_handle_cache: + if len(self._h5_handle_cache) >= 128: + _, h = self._h5_handle_cache.popitem() + h.close() + try: + self._h5_handle_cache[feature_file] = h5py.File( + feature_file, "r", swmr=True, libver="latest" ) - if self.transform is not None: - feats = self.transform(feats) + except Exception: + self._h5_handle_cache[feature_file] = h5py.File(feature_file, "r") + + h5 = self._h5_handle_cache[feature_file] + feats_obj = h5["feats"] + if not isinstance(feats_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" + ) + feats = torch.from_numpy(feats_obj[()]) + # Accept [V] or [1, V] + if feats.ndim == 2 and feats.shape[0] == 1: + feats = feats[0] + elif feats.ndim == 1: + pass + else: + raise RuntimeError( + f"Expected single feature vector (shape [F] or [1, F]), got {feats.shape} in {feature_file}." + "Check that the features are patient-level." + ) + if self.transform is not None: + feats = self.transform(feats) label = self.ground_truths[idx] return feats, label @@ -529,7 +741,7 @@ def mpp(self) -> SlideMPP: def get_coords(feature_h5: h5py.File) -> CoordsInfo: - # --- NEW: handle missing coords ----multiplex data bypass: no coords found; generated fake coords + # NEW: handle missing coords - multiplex data bypass: no coords found; generated fake coords if "coords" not in feature_h5: feats_obj = feature_h5["patch_embeddings"] @@ -545,10 +757,15 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: tile_size_px = TilePixels(0) return CoordsInfo(coords_um, tile_size_um, tile_size_px) - coords: np.ndarray = feature_h5["coords"][:] # type: ignore - coords_um: np.ndarray | None = None + coords_obj = feature_h5["coords"] + if not isinstance(coords_obj, h5py.Dataset): + raise RuntimeError( + f"{feature_h5.filename}: expected 'coords' to be an HDF5 dataset but got {type(coords_obj)}" + ) + coords: np.ndarray = coords_obj[:] tile_size_um: Microns | None = None tile_size_px: TilePixels | None = None + coords_um: np.ndarray | None = None if (tile_size := feature_h5.attrs.get("tile_size", None)) and feature_h5.attrs.get( "unit", None ) == "um": @@ -594,7 +811,7 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: def _to_fixed_size_bag( - bag: _Bag, coords: _Coordinates, bag_size: BagSize + bag: _Bag, coords: _Coordinates, bag_size: BagSize, deterministic: bool = False ) -> tuple[_Bag, _Coordinates, BagSize]: """Samples a fixed-size bag of tiles from an arbitrary one. @@ -603,23 +820,47 @@ def _to_fixed_size_bag( """ # get up to bag_size elements n_tiles, _dim_feats = bag.shape - bag_idxs = torch.randperm(n_tiles)[:bag_size] + if n_tiles <= bag_size: + # take all and pad later + bag_idxs = torch.arange(n_tiles, device=bag.device) + else: + if deterministic: + # equidistant indices across the bag + idxs = torch.linspace(0, n_tiles - 1, steps=bag_size, device=bag.device) + bag_idxs = idxs.round().long() + else: + bag_idxs = torch.randperm(n_tiles, device=bag.device)[:bag_size] + bag_samples = bag[bag_idxs] coord_samples = coords[bag_idxs] # zero-pad if we don't have enough samples - zero_padded_bag = torch.cat( - ( - bag_samples, - torch.zeros(bag_size - bag_samples.shape[0], bag_samples.shape[1]), + if bag_samples.shape[0] < bag_size: + zero_padded_bag = torch.cat( + ( + bag_samples, + torch.zeros( + bag_size - bag_samples.shape[0], + bag_samples.shape[1], + device=bag.device, + dtype=bag.dtype, + ), + ) ) - ) - zero_padded_coord = torch.cat( - ( - coord_samples, - torch.zeros(bag_size - coord_samples.shape[0], coord_samples.shape[1]), + zero_padded_coord = torch.cat( + ( + coord_samples, + torch.zeros( + bag_size - coord_samples.shape[0], + coord_samples.shape[1], + device=coords.device, + dtype=coords.dtype, + ), + ) ) - ) + else: + zero_padded_bag = bag_samples + zero_padded_coord = coord_samples return zero_padded_bag, zero_padded_coord, min(bag_size, len(bag)) @@ -627,33 +868,71 @@ def patient_to_ground_truth_from_clini_table_( *, clini_table_path: Path | TextIO, patient_label: PandasLabel, - ground_truth_label: PandasLabel, -) -> dict[PatientId, GroundTruth]: - """Loads the patients and their ground truths from a clini table.""" + ground_truth_label: PandasLabel | Sequence[PandasLabel], +) -> ( + dict[PatientId, GroundTruth | None] | dict[PatientId, dict[str, GroundTruth | None]] +): + """Loads the patients and their ground truths from a clini table. + + `ground_truth_label` may be either a single column name (str) or a sequence + of column names. In the latter case the returned mapping will contain a + dict mapping column -> value for each patient (supporting multi-target + setups). + """ + # Normalize to list for uniform handling + if isinstance(ground_truth_label, str): + cols = [patient_label, ground_truth_label] + multi = False + target_cols_inner: list[PandasLabel] = [] + else: + cols = [patient_label, *list(ground_truth_label)] + multi = True + target_cols_inner = [c for c in cols if c != patient_label] + clini_df = read_table( clini_table_path, - usecols=[patient_label, ground_truth_label], + usecols=cols, dtype=str, - ).dropna() + ) + + # If multi-target, keep rows where at least one target is present; for + # single target behave like before and drop rows missing the value. + if multi: + clini_df = clini_df.dropna(subset=target_cols_inner, how="all") + else: + clini_df = clini_df.dropna(subset=[ground_truth_label]) + try: - patient_to_ground_truth: Mapping[PatientId, GroundTruth] = clini_df.set_index( - patient_label, verify_integrity=True - )[ground_truth_label].to_dict() + if multi: + # Build mapping patient -> {col: value} + result: dict[PatientId, dict[str, GroundTruth | None]] = {} + for _, row in clini_df.iterrows(): + pid = row[patient_label] + # Convert pandas nan to None and keep strings otherwise + result[pid] = { + col: (None if pd.isna(row[col]) else str(row[col])) + for col in target_cols_inner + } + return result + else: + patient_to_ground_truth: Mapping[PatientId, str] = cast( + Mapping[PatientId, str], + clini_df.set_index(patient_label, verify_integrity=True)[ + cast(PandasLabel, ground_truth_label) + ].to_dict(), + ) + return cast(dict[PatientId, GroundTruth | None], patient_to_ground_truth) except KeyError as e: if patient_label not in clini_df: raise ValueError( f"{patient_label} was not found in clini table " f"(columns in clini table: {clini_df.columns})" ) from e - elif ground_truth_label not in clini_df: + else: raise ValueError( - f"{ground_truth_label} was not found in clini table " + f"One or more ground truth columns were not found in clini table " f"(columns in clini table: {clini_df.columns})" ) from e - else: - raise e from e - - return patient_to_ground_truth def patient_to_survival_from_clini_table_( @@ -662,14 +941,14 @@ def patient_to_survival_from_clini_table_( patient_label: PandasLabel, time_label: PandasLabel, status_label: PandasLabel, -) -> dict[PatientId, GroundTruth]: +) -> dict[PatientId, tuple[float | None, int | None]]: """ Loads patients and their survival ground truths (time + event) from a clini table. Returns - ------- dict[PatientId, GroundTruth] - Mapping patient_id -> "time status" (e.g. "302 dead", "476 alive"). + Mapping patient_id -> tuple (time, event) where `time` is a numeric follow-up + value and `event` is an integer indicator (1=event/death, 0=censored). """ clini_df = read_table( clini_table_path, @@ -705,7 +984,7 @@ def patient_to_survival_from_clini_table_( # Only drop rows where BOTH time and status are missing clini_df = clini_df.dropna(subset=[time_label, status_label], how="all") - patient_to_ground_truth: dict[PatientId, GroundTruth] = {} + patient_to_ground_truth: dict[PatientId, tuple[float | None, int | None]] = {} for _, row in clini_df.iterrows(): pid = row[patient_label] time_str = row[time_label] @@ -718,10 +997,9 @@ def patient_to_survival_from_clini_table_( # Encode status: keep both dead (event=1) and alive (event=0) status = _parse_survival_status(status_str) - # Encode back to "alive"/"dead" like before - # status = "dead" if status_val == 1 else "alive" - - patient_to_ground_truth[pid] = f"{time_str} {status}" + # Store structured ground-truth as (time, event) to avoid repeated string parsing + time_val = None if pd.isna(time_str) else float(time_str) + patient_to_ground_truth[pid] = (time_val, status) return patient_to_ground_truth @@ -780,7 +1058,10 @@ def read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: def filter_complete_patient_data_( *, - patient_to_ground_truth: Mapping[PatientId, GroundTruth | None], + patient_to_ground_truth: Mapping[ + PatientId, + GroundTruth | dict[str, GroundTruth] | tuple[float | None, int | None] | None, + ], slide_to_patient: Mapping[FeaturePath, PatientId], drop_patients_with_missing_ground_truth: bool, ) -> Mapping[PatientId, PatientData]: @@ -860,12 +1141,15 @@ def _log_patient_slide_feature_inconsistencies( if slides_without_features := { slide for slide in slide_to_patient.keys() if not slide.exists() }: + # Log only the filenames (not full paths) to keep warnings concise. + slides_list = sorted(s.name for s in slides_without_features) _logger.warning( - f"some feature files could not be found: {slides_without_features}" + "some feature files could not be found: %s", + ", ".join(slides_list), ) -def get_stride(coords: Float[Tensor, "tile 2"]) -> float: +def get_stride(coords: Tensor) -> float: """Gets the minimum step width between any two coordintes.""" xs: Tensor = coords[:, 0].unique(sorted=True) ys: Tensor = coords[:, 1].unique(sorted=True) @@ -919,26 +1203,139 @@ def _parse_survival_status(value) -> int | None: ) +def load_patient_data_( + *, + feature_dir: Path, + clini_table: Path, + slide_table: Path | None, + task: Task, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, + time_label: PandasLabel | None, + status_label: PandasLabel | None, + patient_label: PandasLabel, + filename_label: PandasLabel, + drop_patients_with_missing_ground_truth: bool = True, +) -> tuple[Mapping[PatientId, PatientData], str]: + """Load patient data based on feature type (tile, slide, or patient). + + This consolidates the common data loading logic used across train, crossval, and deploy. + + Returns: + (patient_to_data, feature_type) + """ + feature_type = detect_feature_type(feature_dir) + + if feature_type in ("tile", "slide"): + if slide_table is None: + raise ValueError("A slide table is required for tile/slide-level features") + + # Load ground truth based on task + if task == "survival": + if time_label is None or status_label is None: + raise ValueError( + "Both time_label and status_label are required for survival modeling" + ) + patient_to_ground_truth = patient_to_survival_from_clini_table_( + clini_table_path=clini_table, + time_label=time_label, + status_label=status_label, + patient_label=patient_label, + ) + else: + if ground_truth_label is None: + raise ValueError( + "Ground truth label is required for classification or regression modeling" + ) + # Disallow multi-target ground truth for non-classification tasks + if ( + task != "classification" + and isinstance(ground_truth_label, Sequence) + and not isinstance(ground_truth_label, str) + ): + raise ValueError( + "Multi-target ground_truth_label is only supported for classification tasks" + ) + patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( + clini_table_path=clini_table, + ground_truth_label=ground_truth_label, + patient_label=patient_label, + ) + + # Link slides to patients + slide_to_patient: Final[dict[FeaturePath, PatientId]] = ( + slide_to_patient_from_slide_table_( + slide_table_path=slide_table, + feature_dir=feature_dir, + patient_label=patient_label, + filename_label=filename_label, + ) + ) + + # Filter to complete patient data + patient_to_data = filter_complete_patient_data_( + patient_to_ground_truth=cast( + Mapping[PatientId, GroundTruth | dict[str, GroundTruth] | None], + patient_to_ground_truth, + ), + slide_to_patient=slide_to_patient, + drop_patients_with_missing_ground_truth=drop_patients_with_missing_ground_truth, + ) + elif feature_type == "patient": + patient_to_data = load_patient_level_data( + task=task, + clini_table=clini_table, + feature_dir=feature_dir, + patient_label=patient_label, + ground_truth_label=ground_truth_label, + time_label=time_label, + status_label=status_label, + ) + else: + raise RuntimeError(f"Unknown feature type: {feature_type}") + + return patient_to_data, feature_type + + def log_patient_class_summary( *, patient_to_data: Mapping[PatientId, PatientData], categories: Sequence[Category] | None, - prefix: str = "", ) -> None: + """ + Logs class distribution. + Supports both single-target and multi-target classification. + """ + ground_truths = [ - pd.ground_truth - for pd in patient_to_data.values() - if pd.ground_truth is not None + p.ground_truth for p in patient_to_data.values() if p.ground_truth is not None ] if not ground_truths: - _logger.warning(f"{prefix}No ground truths available to summarize.") + _logger.warning("No ground truths available for summary.") return - cats = categories or sorted(set(ground_truths)) - counter = Counter(ground_truths) + # Multi-target + if isinstance(ground_truths[0], dict): + # Collect per-target values + per_target: dict[str, list] = {} - _logger.info( - f"{prefix}Total patients: {len(ground_truths)} | " - + " | ".join([f"Class {c}: {counter.get(c, 0)}" for c in cats]) - ) + for gt in ground_truths: + for key, value in gt.items(): + per_target.setdefault(key, []).append(value) + + for target_name, values in per_target.items(): + counts = {} + for v in values: + counts[v] = counts.get(v, 0) + 1 + + _logger.info( + f"[Multi-target] Target '{target_name}' distribution: {counts}" + ) + + # Single-target + else: + counts = {} + for gt in ground_truths: + counts[gt] = counts.get(gt, 0) + 1 + + _logger.info(f"Class distribution: {counts}") diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index 905c6005..e0444f15 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd import torch -from jaxtyping import Float +import torch.nn.functional as F from lightning.pytorch.accelerators.accelerator import Accelerator from stamp.modeling.data import ( @@ -20,7 +20,13 @@ slide_to_patient_from_slide_table_, ) from stamp.modeling.registry import ModelName, load_model_class -from stamp.types import GroundTruth, PandasLabel, PatientId +from stamp.types import ( + Category, + GroundTruth, + PandasLabel, + PatientId, + SurvivalGroundTruth, +) __all__ = ["deploy_categorical_model_"] @@ -32,6 +38,13 @@ Logit: TypeAlias = float +# Prediction type aliases +PredictionSingle: TypeAlias = torch.Tensor +PredictionMulti: TypeAlias = dict[str, torch.Tensor] +PredictionsType: TypeAlias = Mapping[ + PatientId, Union[PredictionSingle, PredictionMulti] +] + def load_model_from_ckpt(path: Union[str, Path]): ckpt = torch.load(path, map_location="cpu", weights_only=False) @@ -50,7 +63,7 @@ def deploy_categorical_model_( clini_table: Path | None, slide_table: Path | None, feature_dir: Path, - ground_truth_label: PandasLabel | None, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, time_label: PandasLabel | None, status_label: PandasLabel | None, patient_label: PandasLabel, @@ -126,7 +139,12 @@ def deploy_categorical_model_( # classification/regression: still use ground_truth_label if ( len( - ground_truth_labels := set(model.ground_truth_label for model in models) + ground_truth_labels := { + tuple(model.ground_truth_label) + if isinstance(model.ground_truth_label, list) + else (model.ground_truth_label,) + for model in models + } ) != 1 ): @@ -145,17 +163,21 @@ def deploy_categorical_model_( f"{ground_truth_label} vs {model_ground_truth_label}" ) - ground_truth_label = ground_truth_label or model_ground_truth_label + ground_truth_label = ground_truth_label or cast( + PandasLabel, model_ground_truth_label + ) output_dir.mkdir(exist_ok=True, parents=True) model_categories = None if task == "classification": # Ensure the categories were the same between all models - category_sets = {tuple(m.categories) for m in models} + category_sets = { + tuple(cast(Sequence[GroundTruth], m.categories)) for m in models + } if len(category_sets) != 1: raise RuntimeError(f"Categories differ between models: {category_sets}") - model_categories = list(models[0].categories) + model_categories = list(cast(Sequence[GroundTruth], models[0].categories)) # Data loading logic if feature_type in ("tile", "slide"): @@ -171,6 +193,15 @@ def deploy_categorical_model_( ) if clini_table is not None: if task == "survival": + if not hasattr(models[0], "time_label") or not isinstance( + models[0].time_label, str + ): + raise AttributeError("Model is missing valid 'time_label' (str).") + if not hasattr(models[0], "status_label") or not isinstance( + models[0].status_label, str + ): + raise AttributeError("Model is missing valid 'status_label' (str).") + patient_to_ground_truth = patient_to_survival_from_clini_table_( clini_table_path=clini_table, patient_label=patient_label, @@ -192,7 +223,10 @@ def deploy_categorical_model_( patient_id: None for patient_id in set(slide_to_patient.values()) } patient_to_data = filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, + patient_to_ground_truth=cast( + Mapping[PatientId, GroundTruth | dict[str, GroundTruth] | None], + patient_to_ground_truth, + ), slide_to_patient=slide_to_patient, drop_patients_with_missing_ground_truth=False, ) @@ -241,16 +275,55 @@ def deploy_categorical_model_( "regression": _to_regression_prediction_df, "survival": _to_survival_prediction_df, }[task] - all_predictions: list[Mapping[PatientId, Float[torch.Tensor, "category"]]] = [] # noqa: F821 + all_predictions: list[PredictionsType] = [] + categories_for_export: ( + Sequence[Category] | Mapping[str, Sequence[Category]] | None + ) = cast(Sequence[Category] | Mapping[str, Sequence[Category]] | None, None) for model_i, model in enumerate(models): + # Check for data leakage: if the deployment patient set overlaps with + # the patients used during model training/validation, log a critical + # message. This check is intentionally performed at the deploy level + # (not inside `_predict`) so prediction helpers can be reused without + # side-effects in other contexts (e.g., cross-validation). + patients_used_for_training: set[PatientId] = set( + getattr(model, "train_patients", []) + ) | set(getattr(model, "valid_patients", [])) + if overlap := patients_used_for_training & set(patient_ids): + _logger.critical( + "DATA LEAKAGE DETECTED: %d patient(s) in deployment set were used " + "during training/validation. Overlapping IDs: %s", + len(overlap), + sorted(overlap), + ) + predictions = _predict( model=model, - test_dl=test_dl, # pyright: ignore[reportPossiblyUnboundVariable] + test_dl=test_dl, patient_ids=patient_ids, accelerator=accelerator, ) all_predictions.append(predictions) + if isinstance(next(iter(predictions.values())), dict): + # Multi-target case: gather categories across all targets for export (use model categories if available, else infer from GT) + categories_accum: dict[str, set[GroundTruth]] = {} + + for pd_item in patient_to_data.values(): + gt = pd_item.ground_truth + if isinstance(gt, dict): + for k, v in gt.items(): + if v is not None: + categories_accum.setdefault(k, set()).add(v) + + categories_for_export = {k: sorted(v) for k, v in categories_accum.items()} + + else: + # Single-target case: use categories from model if available, else infer from GT + if task == "classification": + categories_for_export = models[0].categories + else: + categories_for_export = [] + # cut-off values from survival ckpt cut_off = ( getattr(model.hparams, "train_pred_median", None) @@ -261,7 +334,7 @@ def deploy_categorical_model_( # Only save individual model files when deploying multiple models (ensemble) if len(models) > 1: df_builder( - categories=model_categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, predictions=predictions, patient_label=patient_label, @@ -270,7 +343,7 @@ def deploy_categorical_model_( ).to_csv(output_dir / f"patient-preds-{model_i}.csv", index=False) else: df_builder( - categories=model_categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, predictions=predictions, patient_label=patient_label, @@ -279,17 +352,29 @@ def deploy_categorical_model_( ).to_csv(output_dir / "patient-preds.csv", index=False) if task == "classification": - # TODO we probably also want to save the 95% confidence interval in addition to the mean + # compute mean prediction across models (supports single- and multi-target) + mean_preds: dict[PatientId, object] = {} + for pid in patient_ids: + model_preds = cast( + list[torch.Tensor], [preds[pid] for preds in all_predictions] + ) + firstp = model_preds[0] + if isinstance(firstp, dict): + # per-target averaging + mean_preds[pid] = { + t: torch.stack([p[t] for p in model_preds]).mean(dim=0) + for t in firstp.keys() + } + else: + mean_preds[pid] = torch.stack(model_preds).mean(dim=0) + + assert categories_for_export is not None, ( + "categories_for_export must be set before use" + ) df_builder( - categories=model_categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, - predictions={ - # Mean prediction - patient_id: torch.stack( - [predictions[patient_id] for predictions in all_predictions] - ).mean(dim=0) - for patient_id in patient_ids - }, + predictions=mean_preds, patient_label=patient_label, ground_truth_label=ground_truth_label, ).to_csv(output_dir / "patient-preds_95_confidence_interval.csv", index=False) @@ -301,18 +386,11 @@ def _predict( test_dl: torch.utils.data.DataLoader, patient_ids: Sequence[PatientId], accelerator: str | Accelerator, -) -> Mapping[PatientId, Float[torch.Tensor, "..."]]: +) -> PredictionsType: model = model.eval() torch.set_float32_matmul_precision("medium") - # Check for data leakage - patients_used_for_training: set[PatientId] = set( - getattr(model, "train_patients", []) - ) | set(getattr(model, "valid_patients", [])) - if overlap := patients_used_for_training & set(patient_ids): - raise ValueError( - f"some of the patients in the validation set were used during training: {overlap}" - ) + # Note: data-leakage check intentionally performed at deploy level. trainer = lightning.Trainer( accelerator=accelerator, @@ -320,51 +398,190 @@ def _predict( logger=False, ) - raw_preds = torch.concat(cast(list[torch.Tensor], trainer.predict(model, test_dl))) + outs = trainer.predict(model, test_dl) + + if not outs: + return {} + + first = outs[0] + + # Multi-target case: each element of outs is a dict[target_label -> tensor] + if isinstance(first, dict): + per_target_lists: dict[str, list[torch.Tensor]] = {} + for out in outs: + if not isinstance(out, dict): + raise RuntimeError("Mixed prediction output types from model") + for k, v in out.items(): + per_target_lists.setdefault(k, []).append(v) + + per_target_tensors: dict[str, torch.Tensor] = { + k: torch.cat(vlist, dim=0) for k, vlist in per_target_lists.items() + } + + if getattr(model.hparams, "task", None) == "classification": + for k in list(per_target_tensors.keys()): + per_target_tensors[k] = torch.softmax(per_target_tensors[k], dim=1) + + # build per-patient dicts + num_preds = next(iter(per_target_tensors.values())).shape[0] + predictions: dict[PatientId, dict[str, torch.Tensor]] = {} + for i, pid in enumerate(patient_ids[:num_preds]): + predictions[pid] = { + k: per_target_tensors[k][i] for k in per_target_tensors.keys() + } + + return predictions + + # Single-target case: each element of outs is a tensor + outs_single = cast(list[torch.Tensor], outs) + + raw_preds = torch.cat(outs_single, dim=0) if getattr(model.hparams, "task", None) == "classification": - predictions = torch.softmax(raw_preds, dim=1) + raw_preds = torch.softmax(raw_preds, dim=1) elif getattr(model.hparams, "task", None) == "survival": - predictions = raw_preds.squeeze(-1) # (N,) risk scores - else: # regression - predictions = raw_preds + raw_preds = raw_preds.squeeze(-1) + + result: dict[PatientId, torch.Tensor] = { + pid: raw_preds[i] for i, pid in enumerate(patient_ids) + } - return dict(zip(patient_ids, predictions, strict=True)) + return result def _to_prediction_df( *, - categories: Sequence[GroundTruth], - patient_to_ground_truth: Mapping[PatientId, GroundTruth | None], - predictions: Mapping[PatientId, torch.Tensor], + categories: Sequence[GroundTruth] | Mapping[str, Sequence[GroundTruth]], + patient_to_ground_truth: Mapping[PatientId, GroundTruth | None] + | Mapping[PatientId, dict[str, GroundTruth | None]], + predictions: Mapping[PatientId, torch.Tensor] + | Mapping[PatientId, dict[str, torch.Tensor]], patient_label: PandasLabel, - ground_truth_label: PandasLabel, + ground_truth_label: PandasLabel | Sequence[PandasLabel], **kwargs, ) -> pd.DataFrame: - """Compiles deployment results into a DataFrame.""" - return pd.DataFrame( - [ - { - patient_label: patient_id, - ground_truth_label: patient_to_ground_truth.get(patient_id), - "pred": categories[int(prediction.argmax())], - **{ - f"{ground_truth_label}_{category}": prediction[i_cat].item() - for i_cat, category in enumerate(categories) - }, - "loss": ( - torch.nn.functional.cross_entropy( - prediction.reshape(1, -1), - torch.tensor(np.where(np.array(categories) == ground_truth)[0]), - ).item() - if (ground_truth := patient_to_ground_truth.get(patient_id)) - is not None - else None - ), - } - for patient_id, prediction in predictions.items() - ] - ).sort_values(by="loss") + """Compiles deployment results into a DataFrame. + + Supports single-target and multi-target classification. + - Single-target: `predictions` maps patient -> tensor and `categories` is a sequence. + - Multi-target: `predictions` maps patient -> dict[target_label -> tensor] and + `categories` is a mapping from target_label -> sequence of category names. + """ + first_pred = next(iter(predictions.values())) + + # Multi-target predictions: dict per patient + if isinstance(first_pred, dict): + # determine target labels + target_labels = list(cast(dict, first_pred).keys()) + + # prepare categories mapping + if isinstance(categories, dict): + cats_map = categories + else: + # try infer categories list ordering: assume categories is a sequence-of-sequences + cats_map = {} + if isinstance(categories, Sequence): + try: + for i, t in enumerate(target_labels): + cats_map[t] = list( + cast(Sequence[Sequence[GroundTruth]], categories)[i] + ) + except Exception: + cats_map = {} + + # infer missing category lists from ground truth + if any(t not in cats_map for t in target_labels): + inferred: dict[str, set] = {t: set() for t in target_labels} + for pid, gt in patient_to_ground_truth.items(): + if isinstance(gt, dict): + for t in target_labels: + val = gt.get(t) + if val is not None: + inferred[t].add(val) + for t in target_labels: + if t not in cats_map: + cats_map[t] = sorted(inferred.get(t, [])) + + rows = [] + for pid, pred_dict in predictions.items(): + row: dict = {patient_label: pid} + gt_entry = patient_to_ground_truth.get(pid) + # ground truths per target + for t in target_labels: + if isinstance(gt_entry, dict): + row[t] = gt_entry.get(t) + else: + row[t] = gt_entry + + total_loss = 0.0 + has_loss = False + for t in target_labels: + tensor = cast(dict[str, torch.Tensor], pred_dict)[t] + probs = tensor.detach().cpu() + cats: Sequence[GroundTruth] = cast( + Sequence[GroundTruth], + cats_map.get(t, []), + ) + if probs.numel() == 1: + row[f"pred_{t}"] = float(probs.item()) + else: + pred_idx = int(probs.argmax().item()) + row[f"pred_{t}"] = ( + cats[pred_idx] if pred_idx < len(cats) else pred_idx + ) + for i_cat, cat in enumerate(cats): + if i_cat < probs.shape[0]: + row[f"{t}_{cat}"] = float(probs[i_cat].item()) + else: + row[f"{t}_{cat}"] = None + + if isinstance(gt_entry, dict) and (gt := gt_entry.get(t)) is not None: + try: + target_index = int(np.where(np.array(cats) == gt)[0][0]) + loss = torch.nn.functional.cross_entropy( + probs.reshape(1, -1), torch.tensor([target_index]) + ).item() + total_loss += loss + has_loss = True + except Exception: + pass + + row["loss"] = total_loss if has_loss else None + rows.append(row) + + return pd.DataFrame(rows) + + # Single-target (original behaviour) + if not all(isinstance(p, torch.Tensor) for p in predictions.values()): + raise TypeError("Single-target block received multi-target dict predictions.") + + predictions = cast(Mapping[PatientId, torch.Tensor], predictions) + + rows = [] + for pid, prediction in predictions.items(): + gt = patient_to_ground_truth.get(pid) + cats = cast(Sequence[GroundTruth], categories) + pred_idx = int(prediction.argmax()) + row = { + patient_label: pid, + ground_truth_label: gt, + "pred": cats[pred_idx], + **{ + f"{ground_truth_label}_{category}": float(prediction[i_cat].item()) + for i_cat, category in enumerate(cats) + }, + "loss": ( + torch.nn.functional.cross_entropy( + prediction.reshape(1, -1), + torch.tensor(np.where(np.array(cats) == gt)[0]), + ).item() + if gt is not None + else None + ), + } + rows.append(row) + + return pd.DataFrame(rows).sort_values(by="loss") def _to_regression_prediction_df( @@ -383,8 +600,6 @@ def _to_regression_prediction_df( - pred (float) - loss (per-sample L1 loss if GT available, else None) """ - import torch.nn.functional as F - return pd.DataFrame( [ { @@ -419,7 +634,9 @@ def _to_regression_prediction_df( def _to_survival_prediction_df( *, - patient_to_ground_truth: Mapping[PatientId, GroundTruth | None], + patient_to_ground_truth: Mapping[ + PatientId, GroundTruth | SurvivalGroundTruth | None + ], predictions: Mapping[PatientId, torch.Tensor], patient_label: PandasLabel, cut_off: float | None = None, @@ -448,21 +665,11 @@ def _to_survival_prediction_df( else: row["pred_score"] = pred.cpu().tolist() - # Ground truth: time + event - if gt is not None: - if isinstance(gt, str) and " " in gt: - time_str, status_str = gt.split(" ", 1) - row["time"] = float(time_str) if time_str.lower() != "nan" else None - if status_str.lower() in {"dead", "event", "1"}: - row["event"] = 1 - elif status_str.lower() in {"alive", "censored", "0"}: - row["event"] = 0 - else: - row["event"] = None - elif isinstance(gt, (tuple, list)) and len(gt) == 2: - row["time"], row["event"] = gt - else: - row["time"], row["event"] = None, None + # Ground truth: prefer structured tuple/list (time, event). Do not + # call .split on ground-truth values — assume structured input. If + # the value is not a 2-tuple/list, treat both fields as unknown. + if isinstance(gt, (tuple, list)) and len(gt) == 2: + row["time"], row["event"] = gt else: row["time"], row["event"] = None, None diff --git a/src/stamp/modeling/models/__init__.py b/src/stamp/modeling/models/__init__.py index 59a0a3aa..64894ccc 100644 --- a/src/stamp/modeling/models/__init__.py +++ b/src/stamp/modeling/models/__init__.py @@ -8,12 +8,21 @@ import lightning import numpy as np import torch + +# Use beartype.typing.Mapping to avoid PEP-585 deprecation warnings in beartype +from beartype.typing import Mapping from jaxtyping import Bool, Float +from lifelines.utils import concordance_index as lifelines_cindex from packaging.version import Version from torch import Tensor, nn, optim from torchmetrics.classification import MulticlassAUROC import stamp +from stamp.modeling.models.barspoon import ( + EncDecTransformer, + LitMilClassificationMixin, + TargetLabel, +) from stamp.modeling.models.cox import neg_partial_log_likelihood from stamp.types import ( Bags, @@ -80,7 +89,7 @@ def __init__( # This should only happen when the model is loaded, # otherwise the default value will make these checks pass. # TODO: Change this on version change - if stamp_version < Version("2.4.0"): + if stamp_version < Version("2.4.1"): # Update this as we change our model in incompatible ways! raise ValueError( f"model has been built with stamp version {stamp_version} " @@ -143,6 +152,19 @@ def on_train_batch_end(self, outputs, batch, batch_idx): ) +class _TileLevelMixin: + """Mixin for tile-level models providing shared MIL masking logic.""" + + @staticmethod + def _mask_from_bags(bags: Bags, bag_sizes: BagSizes) -> Bool[Tensor, "batch tile"]: + """Create attention mask for padded tiles in variable-length bags.""" + max_possible_bag_size = bags.size(1) + mask = torch.arange(max_possible_bag_size).type_as(bag_sizes).unsqueeze( + 0 + ).repeat(len(bags), 1) >= bag_sizes.unsqueeze(1) + return mask + + class LitBaseClassifier(Base): """ PyTorch Lightning wrapper for tile level and patient level clasification. @@ -194,15 +216,16 @@ def __init__( self.class_weights = category_weights self.valid_auroc = MulticlassAUROC(len(categories)) # Number classes - self.categories = np.array(categories) + self.categories = list(categories) self.hparams.update({"task": "classification"}) -class LitTileClassifier(LitBaseClassifier): +class LitTileClassifier(_TileLevelMixin, LitBaseClassifier): """ - PyTorch Lightning wrapper for the model used in weakly supervised - learning settings, such as Multiple Instance Learning (MIL) for whole-slide images or patch-based data. + PyTorch Lightning wrapper for tile-level MIL classification. + + Used in weakly supervised settings for whole-slide images or patch-based data. """ supported_features = ["tile"] @@ -216,7 +239,7 @@ def forward( def _step( self, *, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], step_name: str, use_mask: bool, ) -> Loss: @@ -244,7 +267,6 @@ def _step( ) if step_name == "validation": - # TODO this is a bit ugly, we'd like to have `_step` without special cases self.valid_auroc.update(logits, targets.long().argmax(dim=-1)) self.log( f"{step_name}_auroc", @@ -258,59 +280,51 @@ def _step( def training_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="training", use_mask=False) def validation_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="validation", use_mask=False) def test_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="test", use_mask=False) def predict_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Float[Tensor, "batch logit"]: bags, coords, bag_sizes, _ = batch # adding a mask here will *drastically* and *unbearably* increase memory usage + # Ensure input dtype matches model weights to avoid dtype-mismatch errors + param_dtype = next(self.model.parameters()).dtype + bags = bags.to(dtype=param_dtype) + coords = coords.to(dtype=param_dtype) return self.model(bags, coords=coords, mask=None) - def _mask_from_bags( - *, - bags: Bags, - bag_sizes: BagSizes, - ) -> Bool[Tensor, "batch tile"]: - max_possible_bag_size = bags.size(1) - mask = torch.arange(max_possible_bag_size).type_as(bag_sizes).unsqueeze( - 0 - ).repeat(len(bags), 1) >= bag_sizes.unsqueeze(1) - - return mask - class LitSlideClassifier(LitBaseClassifier): - """ - PyTorch Lightning wrapper for MLPClassifier. - """ + """PyTorch Lightning wrapper for slide/patient-level classification.""" supported_features = ["slide"] def forward(self, x: Tensor) -> Tensor: return self.model(x) - def _step(self, batch, step_name: str): - feats, targets = batch + def _step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], step_name: str + ) -> Loss: + feats, targets = list(batch) # Works for both tuple and list logits = self.model(feats.float()) loss = nn.functional.cross_entropy( logits, @@ -336,17 +350,28 @@ def _step(self, batch, step_name: str): ) return loss - def training_step(self, batch, batch_idx): + def training_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch, "training") - def validation_step(self, batch, batch_idx): + def validation_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch, "validation") - def test_step(self, batch, batch_idx): + def test_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch, "test") - def predict_step(self, batch, batch_idx): - feats, _ = batch + def predict_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Tensor: + feats, _ = batch if isinstance(batch, tuple) else batch + # Cast inputs to model parameter dtype to avoid Half/Float mismatches + param_dtype = next(self.model.parameters()).dtype + feats = feats.to(dtype=param_dtype) return self.model(feats) @@ -397,9 +422,10 @@ def _compute_loss(y_true: Tensor, y_pred: Tensor) -> Loss: return nn.functional.l1_loss(y_true, y_pred) -class LitTileRegressor(LitBaseRegressor): +class LitTileRegressor(_TileLevelMixin, LitBaseRegressor): """ - PyTorch Lightning wrapper for weakly supervised / MIL regression at tile/patient level. + PyTorch Lightning wrapper for tile-level MIL regression. + Produces a single continuous output per bag (dim_output = 1). """ @@ -418,7 +444,7 @@ def forward( def _step( self, *, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], step_name: str, use_mask: bool, ) -> Loss: @@ -458,66 +484,50 @@ def _step( def training_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="training", use_mask=False) def validation_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="validation", use_mask=False) def test_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="test", use_mask=False) def predict_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Float[Tensor, "batch 1"]: bags, coords, bag_sizes, _ = batch # keep memory usage low as in classifier return self.model(bags, coords=coords, mask=None) - def _mask_from_bags( - *, - bags: Bags, - bag_sizes: BagSizes, - ) -> Bool[Tensor, "batch tile"]: - max_possible_bag_size = bags.size(1) - mask = torch.arange(max_possible_bag_size).type_as(bag_sizes).unsqueeze( - 0 - ).repeat(len(bags), 1) >= bag_sizes.unsqueeze(1) - - return mask - class LitSlideRegressor(LitBaseRegressor): - """ - PyTorch Lightning wrapper for slide-level or patient-level regression. - Produces a single continuous output per slide (dim_output = 1). - """ + """PyTorch Lightning wrapper for slide/patient-level regression.""" supported_features = ["slide"] def forward(self, feats: Tensor) -> Tensor: - """Forward pass for slide-level features.""" return self.model(feats.float()) def _step( self, *, - batch: tuple[Tensor, Tensor], + batch: tuple[Tensor, Tensor] | list[Tensor], step_name: str, ) -> Loss: - feats, targets = batch + feats, targets = list(batch) # Works for both tuple and list preds = self.model(feats.float(), mask=None) # (B, 1) y = targets.to(preds).float() @@ -534,7 +544,6 @@ def _step( ) if step_name == "validation": - # same metrics as LitTileRegressor p = preds.squeeze(-1) t = y.squeeze(-1) self.log( @@ -547,17 +556,25 @@ def _step( return loss - def training_step(self, batch, batch_idx): + def training_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch=batch, step_name="training") - def validation_step(self, batch, batch_idx): + def validation_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch=batch, step_name="validation") - def test_step(self, batch, batch_idx): + def test_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch=batch, step_name="test") - def predict_step(self, batch, batch_idx): - feats, _ = batch + def predict_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Tensor: + feats, _ = batch if isinstance(batch, tuple) else batch return self.model(feats.float()) @@ -645,27 +662,36 @@ def cox_loss( def c_index( scores: torch.Tensor, times: torch.Tensor, events: torch.Tensor ) -> torch.Tensor: - # """ - # Concordance index: proportion of correctly ordered comparable pairs. - # """ - N = len(times) - if N <= 1: + """ + Concordance index using lifelines implementation for consistency with statistics. + + Uses the same convention as stamp.statistics.survival: + - Higher risk scores should correspond to shorter survival (worse outcome). + - We negate scores so lifelines interprets higher values as longer survival. + """ + # Convert to numpy for lifelines + scores_np = scores.detach().cpu().numpy().flatten() + times_np = times.detach().cpu().numpy().flatten() + events_np = events.detach().cpu().numpy().flatten() + + # Filter out NaN values + valid_mask = ~(np.isnan(times_np) | np.isnan(events_np) | np.isnan(scores_np)) + if valid_mask.sum() <= 1: return torch.tensor(float("nan"), device=scores.device) - t_i = times.view(-1, 1).expand(N, N) - t_j = times.view(1, -1).expand(N, N) - e_i = events.view(-1, 1).expand(N, N) + times_np = times_np[valid_mask] + events_np = events_np[valid_mask] + scores_np = scores_np[valid_mask] - mask = (t_i < t_j) & e_i.bool() - if mask.sum() == 0: + # Use lifelines concordance_index with negated risk (same as statistics module) + # lifelines expects: higher predicted value = longer survival + # Cox outputs: higher risk = shorter survival, so we negate + try: + ci = lifelines_cindex(times_np, -scores_np, events_np) + except Exception: return torch.tensor(float("nan"), device=scores.device) - s_i = scores.view(-1, 1).expand(N, N)[mask] - s_j = scores.view(1, -1).expand(N, N)[mask] - - conc = (s_i > s_j).float() - ties = (s_i == s_j).float() * 0.5 - return (conc + ties).sum() / mask.sum() + return torch.tensor(ci, device=scores.device, dtype=scores.dtype) def on_validation_epoch_end(self): if ( @@ -702,12 +728,12 @@ def on_train_epoch_end(self): self.hparams.update({"train_pred_median": self.train_pred_median}) -class LitTileSurvival(LitSurvivalBase): +class LitTileSurvival(_TileLevelMixin, LitSurvivalBase): """ - Tile-level or patch-level survival analysis. - Expects dataloader batches like: - (bags, coords, bag_sizes, targets) - where targets is shape (B,2): [:,0]=time, [:,1]=event (1=event, 0=censored). + Tile-level survival analysis with Cox proportional hazards loss. + + Expects batches: (bags, coords, bag_sizes, targets) + where targets.shape = (B, 2): [:,0]=time, [:,1]=event (0=censored, 1=event). """ supported_features = ["tile"] @@ -722,7 +748,11 @@ def forward( # (most ViT backbones accept coords/mask even if unused) return self.model(bags, coords=coords, mask=mask) - def training_step(self, batch, batch_idx): + def training_step( + self, + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], + batch_idx: int, + ) -> Loss: bags, coords, bag_sizes, targets = batch preds = self.model(bags, coords=coords, mask=None) y = targets.to(preds.device, dtype=torch.float32) @@ -747,7 +777,7 @@ def training_step(self, batch, batch_idx): def validation_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Any: bags, coords, bag_sizes, targets = batch @@ -761,9 +791,13 @@ def validation_step( self._val_times.append(times.detach().cpu()) self._val_events.append(events.detach().cpu()) - def predict_step(self, batch, batch_idx): - feats, coords, n_tiles, survival_target = batch - return self.model(feats.float(), coords=coords, mask=None) + def predict_step( + self, + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], + batch_idx: int, + ) -> Float[Tensor, "batch 1"]: + bags, coords, bag_sizes, survival_target = batch + return self.model(bags, coords=coords, mask=None) class LitSlideSurvival(LitSurvivalBase): @@ -818,3 +852,86 @@ class LitPatientSurvival(LitSlideSurvival): """ supported_features = ["patient"] + + +class LitEncDecTransformer(LitMilClassificationMixin): + def __init__( + self, + *, + dim_input: int, + category_weights: Mapping[TargetLabel, torch.Tensor], + model_class: type[nn.Module] | None = None, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, + categories: Mapping[str, Sequence[Category]], + # Model parameters + d_model: int = 512, + num_encoder_heads: int = 8, + num_decoder_heads: int = 8, + num_encoder_layers: int = 2, + num_decoder_layers: int = 2, + dim_feedforward: int = 2048, + positional_encoding: bool = True, + # Other hparams + learning_rate: float = 1e-4, + # Deployment metadata (optional) — keep parity with `Base` + train_patients: Iterable[PatientId] = (), + valid_patients: Iterable[PatientId] = (), + stamp_version: Version = Version(stamp.__version__), + **hparams: Any, + ) -> None: + weights_dict: dict[TargetLabel, torch.Tensor] = dict(category_weights) + super().__init__( + weights=weights_dict, + learning_rate=learning_rate, + ) + _ = hparams # so we don't get unused parameter warnings + + self.model = EncDecTransformer( + d_features=dim_input, + target_n_outs={t: len(w) for t, w in category_weights.items()}, + d_model=d_model, + num_encoder_heads=num_encoder_heads, + num_decoder_heads=num_decoder_heads, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dim_feedforward=dim_feedforward, + positional_encoding=positional_encoding, + ) + + self.hparams["supported_features"] = "tile" + self.hparams.update({"task": "classification"}) + # ---- Normalize categories into strict mapping[str, list[str]] ---- + if not isinstance(categories, Mapping): + raise ValueError( + "Multi-target classification requires categories as Mapping[str, Sequence[str]]." + ) + + normalized_categories: dict[str, list[str]] = { + str(k): list(v) for k, v in categories.items() + } + + # Sanity check: head size must match category size + for t, w in category_weights.items(): + if t not in normalized_categories: + raise ValueError(f"Missing categories for target '{t}'") + if len(normalized_categories[t]) != len(w): + raise ValueError( + f"Category mismatch for target '{t}': " + f"{len(normalized_categories[t])} categories " + f"but head has {len(w)} outputs." + ) + + self.ground_truth_label = ground_truth_label + self.categories = normalized_categories + + # Deployment metadata — mirror `Base` behavior so checkpoints include + # train/valid patient lists and stamp version for leak-detection and + # compatibility checks. + self.train_patients = train_patients + self.valid_patients = valid_patients + self.stamp_version = str(stamp_version) + + self.save_hyperparameters() + + def forward(self, *args): + return self.model(*args) diff --git a/src/stamp/modeling/models/barspoon.py b/src/stamp/modeling/models/barspoon.py new file mode 100644 index 00000000..e33d3e3f --- /dev/null +++ b/src/stamp/modeling/models/barspoon.py @@ -0,0 +1,367 @@ +""" +Port from https://github.com/KatherLab/barspoon-transformer +""" + +import re +from typing import Any, TypeAlias + +import lightning +import torch +import torch.nn.functional as F +import torchmetrics +from packaging.version import Version +from torch import nn +from torchmetrics.classification import MulticlassAUROC +from torchmetrics.utilities.data import dim_zero_cat + +import stamp +from stamp.types import Bags, BagSizes, CoordinatesBatch + +__all__ = [ + "EncDecTransformer", + "LitMilClassificationMixin", + "SafeMulticlassAUROC", +] + + +TargetLabel: TypeAlias = str + + +class EncDecTransformer(nn.Module): + """An encoder decoder architecture for multilabel classification tasks + + This architecture is a modified version of the one found in [Attention Is + All You Need][1]: First, we project the features into a lower-dimensional + feature space, to prevent the transformer architecture's complexity from + exploding for high-dimensional features. We add sinusodial [positional + encodings][1]. We then encode these projected input tokens using a + transformer encoder stack. Next, we decode these tokens using a set of + class tokens, one per output label. Finally, we forward each of the decoded + tokens through a fully connected layer to get a label-wise prediction. + + PE1 + | + +--+ v +---+ + t1 ->|FC|--+-->| |--+ + . +--+ | E | | + . | x | | + . +--+ | m | | + tn ->|FC|--+-->| |--+ + +--+ ^ +---+ | + | | + PEn v + +---+ +---+ + c1 ---------------->| |-->|FC1|--> s1 + . | D | +---+ . + . | x | . + . | l | +---+ . + ck ---------------->| |-->|FCk|--> sk + +---+ +---+ + + We opted for this architecture instead of a more traditional [Vision + Transformer][2] to improve performance for multi-label predictions with many + labels. Our experiments have shown that adding too many class tokens to a + vision transformer decreases its performance, as the same weights have to + both process the tiles' information and the class token's processing. Using + an encoder-decoder architecture alleviates these issues, as the data-flow of + the class tokens is completely independent of the encoding of the tiles. + Furthermore, analysis has shown that there is almost no interaction between + the different classes in the decoder. While this points to the decoder + being more powerful than needed in practice, this also means that each + label's prediction is mostly independent of the others. As a consequence, + noisy labels will not negatively impact the accuracy of non-noisy ones. + + In our experiments so far we did not see any improvement by adding + positional encodings. We tried + + 1. [Sinusodal encodings][1] + 2. Adding absolute positions to the feature vector, scaled down so the + maximum value in the training dataset is 1. + + Since neither reduced performance and the author perceives the first one to + be more elegant (as the magnitude of the positional encodings is bounded), + we opted to keep the positional encoding regardless in the hopes of it + improving performance on future tasks. + + The architecture _differs_ from the one described in [Attention Is All You + Need][1] as follows: + + 1. There is an initial projection stage to reduce the dimension of the + feature vectors and allow us to use the transformer with arbitrary + features. + 2. Instead of the language translation task described in [Attention Is All + You Need][1], where the tokens of the words translated so far are used + to predict the next word in the sequence, we use a set of fixed, learned + class tokens in conjunction with equally as many independent fully + connected layers to predict multiple labels at once. + + [1]: https://arxiv.org/abs/1706.03762 "Attention Is All You Need" + [2]: https://arxiv.org/abs/2010.11929 + "An Image is Worth 16x16 Words: + Transformers for Image Recognition at Scale" + """ + + def __init__( + self, + d_features: int, + target_n_outs: dict[str, int], + *, + d_model: int = 512, + num_encoder_heads: int = 8, + num_decoder_heads: int = 8, + num_encoder_layers: int = 2, + num_decoder_layers: int = 2, + dim_feedforward: int = 2048, + positional_encoding: bool = True, + ) -> None: + super().__init__() + + self.projector = nn.Sequential(nn.Linear(d_features, d_model), nn.ReLU()) + + encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=num_encoder_heads, + dim_feedforward=dim_feedforward, + batch_first=True, + norm_first=True, + ) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layer, num_layers=num_encoder_layers, enable_nested_tensor=False + ) + + self.target_labels = target_n_outs.keys() + + # One class token per output label + self.class_tokens = nn.ParameterDict( + { + sanitize(target_label): torch.rand(d_model) + for target_label in target_n_outs + } + ) + + decoder_layer = nn.TransformerDecoderLayer( + d_model=d_model, + nhead=num_decoder_heads, + dim_feedforward=dim_feedforward, + batch_first=True, + norm_first=True, + ) + self.transformer_decoder = nn.TransformerDecoder( + decoder_layer, num_layers=num_decoder_layers + ) + + self.heads = nn.ModuleDict( + { + sanitize(target_label): nn.Linear( + in_features=d_model, out_features=n_out + ) + for target_label, n_out in target_n_outs.items() + } + ) + + self.positional_encoding = positional_encoding + + def forward( + self, + tile_tokens: torch.Tensor, + tile_positions: torch.Tensor, + ) -> dict[str, torch.Tensor]: + batch_size, _, _ = tile_tokens.shape + + tile_tokens = self.projector(tile_tokens) # shape: [bs, seq_len, d_model] + + if self.positional_encoding: + # Add positional encodings + d_model = tile_tokens.size(-1) + x = tile_positions.unsqueeze(-1) / 100_000 ** ( + torch.arange(d_model // 4).type_as(tile_positions) / d_model + ) + positional_encodings = torch.cat( + [ + torch.sin(x).flatten(start_dim=-2), + torch.cos(x).flatten(start_dim=-2), + ], + dim=-1, + ) + tile_tokens = tile_tokens + positional_encodings + + tile_tokens = self.transformer_encoder(tile_tokens) + + class_tokens = torch.stack( + [self.class_tokens[sanitize(t)] for t in self.target_labels] + ).expand(batch_size, -1, -1) + class_tokens = self.transformer_decoder(tgt=class_tokens, memory=tile_tokens) + + # Apply the corresponding head to each class token + logits = { + target_label: self.heads[sanitize(target_label)](class_token) + for target_label, class_token in zip( + self.target_labels, + class_tokens.permute(1, 0, 2), # Permute to [target, batch, d_model] + strict=True, + ) + } + + return logits + + +class LitMilClassificationMixin(lightning.LightningModule): + """Makes a module into a multilabel, multiclass Lightning one""" + + supported_features = ["tile"] + + def __init__( + self, + *, + weights: dict[TargetLabel, torch.Tensor], + # Other hparams + learning_rate: float = 1e-4, + stamp_version: Version = Version(stamp.__version__), + **hparams: Any, + ) -> None: + super().__init__() + _ = hparams # So we don't get unused parameter warnings + + # Check if version is compatible. + if stamp_version < Version("2.4.1"): + # Update this as we change our model in incompatible ways! + raise ValueError( + f"model has been built with stamp version {stamp_version} " + f"which is incompatible with the current version." + ) + elif stamp_version > Version(stamp.__version__): + # Let's be strict with models "from the future", + # better fail deadly than have broken results. + raise ValueError( + "model has been built with a stamp version newer than the installed one " + f"({stamp_version} > {stamp.__version__}). " + "Please upgrade stamp to a compatible version." + ) + + self.hparams.update({"task": "classification"}) + + self.learning_rate = learning_rate + + target_aurocs = torchmetrics.MetricCollection( + { + sanitize(target_label): SafeMulticlassAUROC(num_classes=len(weight)) + for target_label, weight in weights.items() + } + ) + for step_name in ["train", "validation", "test"]: + setattr( + self, + f"{step_name}_target_aurocs", + target_aurocs.clone(prefix=f"{step_name}_"), + ) + + self.weights = weights + + self.save_hyperparameters() + + def step( + self, + batch: tuple[Bags, CoordinatesBatch, BagSizes, dict[str, torch.Tensor]] | list, + step_name=None, + ): + """Process a batch with structure (feats, coords, bag_sizes, targets). + + Args: + batch: Tuple of (feats, coords, bag_sizes, targets) where: + - feats: bag features [batch, bag_size, feature_dim] + - coords: tile coordinates [batch, bag_size, 2] + - bag_sizes: number of tiles per bag [batch] + - targets: dict mapping target names to one-hot encoded tensors [batch, num_classes] + step_name: Optional step name for logging ('train', 'validation', 'test'). + """ + feats: Bags + coords: CoordinatesBatch + bag_sizes: BagSizes + targets: dict[str, torch.Tensor] + feats, coords, bag_sizes, targets = batch + logits = self(feats, coords) + + # Calculate the cross entropy loss for each target, then sum them + loss = sum( + F.cross_entropy( + (logit := logits[target_label]), + targets[target_label].type_as(logit), + weight=weight.type_as(logit), + ) + for target_label, weight in self.weights.items() + ) + + if step_name: + self.log( + f"{step_name}_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + + # Update target-wise metrics + for target_label in self.weights: + target_auroc = getattr(self, f"{step_name}_target_aurocs")[ + sanitize(target_label) + ] + is_na = (targets[target_label] == 0).all(dim=1) + target_auroc.update( + logits[target_label][~is_na], + targets[target_label][~is_na].argmax(dim=1), + ) + self.log( + f"{step_name}_{target_label}_auroc", + target_auroc, + on_step=False, + on_epoch=True, + sync_dist=True, + ) + + return loss + + def training_step(self, batch, batch_idx): # pyright: ignore[reportIncompatibleMethodOverride] + return self.step(batch, step_name="train") + + def validation_step(self, batch, batch_idx): # pyright: ignore[reportIncompatibleMethodOverride] + return self.step(batch, step_name="validation") + + def test_step(self, batch, batch_idx): # pyright: ignore[reportIncompatibleMethodOverride] + return self.step(batch, step_name="test") + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + if len(batch) == 2: + feats, positions = batch + else: + feats, positions, _, _ = batch + + logits = self(feats, positions) + + softmaxed = { + target_label: torch.softmax(x, 1) for target_label, x in logits.items() + } + return softmaxed + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + return optimizer + + +def sanitize(x: str) -> str: + return re.sub(r"[^A-Za-z0-9_]", "_", x) + + +class SafeMulticlassAUROC(MulticlassAUROC): + """A Multiclass AUROC that doesn't blow up when no targets are given""" + + def compute(self) -> torch.Tensor: + # Add faux entry if there are none so far + if len(self.preds) == 0: + self.update(torch.zeros(1, self.num_classes), torch.zeros(1).long()) + elif len(dim_zero_cat(self.preds)) == 0: + self.update( + torch.zeros(1, self.num_classes).type_as(self.preds[0]), + torch.zeros(1).long().type_as(self.target[0]), + ) + return super().compute() diff --git a/src/stamp/modeling/models/mlp.py b/src/stamp/modeling/models/mlp.py index e4f8881f..e88a77ca 100644 --- a/src/stamp/modeling/models/mlp.py +++ b/src/stamp/modeling/models/mlp.py @@ -29,7 +29,7 @@ def __init__( layers.append(nn.Dropout(dropout)) in_dim = dim_hidden layers.append(nn.Linear(in_dim, dim_output)) - self.mlp = nn.Sequential(*layers) # type: ignore + self.mlp = nn.Sequential(*layers) @jaxtyped(typechecker=beartype) def forward( diff --git a/src/stamp/modeling/models/vision_tranformer.py b/src/stamp/modeling/models/vision_tranformer.py index b936c5c9..fcd60c12 100644 --- a/src/stamp/modeling/models/vision_tranformer.py +++ b/src/stamp/modeling/models/vision_tranformer.py @@ -56,9 +56,7 @@ def forward( Which query-key pairs to mask from ALiBi (i.e. don't apply ALiBi to). """ weight_logits = torch.einsum("bqf,bkf->bqk", q, k) * (k.size(-1) ** -0.5) - distances = torch.linalg.norm( - coords_q.unsqueeze(2) - coords_k.unsqueeze(1), dim=-1 - ) + distances = torch.cdist(coords_q, coords_k) scaled_distances = self.scale_distance(distances) * self.bias_scale if alibi_mask is not None: diff --git a/src/stamp/modeling/registry.py b/src/stamp/modeling/registry.py index 2205af22..14011084 100644 --- a/src/stamp/modeling/registry.py +++ b/src/stamp/modeling/registry.py @@ -1,6 +1,7 @@ from enum import StrEnum from stamp.modeling.models import ( + LitEncDecTransformer, LitPatientClassifier, LitPatientRegressor, LitPatientSurvival, @@ -21,6 +22,7 @@ class ModelName(StrEnum): MLP = "mlp" TRANS_MIL = "trans_mil" LINEAR = "linear" + BARSPOON = "barspoon" # Map (feature_type, task) → correct Lightning wrapper class @@ -34,6 +36,7 @@ class ModelName(StrEnum): ("patient", "classification"): LitPatientClassifier, ("patient", "regression"): LitPatientRegressor, ("patient", "survival"): LitPatientSurvival, + # ("tile", "multiclass"): LitEncDecTransformer, } @@ -54,6 +57,13 @@ def load_model_class(task: Task, feature_type: str, model_name: ModelName): case ModelName.MLP: from stamp.modeling.models.mlp import MLP as ModelClass + case ModelName.BARSPOON: + from stamp.modeling.models.barspoon import ( + EncDecTransformer as ModelClass, + ) + + LitModelClass = LitEncDecTransformer + case ModelName.LINEAR: from stamp.modeling.models.mlp import ( Linear as ModelClass, diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index b855e2a0..a55ec1e3 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -2,7 +2,7 @@ import shutil from collections.abc import Callable, Mapping, Sequence from pathlib import Path -from typing import cast +from typing import Any, cast import lightning import torch @@ -17,14 +17,9 @@ BagDataset, PatientData, PatientFeatureDataset, + _parse_survival_status, create_dataloader, - detect_feature_type, - filter_complete_patient_data_, - load_patient_level_data, - log_patient_class_summary, - patient_to_ground_truth_from_clini_table_, - patient_to_survival_from_clini_table_, - slide_to_patient_from_slide_table_, + load_patient_data_, ) from stamp.modeling.registry import ModelName, load_model_class from stamp.modeling.transforms import VaryPrecisionTransform @@ -53,66 +48,25 @@ def train_categorical_model_( advanced: AdvancedConfig, ) -> None: """Trains a model based on the feature type.""" - feature_type = detect_feature_type(config.feature_dir) - _logger.info(f"Detected feature type: {feature_type}") - - if feature_type in ("tile", "slide"): - if config.slide_table is None: - raise ValueError("A slide table is required for modeling") - if config.task == "survival": - if config.time_label is None or config.status_label is None: - raise ValueError( - "Both time_label and status_label is required for survival modeling" - ) - patient_to_ground_truth = patient_to_survival_from_clini_table_( - clini_table_path=config.clini_table, - time_label=config.time_label, - status_label=config.status_label, - patient_label=config.patient_label, - ) - else: - if config.ground_truth_label is None: - raise ValueError( - "Ground truth label is required for tile-level modeling" - ) - patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( - clini_table_path=config.clini_table, - ground_truth_label=config.ground_truth_label, - patient_label=config.patient_label, - ) - slide_to_patient = slide_to_patient_from_slide_table_( - slide_table_path=config.slide_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - filename_label=config.filename_label, - ) - patient_to_data = filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, - slide_to_patient=slide_to_patient, - drop_patients_with_missing_ground_truth=True, - ) - elif feature_type == "patient": - # Patient-level: ignore slide_table - if config.slide_table is not None: - _logger.warning("slide_table is ignored for patient-level features.") - - patient_to_data = load_patient_level_data( - task=config.task, - clini_table=config.clini_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - ground_truth_label=config.ground_truth_label, - time_label=config.time_label, - status_label=config.status_label, - ) - else: - raise RuntimeError(f"Unknown feature type: {feature_type}") - if config.task is None: raise ValueError( "task must be set to 'classification' | 'regression' | 'survival'" ) + patient_to_data, feature_type = load_patient_data_( + feature_dir=config.feature_dir, + clini_table=config.clini_table, + slide_table=config.slide_table, + task=config.task, + ground_truth_label=config.ground_truth_label, + time_label=config.time_label, + status_label=config.status_label, + patient_label=config.patient_label, + filename_label=config.filename_label, + drop_patients_with_missing_ground_truth=True, + ) + _logger.info(f"Detected feature type: {feature_type}") + # Train the model (the rest of the logic is unchanged) model, train_dl, valid_dl = setup_model_for_training( patient_to_data=patient_to_data, @@ -145,14 +99,14 @@ def train_categorical_model_( def setup_model_for_training( *, - patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], + patient_to_data: Mapping[PatientId, PatientData[GroundTruth | None | dict]], task: Task, categories: Sequence[Category] | None, train_transform: Callable[[torch.Tensor], torch.Tensor] | None, feature_type: str, advanced: AdvancedConfig, # Metadata, has no effect on model training - ground_truth_label: PandasLabel | None, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, time_label: PandasLabel | None, status_label: PandasLabel | None, clini_table: Path, @@ -193,9 +147,123 @@ def setup_model_for_training( feature_type=feature_type, train_categories=train_categories, ) - log_patient_class_summary( - patient_to_data={pid: patient_to_data[pid] for pid in patient_to_data}, - categories=categories, + + # 1. Default to a model if none is specified + if advanced.model_name is None: + advanced.model_name = ModelName.VIT if feature_type == "tile" else ModelName.MLP + _logger.info( + f"No model specified, defaulting to '{advanced.model_name.value}' for feature type '{feature_type}'" + ) + + # Prevent selecting `barspoon` for single-target classification + if ( + task == "classification" + and isinstance(ground_truth_label, str) + and advanced.model_name == ModelName.BARSPOON + ): + raise ValueError( + "Model 'barspoon' requires multi-target classification. " + "For single-target classification set model_name to 'vit', 'trans_mil', or 'mlp'." + ) + + # 2. Instantiate the lightning wrapper (based on provided task, feature type) and model backbone dynamically + LitModelClass, ModelClass = load_model_class( + task, feature_type, advanced.model_name + ) + + # 3. Validate that the chosen model supports the feature type + if feature_type not in LitModelClass.supported_features: + raise ValueError( + f"Model '{advanced.model_name.value}' does not support feature type '{feature_type}'. " + f"Supported types are: {LitModelClass.supported_features}" + ) + elif feature_type in ( + "slide", + "patient", + ) and advanced.model_name.value.lower() not in {"mlp", "linear"}: + raise ValueError( + f"Feature type '{feature_type}' only supports MLP or Linear. " + f"Got '{advanced.model_name.value}'. Please set model_name='mlp' or 'linear'." + ) + + # 4. Get model-specific hyperparameters + model_specific_params = ( + advanced.model_params.model_dump().get(advanced.model_name.value) or {} + ) + + # 5. Calculate total steps for scheduler + steps_per_epoch = len(train_dl) + total_steps = steps_per_epoch * advanced.max_epochs + + # 6. Prepare common parameters + common_params = { + "categories": train_categories, + "category_weights": category_weights, + "dim_input": dim_feats, + "total_steps": total_steps, + "max_lr": advanced.max_lr, + "div_factor": advanced.div_factor, + # Metadata, has no effect on model training + "model_name": advanced.model_name.value, + "ground_truth_label": ground_truth_label, + "time_label": time_label, + "status_label": status_label, + "train_patients": train_patients, + "valid_patients": valid_patients, + "clini_table": clini_table, + "slide_table": slide_table, + "feature_dir": feature_dir, + } + + all_params = {**common_params, **model_specific_params} + + _logger.info( + f"Instantiating model '{advanced.model_name.value}' with parameters: {model_specific_params}" + ) + _logger.info( + "Other params: max_epochs=%s, patience=%s", + advanced.max_epochs, + advanced.patience, + ) + + model = LitModelClass(model_class=ModelClass, **all_params) + + return model, train_dl, valid_dl + + +def setup_model_from_dataloaders( + *, + train_dl: DataLoader, + valid_dl: DataLoader, + task: Task, + train_categories: Sequence[Category] | Mapping[str, Sequence[Category]], + dim_feats: int, + train_patients: Sequence[PatientId], + valid_patients: Sequence[PatientId], + feature_type: str, + advanced: AdvancedConfig, + # Metadata, has no effect on model training + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, + time_label: PandasLabel | None, + status_label: PandasLabel | None, + clini_table: Path, + slide_table: Path | None, + feature_dir: Path, +) -> lightning.LightningModule: + """Creates a model from pre-built dataloaders (no internal split).""" + + _logger.info( + "Training dataloaders: task=%s, feature_type=%s", + task, + feature_type, + ) + + category_weights: torch.Tensor | dict[str, torch.Tensor] | list = [] + if task == "classification": + category_weights = _compute_class_weights_and_check_categories( + train_dl=train_dl, + feature_type=feature_type, + train_categories=train_categories, ) # 1. Default to a model if none is specified @@ -205,6 +273,17 @@ def setup_model_for_training( f"No model specified, defaulting to '{advanced.model_name.value}' for feature type '{feature_type}'" ) + # Prevent selecting `barspoon` for single-target classification + if ( + task == "classification" + and isinstance(ground_truth_label, str) + and advanced.model_name == ModelName.BARSPOON + ): + raise ValueError( + "Model 'barspoon' requires multi-target classification. " + "For single-target classification set model_name to 'vit', 'trans_mil', or 'mlp'." + ) + # 2. Instantiate the lightning wrapper (based on provided task, feature type) and model backbone dynamically LitModelClass, ModelClass = load_model_class( task, feature_type, advanced.model_name @@ -267,12 +346,12 @@ def setup_model_for_training( model = LitModelClass(model_class=ModelClass, **all_params) - return model, train_dl, valid_dl + return model def setup_dataloaders_for_training( *, - patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], + patient_to_data: Mapping[PatientId, PatientData[GroundTruth | None | dict]], task: Task, categories: Sequence[Category] | None, bag_size: int, @@ -283,7 +362,7 @@ def setup_dataloaders_for_training( ) -> tuple[ DataLoader, DataLoader, - Sequence[Category], + Sequence[Category] | Mapping[str, Sequence[Category]], int, Sequence[PatientId], Sequence[PatientId], @@ -309,11 +388,52 @@ def setup_dataloaders_for_training( "patient_to_data must have a ground truth defined for all targets!" ) + # Multi-target ground truths are only supported for classification + if task != "classification" and any(isinstance(gt, dict) for gt in ground_truths): + raise ValueError( + "Multi-target ground truths are only supported for classification tasks" + ) + if task == "classification": - stratify = ground_truths + # Handle both single and multi-target cases + if ground_truths and isinstance(ground_truths[0], dict): + # Multi-target: use first target for stratification + first_key = list(ground_truths[0].keys())[0] + stratify = [cast(dict, gt)[first_key] for gt in ground_truths] + else: + stratify = ground_truths elif task == "survival": - # Extract event indicator (status) - statuses = [int(gt.split()[1]) for gt in ground_truths] + # Extract event indicator (status). Accept either structured (time,event) + # or legacy string "time status" formats. + statuses: list[int] = [] + for gt in ground_truths: + if isinstance(gt, dict): + raise ValueError( + "Multi-target survival is not supported; provide a single survival time/status per patient" + ) + if isinstance(gt, (tuple, list)) and len(gt) == 2: + status_val = gt[1] + if status_val is None: + raise ValueError( + "Missing survival status for a patient; cannot stratify" + ) + statuses.append(int(status_val)) + else: + parts = str(gt).split() + if len(parts) >= 2: + status_token = parts[1] + elif len(parts) == 1: + status_token = parts[0] + else: + raise ValueError( + "Unrecognized survival ground-truth format for stratification" + ) + parsed_status = _parse_survival_status(status_token) + if parsed_status is None: + raise ValueError( + f"Unrecognized survival status token for stratification: {status_token!r}" + ) + statuses.append(int(parsed_status)) stratify = statuses elif task == "regression": stratify = None @@ -321,7 +441,10 @@ def setup_dataloaders_for_training( train_patients, valid_patients = cast( tuple[Sequence[PatientId], Sequence[PatientId]], train_test_split( - list(patient_to_data), stratify=stratify, shuffle=True, random_state=0 + list(patient_to_data), + stratify=cast(Any, stratify), + shuffle=True, + random_state=0, ), ) @@ -441,28 +564,50 @@ def _compute_class_weights_and_check_categories( *, train_dl: DataLoader, feature_type: str, - train_categories: Sequence[str], -) -> torch.Tensor: + train_categories: Sequence[str] | Mapping[str, Sequence[str]], +) -> torch.Tensor | dict[str, torch.Tensor]: """ Computes class weights and checks for category issues. Logs warnings if there are too few or underpopulated categories. Returns normalized category weights as a torch.Tensor. """ if feature_type == "tile": - category_counts = cast(BagDataset, train_dl.dataset).ground_truths.sum(dim=0) + dataset = cast(BagDataset, train_dl.dataset) + + if isinstance(dataset.ground_truths, list): + # Multi-target case: compute weights per target head + weights_per_target: dict[str, torch.Tensor] = {} + + target_keys = dataset.ground_truths[0].keys() + + for key in target_keys: + stacked = torch.stack([gt[key] for gt in dataset.ground_truths], dim=0) + counts = stacked.sum(dim=0) + w = counts.sum() / counts + weights_per_target[key] = w / w.sum() + + return weights_per_target + else: + category_counts = dataset.ground_truths.sum(dim=0) else: - category_counts = cast( - PatientFeatureDataset, train_dl.dataset - ).ground_truths.sum(dim=0) + dataset = cast(PatientFeatureDataset, train_dl.dataset) + category_counts = dataset.ground_truths.sum(dim=0) cat_ratio_reciprocal = category_counts.sum() / category_counts category_weights = cat_ratio_reciprocal / cat_ratio_reciprocal.sum() if len(train_categories) <= 1: raise ValueError(f"not enough categories to train on: {train_categories}") - elif any(category_counts < 16): + elif (category_counts < 16).any(): + category_counts_list = ( + category_counts.tolist() + if category_counts.dim() > 0 + else [category_counts.item()] + ) underpopulated_categories = { category: int(count) - for category, count in zip(train_categories, category_counts, strict=True) + for category, count in zip( + train_categories, category_counts_list, strict=True + ) if count < 16 } _logger.warning( diff --git a/src/stamp/preprocessing/__init__.py b/src/stamp/preprocessing/__init__.py index ab3ff0d2..b2daa386 100755 --- a/src/stamp/preprocessing/__init__.py +++ b/src/stamp/preprocessing/__init__.py @@ -16,7 +16,6 @@ from tqdm import tqdm import stamp -from stamp.cache import get_processing_code_hash from stamp.preprocessing.config import ExtractorName from stamp.preprocessing.extractor import Extractor from stamp.preprocessing.tiling import ( @@ -32,6 +31,7 @@ SlidePixels, TilePixels, ) +from stamp.utils.cache import get_processing_code_hash __author__ = "Marko van Treeck" __copyright__ = "Copyright (C) 2022-2024 Marko van Treeck" @@ -85,16 +85,9 @@ def __init__( self.canny_cutoff = canny_cutoff self.default_slide_mpp = default_slide_mpp - # Already check if we can extract the MPP here. - # We don't want to kill our dataloader later, - # because that leads to _a lot_ of error messages which are difficult to read - if ( - get_slide_mpp_( - openslide.open_slide(slide_path), default_mpp=default_slide_mpp - ) - is None - ): - raise MPPExtractionError() + # MPP is validated by the caller (extract_()) before constructing this dataset, + # so we no longer open the slide here for a redundant MPP check. + # This removes one openslide.open_slide() call per WSI. def __iter__(self) -> Iterator[tuple[Tensor, Microns, Microns]]: return ( @@ -177,6 +170,11 @@ def extract_( extractor = dino_bloom() + case ExtractorName.RED_DINO: + from stamp.preprocessing.extractor.reddino import red_dino + + extractor = red_dino() + case ExtractorName.VIRCHOW: from stamp.preprocessing.extractor.virchow import virchow @@ -222,6 +220,10 @@ def extract_( extractor = plip() + case ExtractorName.KEEP: + from stamp.preprocessing.extractor.keep import keep + + extractor = keep() case ExtractorName.TICON: from stamp.preprocessing.extractor.ticon import ticon @@ -286,6 +288,15 @@ def extract_( feature_output_path.parent.mkdir(parents=True, exist_ok=True) try: + # Validate MPP here once (avoids a second openslide.open_slide inside _TileDataset.__init__). + if ( + get_slide_mpp_( + openslide.open_slide(slide_path), default_mpp=default_slide_mpp + ) + is None + ): + raise MPPExtractionError() + ds = _TileDataset( slide_path=slide_path, cache_dir=cache_dir, @@ -300,7 +311,15 @@ def extract_( default_slide_mpp=default_slide_mpp, ) # Parallelism is implemented in the dataset iterator already, so one worker is enough! - dl = DataLoader(ds, batch_size=64, num_workers=1, drop_last=False) + # pin_memory speeds up CPU→GPU DMA for tile batches. + # num_workers=1 is intentional: WSI read parallelism is inside _supertiles. + dl = DataLoader( + ds, + batch_size=64, + num_workers=1, + drop_last=False, + pin_memory=torch.cuda.is_available(), + ) feats, xs_um, ys_um = [], [], [] for tiles, xs, ys in tqdm(dl, leave=False): @@ -384,8 +403,9 @@ def _get_rejection_thumb( dtype=bool, ) - for y, x in np.floor(coords_um / tile_size_um).astype(np.uint32): - inclusion_map[y, x] = True + # Vectorized: set all tile positions at once instead of a Python loop. + tile_indices = np.floor(coords_um / tile_size_um).astype(np.uint32) + inclusion_map[tile_indices[:, 0], tile_indices[:, 1]] = True thumb = slide.get_thumbnail(size).convert("RGBA") discarded_im = Image.fromarray( diff --git a/src/stamp/preprocessing/config.py b/src/stamp/preprocessing/config.py index 5eca41dd..e017daf8 100644 --- a/src/stamp/preprocessing/config.py +++ b/src/stamp/preprocessing/config.py @@ -1,7 +1,6 @@ from enum import StrEnum from pathlib import Path -import torch from pydantic import BaseModel, ConfigDict, Field from stamp.types import ImageExtension, Microns, SlideMPP, TilePixels @@ -28,8 +27,10 @@ class ExtractorName(StrEnum): MUSK = "musk" MSTAR = "mstar" PLIP = "plip" + KEEP = "keep" TICON = "ticon" EMPTY = "empty" + RED_DINO = "red-dino" class PreprocessingConfig(BaseModel, arbitrary_types_allowed=True): @@ -46,7 +47,11 @@ class PreprocessingConfig(BaseModel, arbitrary_types_allowed=True): tile_size_px: TilePixels = TilePixels(224) extractor: ExtractorName max_workers: int = 8 - device: str = "cuda" if torch.cuda.is_available() else "cpu" + device: str = Field( + default_factory=lambda: ( + "cuda" if __import__("torch").cuda.is_available() else "cpu" + ) + ) generate_hash: bool = True default_slide_mpp: SlideMPP | None = None diff --git a/src/stamp/preprocessing/extractor/chief_ctranspath.py b/src/stamp/preprocessing/extractor/chief_ctranspath.py index 2d2e6b9b..03f5bba6 100644 --- a/src/stamp/preprocessing/extractor/chief_ctranspath.py +++ b/src/stamp/preprocessing/extractor/chief_ctranspath.py @@ -1,6 +1,6 @@ from pathlib import Path -from stamp.cache import STAMP_CACHE_DIR, file_digest +from stamp.utils.cache import STAMP_CACHE_DIR, file_digest try: import gdown diff --git a/src/stamp/preprocessing/extractor/ctranspath.py b/src/stamp/preprocessing/extractor/ctranspath.py index ba9a277c..d189e279 100644 --- a/src/stamp/preprocessing/extractor/ctranspath.py +++ b/src/stamp/preprocessing/extractor/ctranspath.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Optional, TypeVar, cast -from stamp.cache import STAMP_CACHE_DIR, file_digest +from stamp.utils.cache import STAMP_CACHE_DIR, file_digest try: import gdown @@ -518,7 +518,7 @@ def forward(self, x, mask: Optional[torch.Tensor] = None): attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.view(-1) + self.relative_position_index.view(-1) # pyright: ignore[reportCallIssue] ].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], diff --git a/src/stamp/preprocessing/extractor/dinobloom.py b/src/stamp/preprocessing/extractor/dinobloom.py index fb7713ce..54b79c53 100644 --- a/src/stamp/preprocessing/extractor/dinobloom.py +++ b/src/stamp/preprocessing/extractor/dinobloom.py @@ -8,9 +8,9 @@ from torch import nn from torchvision import transforms -from stamp.cache import STAMP_CACHE_DIR from stamp.preprocessing.config import ExtractorName from stamp.preprocessing.extractor import Extractor +from stamp.utils.cache import STAMP_CACHE_DIR __author__ = "Marko van Treeck" __copyright__ = "Copyright (C) 2022-2025 Marko van Treeck" diff --git a/src/stamp/preprocessing/extractor/keep.py b/src/stamp/preprocessing/extractor/keep.py new file mode 100644 index 00000000..4adc964e --- /dev/null +++ b/src/stamp/preprocessing/extractor/keep.py @@ -0,0 +1,49 @@ +""" +Adopted from https://github.com/MAGIC-AI4Med/KEEP +KEEP (KnowledgE-Enhanced Pathology) +""" + +try: + import torch + from torchvision import transforms + from transformers import AutoModel +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "keep dependencies not installed." + " Please reinstall stamp using `pip install 'stamp[keep]'`" + ) from e + +from stamp.preprocessing.config import ExtractorName +from stamp.preprocessing.extractor import Extractor + + +class KEEPWrapper(torch.nn.Module): + def __init__(self, model) -> None: + super().__init__() + self.model = model + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + return self.model.encode_image(batch) + + +def keep() -> Extractor[KEEPWrapper]: + """Extracts features from slide tiles using the KEEP tile encoder.""" + model = AutoModel.from_pretrained("Astaxanthin/KEEP", trust_remote_code=True) + model.eval() + + transform = transforms.Compose( + [ + transforms.Resize( + size=224, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.CenterCrop(size=(224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ] + ) + + return Extractor( + model=KEEPWrapper(model), + transform=transform, + identifier=ExtractorName.KEEP, + ) diff --git a/src/stamp/preprocessing/extractor/reddino.py b/src/stamp/preprocessing/extractor/reddino.py new file mode 100644 index 00000000..ef92c551 --- /dev/null +++ b/src/stamp/preprocessing/extractor/reddino.py @@ -0,0 +1,62 @@ +""" +Port from https://github.com/Snarci/RedDino +RedDino: A Foundation Model for Red Blood Cell Analysis +""" + +from typing import Callable, cast + +try: + import timm + import torch + from PIL import Image + from torchvision import transforms +except ModuleNotFoundError as e: + raise ModuleNotFoundError("red-dino dependencies not installed.") from e + +from stamp.preprocessing.config import ExtractorName +from stamp.preprocessing.extractor import Extractor + +__license__ = "MIT" + + +class RedDinoClsOnly(torch.nn.Module): + def __init__(self, model) -> None: + super().__init__() + self.model = model + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + out = self.model(batch) + if isinstance(out, tuple): + out = out[0] + # if model returns tokens, return class token + if getattr(out, "ndim", 0) >= 2 and out.shape[1] > 1: + return out[:, 0] + return out + + +def red_dino() -> Extractor[RedDinoClsOnly]: + """Extracts features from single image using RedDino encoder.""" + + model = timm.create_model( + "hf-hub:Snarcy/RedDino-large", + pretrained=True, + ) + + transform = cast( + Callable[[Image.Image], torch.Tensor], + transforms.Compose( + [ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ), + ) + + return Extractor( + model=RedDinoClsOnly(model), + transform=transform, + identifier=ExtractorName.RED_DINO, + ) diff --git a/src/stamp/preprocessing/tiling.py b/src/stamp/preprocessing/tiling.py index 82a3efba..1f560981 100644 --- a/src/stamp/preprocessing/tiling.py +++ b/src/stamp/preprocessing/tiling.py @@ -397,8 +397,10 @@ def _tiles_from_cache_file(cache_file_path: Path) -> Iterator[_Tile]: x_um, y_um = Microns(float(x_um_str)), Microns(float(y_um_str)) with zip_fp.open(name, "r") as tile_fp: + img = Image.open(tile_fp) + img.load() # force eager pixel decode while tile_fp is still open yield _Tile( - image=Image.open(tile_fp), + image=img, coordinates=_XYCoords(x_um, y_um), size=tiler_params["tile_size_um"], ) @@ -461,6 +463,9 @@ def _extract_mpp_from_metadata(slide: openslide.AbstractSlide) -> SlideMPP | Non return None doc = minidom.parseString(xml_path) collection = doc.documentElement + if collection is None: + _logger.error("Document element is None, unable to extract MPP.") + return None images = collection.getElementsByTagName("Image") pixels = images[0].getElementsByTagName("Pixels") mpp = float(pixels[0].getAttribute("PhysicalSizeX")) diff --git a/src/stamp/statistics/__init__.py b/src/stamp/statistics/__init__.py index ec09e1e0..b3243ecc 100644 --- a/src/stamp/statistics/__init__.py +++ b/src/stamp/statistics/__init__.py @@ -1,3 +1,11 @@ +"""Statistics utilities (wrappers) for classification, regression and survival. + +This module provides a small, stable wrapper `compute_stats_` that dispatches +to the task-specific statistic implementations found in the submodules. +""" + +from __future__ import annotations + from collections.abc import Sequence from pathlib import Path from typing import NewType @@ -7,7 +15,10 @@ from matplotlib import pyplot as plt from pydantic import BaseModel, ConfigDict, Field -from stamp.statistics.categorical import categorical_aggregated_ +from stamp.statistics.categorical import ( + categorical_aggregated_, + categorical_aggregated_multitarget_, +) from stamp.statistics.prc import ( plot_multiple_decorated_precision_recall_curves, plot_single_decorated_precision_recall_curve, @@ -17,23 +28,25 @@ plot_multiple_decorated_roc_curves, plot_single_decorated_roc_curve, ) -from stamp.statistics.survival import ( - _plot_km, - _survival_stats_for_csv, -) +from stamp.statistics.survival import _plot_km, _survival_stats_for_csv from stamp.types import PandasLabel, Task +__all__ = ["StatsConfig", "compute_stats_"] + + __author__ = "Marko van Treeck, Minh Duc Nguyen" __copyright__ = "Copyright (C) 2022-2024 Marko van Treeck, Minh Duc Nguyen" __license__ = "MIT" def _read_table(file: Path, **kwargs) -> pd.DataFrame: - """Loads a dataframe from a file.""" + """Load a dataframe from CSV or XLSX file path. + + This small helper centralizes file IO formatting and keeps callers simple. + """ if isinstance(file, Path) and file.suffix == ".xlsx": return pd.read_excel(file, **kwargs) - else: - return pd.read_csv(file, **kwargs) + return pd.read_csv(file, **kwargs) class StatsConfig(BaseModel): @@ -41,7 +54,7 @@ class StatsConfig(BaseModel): task: Task = Field(default="classification") output_dir: Path pred_csvs: list[Path] - ground_truth_label: PandasLabel | None = None + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None true_class: str | None = None time_label: str | None = None status_label: str | None = None @@ -50,47 +63,65 @@ class StatsConfig(BaseModel): _Inches = NewType("_Inches", float) -def compute_stats_( +def _compute_multitarget_classification_stats( *, - task: Task, output_dir: Path, pred_csvs: Sequence[Path], - ground_truth_label: PandasLabel | None = None, - true_class: str | None = None, - time_label: str | None = None, - status_label: str | None = None, + target_labels: Sequence[str], ) -> None: - match task: - case "classification": - if true_class is None or ground_truth_label is None: - raise ValueError( - "both true_class and ground_truth_label are required in statistic configuration" - ) - - preds_dfs = [ - _read_table( - p, - usecols=[ground_truth_label, f"{ground_truth_label}_{true_class}"], - dtype={ - ground_truth_label: str, - f"{ground_truth_label}_{true_class}": float, - }, - ) - for p in pred_csvs - ] - - y_trues = [ - np.array(df[ground_truth_label] == true_class) for df in preds_dfs - ] - y_preds = [ - np.array(df[f"{ground_truth_label}_{true_class}"].values) - for df in preds_dfs - ] - n_bootstrap_samples = 1000 - figure_width = _Inches(3.8) - threshold_cmap = None - - roc_curve_figure_aspect_ratio = 1.08 + """Compute statistics and plots for multi-target classification. + + For each target, creates ROC and PRC curves for each class, + similar to single-target classification. + """ + output_dir.mkdir(parents=True, exist_ok=True) + n_bootstrap_samples = 1000 + figure_width = _Inches(3.8) + roc_curve_figure_aspect_ratio = 1.08 + + # Validate all target labels exist in CSV + first_df = _read_table(pred_csvs[0], nrows=0) + missing_targets = [t for t in target_labels if t not in first_df.columns] + if missing_targets: + raise ValueError( + f"Target labels not found in CSV: {missing_targets}. Available columns: {list(first_df.columns)}" + ) + + # Process each target + for target_label in target_labels: + # Load data for this target + preds_dfs = [] + for p in pred_csvs: + df = _read_table(p, dtype=str) + # Only keep rows where this target has ground truth + df_clean = df.dropna(subset=[target_label]) + if len(df_clean) > 0: + preds_dfs.append(df_clean) + + if not preds_dfs: + continue + + # Get unique classes for this target + classes = sorted(preds_dfs[0][target_label].unique()) + + # Create plots for each class in this target + for true_class in classes: + # Extract ground truth and predictions for this class + y_trues = [] + y_preds = [] + + for df in preds_dfs: + prob_col = f"{target_label}_{true_class}" + if prob_col not in df.columns: + continue + + y_trues.append(np.array(df[target_label] == true_class)) + y_preds.append(np.array(df[prob_col].astype(float).values)) + + if not y_trues: + continue + + # Plot ROC curve fig, ax = plt.subplots( figsize=(figure_width, figure_width * roc_curve_figure_aspect_ratio), dpi=300, @@ -101,63 +132,206 @@ def compute_stats_( ax=ax, y_true=y_trues[0], y_score=y_preds[0], - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", n_bootstrap_samples=n_bootstrap_samples, - threshold_cmap=threshold_cmap, + threshold_cmap=None, ) - else: plot_multiple_decorated_roc_curves( ax=ax, y_trues=y_trues, y_scores=y_preds, - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", n_bootstrap_samples=None, ) fig.tight_layout() - if not output_dir.exists(): - output_dir.mkdir(parents=True, exist_ok=True) - - fig.savefig(output_dir / f"roc-curve_{ground_truth_label}={true_class}.svg") + fig.savefig(output_dir / f"roc-curve_{target_label}={true_class}.svg") plt.close(fig) + # Plot PRC curve fig, ax = plt.subplots( figsize=(figure_width, figure_width * roc_curve_figure_aspect_ratio), dpi=300, ) + if len(preds_dfs) == 1: plot_single_decorated_precision_recall_curve( ax=ax, y_true=y_trues[0], y_score=y_preds[0], - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", n_bootstrap_samples=n_bootstrap_samples, ) - else: plot_multiple_decorated_precision_recall_curves( ax=ax, y_trues=y_trues, y_scores=y_preds, - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", ) fig.tight_layout() - fig.savefig(output_dir / f"pr-curve_{ground_truth_label}={true_class}.svg") + fig.savefig(output_dir / f"pr-curve_{target_label}={true_class}.svg") plt.close(fig) - categorical_aggregated_( - preds_csvs=pred_csvs, - ground_truth_label=ground_truth_label, - outpath=output_dir, + # Compute aggregated statistics for all targets + categorical_aggregated_multitarget_( + preds_csvs=pred_csvs, + outpath=output_dir, + target_labels=target_labels, + ) + + +def compute_stats_( + *, + task: Task, + output_dir: Path, + pred_csvs: Sequence[Path], + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None, + true_class: str | None = None, + time_label: str | None = None, + status_label: str | None = None, +) -> None: + """Compute and save statistics for the provided task and prediction CSVs. + + This wrapper keeps the external API stable while delegating the detailed + computations and plotting to the submodules under `stamp.statistics.*`. + """ + match task: + case "classification": + # Check if multi-target based on ground_truth_label type + is_multitarget = ( + isinstance(ground_truth_label, (list, tuple)) + and len(ground_truth_label) > 1 ) + if is_multitarget: + # Multi-target classification + if not isinstance(ground_truth_label, (list, tuple)): + raise ValueError( + "ground_truth_label must be a list or tuple for multi-target classification" + ) + _compute_multitarget_classification_stats( + output_dir=output_dir, + pred_csvs=pred_csvs, + target_labels=list(ground_truth_label), + ) + else: + # Single-target classification (original behavior) + if true_class is None or ground_truth_label is None: + raise ValueError( + "both true_class and ground_truth_label are required in statistic configuration" + ) + if not isinstance(ground_truth_label, str): + raise ValueError( + "ground_truth_label must be a string for single-target classification" + ) + + preds_dfs = [ + _read_table( + p, + usecols=[ + ground_truth_label, + f"{ground_truth_label}_{true_class}", + ], + dtype={ + ground_truth_label: str, + f"{ground_truth_label}_{true_class}": float, + }, + ) + for p in pred_csvs + ] + + y_trues = [ + np.array(df[ground_truth_label] == true_class) for df in preds_dfs + ] + y_preds = [ + np.array(df[f"{ground_truth_label}_{true_class}"].values) + for df in preds_dfs + ] + n_bootstrap_samples = 1000 + figure_width = _Inches(3.8) + threshold_cmap = None + + roc_curve_figure_aspect_ratio = 1.08 + fig, ax = plt.subplots( + figsize=( + figure_width, + figure_width * roc_curve_figure_aspect_ratio, + ), + dpi=300, + ) + + if len(preds_dfs) == 1: + plot_single_decorated_roc_curve( + ax=ax, + y_true=y_trues[0], + y_score=y_preds[0], + title=f"{ground_truth_label} = {true_class}", + n_bootstrap_samples=n_bootstrap_samples, + threshold_cmap=threshold_cmap, + ) + else: + plot_multiple_decorated_roc_curves( + ax=ax, + y_trues=y_trues, + y_scores=y_preds, + title=f"{ground_truth_label} = {true_class}", + n_bootstrap_samples=None, + ) + + fig.tight_layout() + output_dir.mkdir(parents=True, exist_ok=True) + fig.savefig( + output_dir / f"roc-curve_{ground_truth_label}={true_class}.svg" + ) + plt.close(fig) + + fig, ax = plt.subplots( + figsize=( + figure_width, + figure_width * roc_curve_figure_aspect_ratio, + ), + dpi=300, + ) + if len(preds_dfs) == 1: + plot_single_decorated_precision_recall_curve( + ax=ax, + y_true=y_trues[0], + y_score=y_preds[0], + title=f"{ground_truth_label} = {true_class}", + n_bootstrap_samples=n_bootstrap_samples, + ) + else: + plot_multiple_decorated_precision_recall_curves( + ax=ax, + y_trues=y_trues, + y_scores=y_preds, + title=f"{ground_truth_label} = {true_class}", + ) + + fig.tight_layout() + fig.savefig( + output_dir / f"pr-curve_{ground_truth_label}={true_class}.svg" + ) + plt.close(fig) + + categorical_aggregated_( + preds_csvs=pred_csvs, + ground_truth_label=ground_truth_label, + outpath=output_dir, + ) + case "regression": if ground_truth_label is None: raise ValueError( "no ground_truth_label configuration supplied in statistic" ) + if not isinstance(ground_truth_label, str): + raise ValueError( + "ground_truth_label must be a string for regression (multi-target regression not yet supported)" + ) regression_aggregated_( preds_csvs=pred_csvs, ground_truth_label=ground_truth_label, @@ -203,12 +377,7 @@ def compute_stats_( cut_off=cut_off, ) - # ------------------------------------------------------------------ # # Save individual and aggregated CSVs - # ------------------------------------------------------------------ # stats_df = pd.DataFrame(per_fold).transpose() stats_df.index.name = "fold_name" # label the index column stats_df.to_csv(output_dir / "survival-stats_individual.csv", index=True) - - # agg_df = _aggregate_with_ci(stats_df) - # agg_df.to_csv(output_dir / "survival-stats_aggregated.csv", index=True) diff --git a/src/stamp/statistics/categorical.py b/src/stamp/statistics/categorical.py index 2b5c859e..e19b1659 100755 --- a/src/stamp/statistics/categorical.py +++ b/src/stamp/statistics/categorical.py @@ -21,6 +21,30 @@ ] +def _detect_targets_from_columns(columns: Sequence[str]) -> list[str]: + """Detect target columns from CSV column names. + + Assumes multi-target format where each target has: + - A ground truth column (target name) + - A prediction column (pred_{target}) + - Probability columns ({target}_{class1}, {target}_{class2}, ...) + + Returns: + List of target names detected. + """ + # Convert to list to handle pandas Index + columns = list(columns) + targets = [] + for col in columns: + # Look for columns that start with "pred_" + if col.startswith("pred_"): + target_name = col[5:] # Remove "pred_" prefix + # Verify the target column exists + if target_name in columns: + targets.append(target_name) + return sorted(targets) + + def _categorical(preds_df: pd.DataFrame, target_label: str) -> pd.DataFrame: """Calculates some stats for categorical prediction tables. @@ -29,7 +53,9 @@ def _categorical(preds_df: pd.DataFrame, target_label: str) -> pd.DataFrame: """ categories = preds_df[target_label].unique() y_true = preds_df[target_label] - y_pred = preds_df[[f"{target_label}_{cat}" for cat in categories]].map(float).values + y_pred = ( + preds_df[[f"{target_label}_{cat}" for cat in categories]].astype(float).values + ) stats_df = pd.DataFrame(index=categories) @@ -38,29 +64,29 @@ def _categorical(preds_df: pd.DataFrame, target_label: str) -> pd.DataFrame: # roc_auc stats_df["roc_auc_score"] = [ - metrics.roc_auc_score(y_true == cat, y_pred[:, i]) # pyright: ignore[reportCallIssue,reportArgumentType] + metrics.roc_auc_score(y_true == cat, y_pred[:, i]) for i, cat in enumerate(categories) ] # average_precision stats_df["average_precision_score"] = [ - metrics.average_precision_score(y_true == cat, y_pred[:, i]) # pyright: ignore[reportCallIssue,reportArgumentType] + metrics.average_precision_score(y_true == cat, y_pred[:, i]) for i, cat in enumerate(categories) ] # f1 score y_pred_labels = categories[y_pred.argmax(axis=1)] stats_df["f1_score"] = [ - metrics.f1_score(y_true == cat, y_pred_labels == cat) # pyright: ignore[reportCallIssue,reportArgumentType] - for cat in categories + metrics.f1_score(y_true == cat, y_pred_labels == cat) for cat in categories ] # p values p_values = [] for i, cat in enumerate(categories): - pos_scores = y_pred[:, i][y_true == cat] # pyright: ignore[reportCallIssue,reportArgumentType] - neg_scores = y_pred[:, i][y_true != cat] # pyright: ignore[reportCallIssue,reportArgumentType] - p_values.append(st.ttest_ind(pos_scores, neg_scores).pvalue) # pyright: ignore[reportGeneralTypeIssues, reportAttributeAccessIssue] + pos_scores = y_pred[:, i][y_true == cat] + neg_scores = y_pred[:, i][y_true != cat] + _, p_value = st.ttest_ind(pos_scores, neg_scores) + p_values.append(p_value) stats_df["p_value"] = p_values assert set(_score_labels) & set(stats_df.columns) == set(_score_labels) @@ -110,3 +136,65 @@ def categorical_aggregated_( preds_df.to_csv(outpath / f"{ground_truth_label}_categorical-stats_individual.csv") stats_df = _aggregate_categorical_stats(preds_df.reset_index()) stats_df.to_csv(outpath / f"{ground_truth_label}_categorical-stats_aggregated.csv") + + +def categorical_aggregated_multitarget_( + *, + preds_csvs: Sequence[Path], + outpath: Path, + target_labels: Sequence[str], +) -> None: + """Calculate statistics for multi-target categorical deployments. + + Args: + preds_csvs: CSV files containing predictions. + outpath: Path to save the results to. + target_labels: List of target labels to compute statistics for. + + This will apply `_categorical` to each target in the multi-target setup, + calculate statistics per target, and save both individual and aggregated results. + """ + outpath.mkdir(parents=True, exist_ok=True) + + all_target_stats = {} + + # Read each CSV once and cache so we don't re-read N×M times. + csv_cache: dict[str, pd.DataFrame] = { + Path(p).parent.name: pd.read_csv(p, dtype=str) for p in preds_csvs + } + + for target_label in target_labels: + # Process each target separately + preds_dfs = {} + for fold_name, df in csv_cache.items(): + # Drop rows where this target's ground truth is missing + df_clean = df.dropna(subset=[target_label]) + if len(df_clean) > 0: + preds_dfs[fold_name] = _categorical(df_clean, target_label) + + if not preds_dfs: + continue + + # Concatenate and save individual stats for this target + preds_df = pd.concat(preds_dfs).sort_index() + preds_df.to_csv(outpath / f"{target_label}_categorical-stats_individual.csv") + + # Aggregate stats for this target + stats_df = _aggregate_categorical_stats(preds_df.reset_index()) + stats_df.to_csv(outpath / f"{target_label}_categorical-stats_aggregated.csv") + + # Store for summary + all_target_stats[target_label] = stats_df + + # Create a combined summary across all targets + if all_target_stats: + summary_dfs = [] + for target_name, stats_df in all_target_stats.items(): + stats_copy = stats_df.copy() + stats_copy.index = pd.MultiIndex.from_product( + [[target_name], stats_copy.index], names=["target", "class"] + ) + summary_dfs.append(stats_copy) + + combined_summary = pd.concat(summary_dfs) + combined_summary.to_csv(outpath / "multitarget_categorical-stats_summary.csv") diff --git a/src/stamp/statistics/prc.py b/src/stamp/statistics/prc.py index 867885e9..dd58ea2e 100755 --- a/src/stamp/statistics/prc.py +++ b/src/stamp/statistics/prc.py @@ -6,7 +6,6 @@ import scipy.stats as st from jaxtyping import Bool, Float from matplotlib.axes import Axes -from scipy.interpolate import interp1d from sklearn.metrics import ( auc, average_precision_score, @@ -56,15 +55,9 @@ def _plot_bootstrapped_pr_curve( continue precision, recall, _ = precision_recall_curve(sample_y_true, sample_y_pred) - # Create an interpolation function with decreasing values - interp_func = interp1d( - recall[::-1], - precision[::-1], - kind="linear", - fill_value=np.nan, - bounds_error=False, - ) - interp_prc = interp_func(interp_recall) + # np.interp requires increasing x; precision_recall_curve returns + # decreasing recall, so reverse both arrays. + interp_prc = np.interp(interp_recall, recall[::-1], precision[::-1]) interp_prcs[i] = interp_prc bootstrapped_auprc = auc(interp_recall, interp_prc) bootstrap_auprcs.append(bootstrapped_auprc) diff --git a/src/stamp/statistics/roc.py b/src/stamp/statistics/roc.py index d42413a4..338cf876 100755 --- a/src/stamp/statistics/roc.py +++ b/src/stamp/statistics/roc.py @@ -180,9 +180,11 @@ def _plot_bootstrapped_roc_curve( # and then sample the bottom 0.025 / top 0.975 quantile point # for each sampled fpr-position rng = np.random.default_rng() - interp_rocs = [] interp_fpr = np.linspace(0, 1, num=1000) + # Pre-allocate; rows that correspond to skipped samples stay NaN. + interp_rocs = np.full((n_bootstrap_samples, len(interp_fpr)), np.nan) bootstrap_aucs: list[float] = [] + valid_row = 0 for _ in trange(n_bootstrap_samples, desc="Bootstrapping ROC curves", leave=False): sample_idxs = rng.choice(len(y_true), len(y_true)) sample_y_true = y_true[sample_idxs] @@ -190,15 +192,17 @@ def _plot_bootstrapped_roc_curve( if len(np.unique(sample_y_true)) != 2: continue fpr, tpr, thresh = roc_curve(sample_y_true, sample_y_score) - interp_rocs.append(np.interp(interp_fpr, fpr, tpr)) + interp_rocs[valid_row] = np.interp(interp_fpr, fpr, tpr) + valid_row += 1 bootstrap_aucs.append(float(roc_auc_score(sample_y_true, sample_y_score))) + interp_rocs = interp_rocs[:valid_row] # trim unused rows roc_lower, roc_upper = cast( tuple[ Float[np.ndarray, "fpr"], # noqa: F821 Float[np.ndarray, "fpr"], # noqa: F821 ], - np.quantile(interp_rocs, [0.025, 0.975], axis=0), + np.nanquantile(interp_rocs, [0.025, 0.975], axis=0), ) ax.fill_between(interp_fpr, roc_lower, roc_upper, alpha=0.5) diff --git a/src/stamp/statistics/survival.py b/src/stamp/statistics/survival.py index 78fb51cd..8fbf5d63 100644 --- a/src/stamp/statistics/survival.py +++ b/src/stamp/statistics/survival.py @@ -4,11 +4,11 @@ from pathlib import Path +import lifelines.plotting as lifelines_plotting import matplotlib.pyplot as plt import numpy as np import pandas as pd from lifelines import KaplanMeierFitter -from lifelines.plotting import add_at_risk_counts from lifelines.statistics import logrank_test from lifelines.utils import concordance_index @@ -101,7 +101,7 @@ def _plot_km( if risk_label is None: risk_label = "pred_score" - # --- Clean NaNs and invalid entries --- + # Clean NaNs and invalid entries df = df.replace(["NaN", "nan", "None", "Inf", "inf"], np.nan) df = df.dropna(subset=[time_label, status_label, risk_label]).copy() df = df[df[status_label].isin([0, 1])] @@ -143,20 +143,18 @@ def _plot_km( if len(high_df) > 0: fitters.append(kmf_high) + # add at-risk table for fitted curves if len(fitters) > 0: - add_at_risk_counts(*fitters, ax=ax) - - # log-rank only if both groups exist - if len(low_df) > 0 and len(high_df) > 0: - res = logrank_test( - low_df[time_label], - high_df[time_label], - event_observed_A=low_df[status_label], - event_observed_B=high_df[status_label], - ) - logrank_p = float(res.p_value) - else: - logrank_p = float("nan") + lifelines_plotting.add_at_risk_counts(*fitters, ax=ax) + + # log-rank and c-index + res = logrank_test( + low_df[time_label], + high_df[time_label], + event_observed_A=low_df[status_label], + event_observed_B=high_df[status_label], + ) + logrank_p = float(res.p_value) c_used, used, *_ = _cindex(time, event, risk) ax.text( diff --git a/src/stamp/types.py b/src/stamp/types.py index f1f571cc..fddb9d4c 100644 --- a/src/stamp/types.py +++ b/src/stamp/types.py @@ -37,6 +37,9 @@ PatientId: TypeAlias = str GroundTruth: TypeAlias = str +# Survival ground-truth is represented as (time, event) +SurvivalGroundTruth: TypeAlias = tuple[float | None, int | None] +MultiClassGroundTruth: TypeAlias = tuple[str, ...] FeaturePath = NewType("FeaturePath", Path) Category: TypeAlias = str diff --git a/src/stamp/utils/cache.py b/src/stamp/utils/cache.py new file mode 100644 index 00000000..c0b00a51 --- /dev/null +++ b/src/stamp/utils/cache.py @@ -0,0 +1,55 @@ +import hashlib +import os +import shutil +import urllib.request +from functools import cache +from pathlib import Path +from typing import Final + +STAMP_CACHE_DIR: Final[Path] = ( + Path(os.environ.get("XDG_CACHE_HOME") or (Path.home() / ".cache")) / "stamp" +) +# Directory is created on demand (inside functions that write to it) +# so that a bare import of this module does not cause filesystem I/O. + + +def download_file(*, url: str, file_name: str, sha256sum: str) -> Path: + """Downloads a file, or loads it from cache if it has been downloaded before. + + The checksum is only verified on the initial download. Once the file + exists in the cache it is trusted as-is to avoid re-reading large weight + files (which can be ~1 GB) on every run. + """ + STAMP_CACHE_DIR.mkdir(exist_ok=True, parents=True) + outfile_path = STAMP_CACHE_DIR / file_name + if outfile_path.is_file(): + # File already cached and verified on first download — skip re-hash. + return outfile_path + + filename, _ = urllib.request.urlretrieve(url) + with open(filename, "rb") as weight_file: + digest = hashlib.file_digest(weight_file, "sha256") + assert digest.hexdigest() == sha256sum, "hash of downloaded file did not match" + shutil.move(filename, outfile_path) + return outfile_path + + +def file_digest(file: str | Path) -> str: + with open(file, "rb") as fp: + return hashlib.file_digest(fp, "sha256").hexdigest() + + +@cache +def get_processing_code_hash(file_path: Path) -> str: + """The hash of the entire process codebase. + + It is used to assure that features extracted with different versions of + this code base can be identified as such after the fact. + """ + hasher = hashlib.sha256() + for py_file in sorted(file_path.parent.glob("*.py")): + # Use file_digest to stream the file in chunks instead of reading + # the entire source into memory at once. + with open(py_file, "rb") as fp: + hasher.update(hashlib.file_digest(fp, "sha256").digest()) + return hasher.hexdigest() diff --git a/src/stamp/config.py b/src/stamp/utils/config.py similarity index 100% rename from src/stamp/config.py rename to src/stamp/utils/config.py diff --git a/src/stamp/seed.py b/src/stamp/utils/seed.py similarity index 100% rename from src/stamp/seed.py rename to src/stamp/utils/seed.py diff --git a/tests/conftest.py b/tests/conftest.py index f7f4f005..16dcae5b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,15 @@ +import multiprocessing from stamp.preprocessing import ExtractorName +# Ensure tests use a safe multiprocessing start method to avoid +# fork-from-multi-threaded-process warnings on Linux. +if multiprocessing.get_start_method(allow_none=True) != "spawn": + try: + multiprocessing.set_start_method("spawn") + except RuntimeError: + # start method already set by the test runner/environment + pass + # This lets you choose which extractors to run on pytest. Useful for the # CI pipeline as extractors like MUSK take an eternity on CPU. diff --git a/tests/random_data.py b/tests/random_data.py index bd95d1bc..b79c9f42 100644 --- a/tests/random_data.py +++ b/tests/random_data.py @@ -74,13 +74,13 @@ def create_random_dataset( clini_df = pd.DataFrame( patient_to_ground_truth.items(), - columns=["patient", "ground-truth"], # pyright: ignore[reportArgumentType] + columns=["patient", "ground-truth"], ) clini_df.to_csv(clini_path, index=False) slide_df = pd.DataFrame( slide_path_to_patient.items(), - columns=["slide_path", "patient"], # pyright: ignore[reportArgumentType] + columns=["slide_path", "patient"], ) slide_df.to_csv(slide_path, index=False) @@ -130,7 +130,7 @@ def create_random_regression_dataset( # --- Write clini + slide tables --- clini_df = pd.DataFrame(patient_to_target, columns=["patient", "target"]) - clini_df["target"] = clini_df["target"].astype(float) # ✅ ensure numeric dtype + clini_df["target"] = clini_df["target"].astype(float) # ensure numeric dtype clini_df.to_csv(clini_path, index=False) slide_df = pd.DataFrame( @@ -254,6 +254,92 @@ def create_random_patient_level_dataset( return clini_path, slide_path, feat_dir, categories +def create_random_multi_target_dataset( + *, + dir: Path, + n_patients: int, + max_slides_per_patient: int, + min_tiles_per_slide: int, + max_tiles_per_slide: int, + feat_dim: int, + target_labels: Sequence[str], + categories_per_target: Sequence[Sequence[str]], + extractor_name: str = "random-test-generator", + min_slides_per_patient: int = 1, +) -> tuple[Path, Path, Path, Sequence[Sequence[str]]]: + """ + Create a random multi-target tile-level dataset. + + Args: + dir: Directory to create dataset in + n_patients: Number of patients + max_slides_per_patient: Maximum slides per patient + min_tiles_per_slide: Minimum tiles per slide + max_tiles_per_slide: Maximum tiles per slide + feat_dim: Feature dimension + target_labels: Names of the target columns (e.g., ["subtype", "grade"]) + categories_per_target: Categories for each target (e.g., [["A", "B"], ["1", "2", "3"]]) + extractor_name: Name of the extractor + min_slides_per_patient: Minimum slides per patient + + Returns: + Tuple of (clini_path, slide_path, feat_dir, categories_per_target) + """ + if len(target_labels) != len(categories_per_target): + raise ValueError( + "target_labels and categories_per_target must have same length" + ) + + slide_path_to_patient: Mapping[Path, PatientId] = {} + patient_to_ground_truths: Mapping[PatientId, dict[str, str]] = {} + clini_path = dir / "clini.csv" + slide_path = dir / "slide.csv" + + feat_dir = dir / "feats" + feat_dir.mkdir() + + for _ in range(n_patients): + # Random patient ID + patient_id = random_string(16) + + # Generate ground truths for each target + ground_truths = {} + for target_label, categories in zip(target_labels, categories_per_target): + ground_truths[target_label] = random.choice(categories) + + patient_to_ground_truths[patient_id] = ground_truths + + # Generate some slides + for _ in range(random.randint(min_slides_per_patient, max_slides_per_patient)): + slide_path_to_patient[ + create_random_feature_file( + tmp_path=feat_dir, + min_tiles=min_tiles_per_slide, + max_tiles=max_tiles_per_slide, + feat_dim=feat_dim, + extractor_name=extractor_name, + ).relative_to(feat_dir) + ] = patient_id + + # Create clinical table with multiple target columns + clini_data = [] + for patient_id, ground_truths in patient_to_ground_truths.items(): + row = {"patient": patient_id} + row.update(ground_truths) + clini_data.append(row) + + clini_df = pd.DataFrame(clini_data) + clini_df.to_csv(clini_path, index=False) + + slide_df = pd.DataFrame( + slide_path_to_patient.items(), + columns=["slide_path", "patient"], + ) + slide_df.to_csv(slide_path, index=False) + + return clini_path, slide_path, feat_dir, categories_per_target + + def create_random_feature_file( *, tmp_path: Path, diff --git a/tests/test_cache_tiles.py b/tests/test_cache_tiles.py index a665e92c..d9f1411d 100644 --- a/tests/test_cache_tiles.py +++ b/tests/test_cache_tiles.py @@ -6,10 +6,10 @@ import numpy as np import pytest -from stamp.cache import download_file from stamp.preprocessing import Microns, TilePixels from stamp.preprocessing.tiling import _Tile, tiles_with_cache from stamp.types import ImageExtension, SlidePixels +from stamp.utils.cache import download_file def _get_tiles_and_images( diff --git a/tests/test_config.py b/tests/test_config.py index 15b5dd80..fbb53a6c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,7 +1,6 @@ # %% from pathlib import Path -from stamp.config import StampConfig from stamp.heatmaps.config import HeatmapConfig from stamp.modeling.config import ( AdvancedConfig, @@ -21,6 +20,7 @@ TilePixels, ) from stamp.statistics import StatsConfig +from stamp.utils.config import StampConfig def test_config_parsing() -> None: diff --git a/tests/test_crossval.py b/tests/test_crossval.py index 184a5c23..5e53f50c 100644 --- a/tests/test_crossval.py +++ b/tests/test_crossval.py @@ -13,7 +13,7 @@ VitModelParams, ) from stamp.modeling.crossval import categorical_crossval_ -from stamp.seed import Seed +from stamp.utils.seed import Seed @pytest.mark.slow diff --git a/tests/test_data.py b/tests/test_data.py index 564d6426..a60d830e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -3,6 +3,7 @@ from pathlib import Path import h5py +import pandas as pd import pytest import torch from random_data import ( @@ -21,9 +22,9 @@ PatientFeatureDataset, filter_complete_patient_data_, get_coords, + patient_to_ground_truth_from_clini_table_, slide_to_patient_from_slide_table_, ) -from stamp.seed import Seed from stamp.types import ( BagSize, FeaturePath, @@ -33,6 +34,7 @@ SlideMPP, TilePixels, ) +from stamp.utils.seed import Seed @pytest.mark.filterwarnings("ignore:some patients have no associated slides") @@ -85,6 +87,31 @@ def test_get_cohort_df(tmp_path: Path) -> None: } +def test_patient_to_ground_truth_multi_target(tmp_path: Path) -> None: + """Verify multi-target clini parsing returns dicts and drops rows missing all targets.""" + df = pd.DataFrame( + { + "patient": ["p1", "p2", "p3", "p4"], + "subtype": ["A", None, "B", None], + "grade": ["1", "2", None, None], + } + ) + df.to_csv(tmp_path / "clini.csv", index=False) + + result = patient_to_ground_truth_from_clini_table_( + clini_table_path=tmp_path / "clini.csv", + patient_label="patient", + ground_truth_label=["subtype", "grade"], + ) + + # p4 has both targets missing → dropped + assert "p4" not in result + + assert result["p1"] == {"subtype": "A", "grade": "1"} + assert result["p2"] == {"subtype": None, "grade": "2"} + assert result["p3"] == {"subtype": "B", "grade": None} + + @pytest.mark.parametrize( "feature_file_creator", [make_feature_file, make_old_feature_file], diff --git a/tests/test_deployment.py b/tests/test_deployment.py index de20ea12..55939e66 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import cast import numpy as np import pytest @@ -24,8 +25,8 @@ ) from stamp.modeling.models.mlp import MLP from stamp.modeling.models.vision_tranformer import VisionTransformer -from stamp.seed import Seed from stamp.types import GroundTruth, PatientId, Task +from stamp.utils.seed import Seed def test_predict_patient_level( @@ -83,7 +84,8 @@ def test_predict_patient_level( assert len(predictions) == len(patient_to_data) for pid in patient_ids: - assert predictions[pid].shape == torch.Size([3]), "expected one score per class" + pred = cast(torch.Tensor, predictions[pid]) + assert pred.shape == torch.Size([3]), "expected one score per class" # Check if scores are consistent between runs and different for different patients more_patient_ids = [PatientId(f"pat{i}") for i in range(8, 11)] @@ -124,11 +126,13 @@ def test_predict_patient_level( assert len(more_predictions) == len(all_patient_ids) # Different patients should give different results assert not torch.allclose( - more_predictions[more_patient_ids[0]], more_predictions[more_patient_ids[1]] + cast(torch.Tensor, more_predictions[more_patient_ids[0]]), + cast(torch.Tensor, more_predictions[more_patient_ids[1]]), ), "different inputs should give different results" # The same patient should yield the same result assert torch.allclose( - predictions[patient_ids[0]], more_predictions[patient_ids[0]] + cast(torch.Tensor, predictions[patient_ids[0]]), + cast(torch.Tensor, more_predictions[patient_ids[0]]), ), "the same inputs should repeatedly yield the same results" @@ -163,7 +167,7 @@ def test_to_prediction_df(task: str) -> None: ) if task == "classification": preds_df = _to_prediction_df( - categories=list(model.categories), # type: ignore + categories=list(cast(list, model.categories)), patient_to_ground_truth={ PatientId("pat5"): GroundTruth("foo"), PatientId("pat6"): None, @@ -192,13 +196,13 @@ def test_to_prediction_df(task: str) -> None: # Check if no loss / target is given for targets with missing ground truths no_ground_truth = preds_df[preds_df["patient"].isin(["pat6"])] - assert no_ground_truth["target"].isna().all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] - assert no_ground_truth["loss"].isna().all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] + assert no_ground_truth["target"].isna().all() + assert no_ground_truth["loss"].isna().all() # Check if loss / target is given for targets with ground truths with_ground_truth = preds_df[preds_df["patient"].isin(["pat5", "pat7"])] - assert (~with_ground_truth["target"].isna()).all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] - assert (~with_ground_truth["loss"].isna()).all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] + assert (~with_ground_truth["target"].isna()).all() + assert (~with_ground_truth["loss"].isna()).all() elif task == "regression": patient_to_ground_truth = {} @@ -217,8 +221,8 @@ def test_to_prediction_df(task: str) -> None: assert preds_df["loss"].isna().all() else: patient_to_ground_truth = { - PatientId("p1"): "10.0 1", - PatientId("p2"): "12.3 0", + PatientId("p1"): (10.0, 1), + PatientId("p2"): (12.3, 0), } predictions = { PatientId("p1"): torch.tensor([0.8]), @@ -295,22 +299,24 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: div_factor=25.0, ) - # ---- Build tile-level feature file so batch = (bags, coords, bag_sizes, gt) + # Build tile-level feature file so batch = (bags, coords, bag_sizes, gt) if task == "classification": feature_file = make_old_feature_file( feats=torch.rand(23, dim_feats), coords=torch.rand(23, 2) ) - gt = GroundTruth("foo") + gt = cast(GroundTruth, "foo") elif task == "regression": feature_file = make_old_feature_file( feats=torch.rand(30, dim_feats), coords=torch.rand(30, 2) ) - gt = GroundTruth(42.5) # numeric target wrapped for typing + gt = cast(GroundTruth, 42.5) # numeric target wrapped for typing else: # survival feature_file = make_old_feature_file( feats=torch.rand(40, dim_feats), coords=torch.rand(40, 2) ) - gt = GroundTruth("12 0") # (time, status) + gt = cast( + GroundTruth, (12.0, 0) + ) # (time, status) - use raw tuple (GroundTruth is a str alias) patient_to_data = { PatientId("pat_test"): PatientData( @@ -319,7 +325,7 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: ) } - # ---- Use tile_bag_dataloader for ALL tasks (so batch has 4 elements) + # Use tile_bag_dataloader for ALL tasks (so batch has 4 elements) test_dl, _ = tile_bag_dataloader( task=task, # "classification" | "regression" | "survival" patient_data=list(patient_to_data.values()), @@ -341,12 +347,17 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: assert len(predictions) == 1 pred = list(predictions.values())[0] if task == "classification": - assert pred.shape == torch.Size([len(categories)]) + pred_tensor = cast(torch.Tensor, pred) + assert pred_tensor.shape == torch.Size([len(categories)]) elif task == "regression": - assert pred.shape == torch.Size([1]) + pred_tensor = cast(torch.Tensor, pred) + assert pred_tensor.shape == torch.Size([1]) else: # survival # Cox model → scalar log-risk, KM → vector or matrix - assert pred.ndim in (0, 1, 2), f"unexpected survival output shape: {pred.shape}" + pred_tensor = cast(torch.Tensor, pred) + assert pred_tensor.ndim in (0, 1, 2), ( + f"unexpected survival output shape: {pred_tensor.shape}" + ) # Repeatability predictions2 = _predict( @@ -356,4 +367,6 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: accelerator="cpu", ) for pid in predictions: - assert torch.allclose(predictions[pid], predictions2[pid]) + assert torch.allclose( + cast(torch.Tensor, predictions[pid]), cast(torch.Tensor, predictions2[pid]) + ) diff --git a/tests/test_encoders.py b/tests/test_encoders.py index 3edef575..28d7c2c1 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -10,13 +10,13 @@ from huggingface_hub.errors import GatedRepoError from random_data import create_random_dataset, create_random_feature_file, random_string -from stamp.cache import download_file from stamp.encoding import ( EncoderName, init_patient_encoder_, init_slide_encoder_, ) from stamp.preprocessing.config import ExtractorName +from stamp.utils.cache import download_file # Contains an accepted input patch-level feature encoder # TODO: Make a class for each extractor instead of a function. This class diff --git a/tests/test_feature_extractors.py b/tests/test_feature_extractors.py index 699c10f6..6323ee8d 100644 --- a/tests/test_feature_extractors.py +++ b/tests/test_feature_extractors.py @@ -7,8 +7,8 @@ import torch from huggingface_hub.errors import GatedRepoError -from stamp.cache import download_file from stamp.preprocessing import ExtractorName, Microns, TilePixels, extract_ +from stamp.utils.cache import download_file @pytest.mark.slow diff --git a/tests/test_mics.py b/tests/test_mics.py new file mode 100644 index 00000000..faacbe0a --- /dev/null +++ b/tests/test_mics.py @@ -0,0 +1,161 @@ +from pathlib import Path + +import h5py +import numpy as np +import torch + +from typing import cast, Sequence + +from stamp.modeling.config import TrainConfig +from stamp.modeling.data import ( + PatientData, + _log_patient_slide_feature_inconsistencies, + create_dataloader, +) +from stamp.modeling.models.__init__ import LitSlideClassifier +from stamp.modeling.models.mlp import MLP +from stamp.modeling.registry import ModelName, load_model_class +from stamp.types import FeaturePath +from stamp.utils.config import StampConfig + + +def _make_h5(path: Path, vec: np.ndarray) -> None: + with h5py.File(path, "w") as h5: + h5.create_dataset("feats", data=vec) + + +def test_create_dataloader_infers_categories_patient(tmp_path: Path) -> None: + # create two patient-level feature files + p1 = tmp_path / "p1.h5" + p2 = tmp_path / "p2.h5" + _make_h5(p1, np.zeros((1, 4), dtype=np.float32)) + _make_h5(p2, np.zeros((1, 4), dtype=np.float32)) + + pd1 = PatientData(ground_truth="A", feature_files=[FeaturePath(p1)]) + pd2 = PatientData(ground_truth="B", feature_files=[FeaturePath(p2)]) + + dl, cats = create_dataloader( + feature_type="patient", + task="classification", + patient_data=[pd1, pd2], + bag_size=None, + batch_size=1, + shuffle=False, + num_workers=0, + transform=None, + categories=None, + ) + + assert set(cats) == {"A", "B"} + + +def test_create_dataloader_survival_tuple_and_legacy_string(tmp_path: Path) -> None: + p1 = tmp_path / "s1.h5" + p2 = tmp_path / "s2.h5" + _make_h5(p1, np.zeros((1, 3), dtype=np.float32)) + _make_h5(p2, np.zeros((1, 3), dtype=np.float32)) + + # tuple (time, event) and legacy string (will be parsed defensively) + pd1 = PatientData(ground_truth=(5.0, 1), feature_files=[FeaturePath(p1)]) + pd2 = PatientData(ground_truth="7 0", feature_files=[FeaturePath(p2)]) + + dl, cats = create_dataloader( + feature_type="patient", + task="survival", + patient_data=cast(Sequence[PatientData], [pd1, pd2]), + bag_size=None, + batch_size=2, + shuffle=False, + num_workers=0, + transform=None, + categories=None, + ) + + batch_feats, labels = next(iter(dl)) + # labels: shape (2,2) -> (time, event) + assert labels.shape[0] == 2 + # first sample is (5.0, 1) + assert torch.isclose(labels[0, 0], torch.tensor(5.0)) + assert labels[0, 1] == 1.0 + # legacy string parsing yields an event token parsed to 0 (censored) + assert labels[1, 1] == 0.0 + + +def test_predict_dtype_casting_no_error() -> None: + # Build a minimal LitSlideClassifier with MLP backbone + categories = ["A", "B"] + category_weights = torch.tensor([1.0, 1.0], dtype=torch.float32) + + model = LitSlideClassifier( + model_class=MLP, + ground_truth_label="y", + categories=categories, + category_weights=category_weights, + dim_input=4, + # provide backbone hyperparams + dim_hidden=8, + num_layers=2, + dropout=0.1, + # Base required args + total_steps=10, + max_lr=1e-3, + div_factor=25.0, + train_patients=[], + valid_patients=[], + ) + + # force model params to half precision to simulate mixed-precision checkpoints + model.model.half() + + feats = torch.rand((1, 4), dtype=torch.float32) + labels = torch.tensor([[0.0, 1.0]], dtype=torch.float32) + + # predict_step should cast feats to model dtype (half) and not raise + out = model.predict_step((feats, labels), 0) + assert isinstance(out, torch.Tensor) + assert out.dtype == next(model.model.parameters()).dtype + + +def test_log_missing_feature_filenames(caplog, tmp_path: Path) -> None: + missing = tmp_path / "missing.h5" + # slide_to_patient maps FeaturePath -> patient id + slide_to_patient = {FeaturePath(missing): "P1"} + patient_to_ground_truth = {"P1": "A"} + + with caplog.at_level("WARNING"): + _log_patient_slide_feature_inconsistencies( + patient_to_ground_truth=patient_to_ground_truth, + slide_to_patient=slide_to_patient, + ) + + assert "some feature files could not be found" in caplog.text + # ensure only filename is logged (not full path) + assert "missing.h5" in caplog.text + + +def test_model_registry_returns_classes() -> None: + LitClass, ModelClass = load_model_class("classification", "patient", ModelName.MLP) + assert callable(LitClass) + assert callable(ModelClass) + + +def test_stampconfig_training_task_default() -> None: + cfg = StampConfig.model_validate( + { + "training": { + "output_dir": "out", + "clini_table": "cl.csv", + "feature_dir": "feats", + "ground_truth_label": "gt", + } + } + ) + assert cfg.training is not None + assert ( + cfg.training.task + == TrainConfig( + output_dir=Path("out"), + clini_table=Path("cl.csv"), + feature_dir=Path("feats"), + ).task + ) diff --git a/tests/test_model.py b/tests/test_model.py index 1aa6d80a..f74e0a99 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,5 +1,6 @@ import torch +from stamp.modeling.models.barspoon import EncDecTransformer from stamp.modeling.models.mlp import MLP from stamp.modeling.models.trans_mil import TransMIL from stamp.modeling.models.vision_tranformer import VisionTransformer @@ -162,3 +163,114 @@ def test_trans_mil_inference_reproducibility( ) assert logits1.allclose(logits2) + + +def test_enc_dec_transformer_dims( + batch_size: int = 6, + n_tiles: int = 75, + input_dim: int = 456, + d_model: int = 128, +) -> None: + target_n_outs = {"subtype": 3, "grade": 4} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=256, + positional_encoding=True, + ) + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + logits = model.forward(bags, coords) + + assert set(logits.keys()) == set(target_n_outs.keys()) + for target_label, n_out in target_n_outs.items(): + assert logits[target_label].shape == (batch_size, n_out) + + +def test_enc_dec_transformer_single_target( + batch_size: int = 4, + n_tiles: int = 50, + input_dim: int = 256, + d_model: int = 64, +) -> None: + target_n_outs = {"label": 5} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=1, + num_decoder_layers=1, + dim_feedforward=128, + ) + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + logits = model.forward(bags, coords) + + assert list(logits.keys()) == ["label"] + assert logits["label"].shape == (batch_size, 5) + + +def test_enc_dec_transformer_no_positional_encoding( + batch_size: int = 4, + n_tiles: int = 30, + input_dim: int = 128, + d_model: int = 64, +) -> None: + target_n_outs = {"a": 2, "b": 3} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=1, + num_decoder_layers=1, + dim_feedforward=128, + positional_encoding=False, + ) + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + logits = model.forward(bags, coords) + + for target_label, n_out in target_n_outs.items(): + assert logits[target_label].shape == (batch_size, n_out) + + +def test_enc_dec_transformer_inference_reproducibility( + batch_size: int = 5, + n_tiles: int = 40, + input_dim: int = 200, + d_model: int = 64, +) -> None: + target_n_outs = {"subtype": 3, "grade": 4} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=128, + ) + model = model.eval() + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + + with torch.inference_mode(): + logits1 = model.forward(bags, coords) + logits2 = model.forward(bags, coords) + + for target_label in target_n_outs: + assert logits1[target_label].allclose(logits2[target_label]) diff --git a/tests/test_statistics.py b/tests/test_statistics.py index 790b98ab..b600d7e7 100644 --- a/tests/test_statistics.py +++ b/tests/test_statistics.py @@ -2,6 +2,7 @@ from pathlib import Path import numpy as np +import pandas as pd import torch from random_data import random_patient_preds, random_string @@ -47,3 +48,153 @@ def test_statistics_integration( def test_statistics_integration_for_multiple_patient_preds(tmp_path: Path) -> None: return test_statistics_integration(tmp_path=tmp_path, n_patient_preds=5) + + +def test_statistics_survival_integration( + *, + tmp_path: Path, + n_folds: int = 1, + n_patients: int = 200, +) -> None: + """Check that survival statistics run without crashing.""" + random.seed(0) + np.random.seed(0) + + for fold_i in range(n_folds): + times = np.random.uniform(30, 2000, size=n_patients) + statuses = np.random.choice([0, 1], size=n_patients, p=[0.3, 0.7]) + risks = np.random.randn(n_patients) + df = pd.DataFrame( + { + "patient": [random_string(8) for _ in range(n_patients)], + "day": times, + "status": statuses, + "pred_score": risks, + } + ) + df.to_csv(tmp_path / f"survival-preds-{fold_i}.csv", index=False) + + compute_stats_( + task="survival", + output_dir=tmp_path / "output", + pred_csvs=[tmp_path / f"survival-preds-{i}.csv" for i in range(n_folds)], + time_label="day", + status_label="status", + ) + + assert (tmp_path / "output" / "survival-stats_individual.csv").is_file() + + +def test_statistics_survival_integration_multiple_folds(tmp_path: Path) -> None: + return test_statistics_survival_integration(tmp_path=tmp_path, n_folds=5) + + +def test_statistics_regression_integration( + *, + tmp_path: Path, + n_folds: int = 1, + n_patients: int = 200, +) -> None: + """Check that regression statistics run without crashing.""" + random.seed(0) + np.random.seed(0) + + for fold_i in range(n_folds): + y_true = np.random.uniform(0, 100, size=n_patients) + y_pred = y_true + np.random.randn(n_patients) * 10 # noisy predictions + df = pd.DataFrame( + { + "patient": [random_string(8) for _ in range(n_patients)], + "target": y_true, + "pred": y_pred, + } + ) + df.to_csv(tmp_path / f"regression-preds-{fold_i}.csv", index=False) + + compute_stats_( + task="regression", + output_dir=tmp_path / "output", + pred_csvs=[tmp_path / f"regression-preds-{i}.csv" for i in range(n_folds)], + ground_truth_label="target", + ) + + assert (tmp_path / "output" / "target_regression-stats_individual.csv").is_file() + assert (tmp_path / "output" / "target_regression-stats_aggregated.csv").is_file() + + +def test_statistics_regression_integration_multiple_folds(tmp_path: Path) -> None: + return test_statistics_regression_integration(tmp_path=tmp_path, n_folds=5) + + +def test_statistics_multi_target_classification_integration( + *, + tmp_path: Path, + n_patient_preds: int = 1, +) -> None: + """Check that multi-target classification statistics run without crashing. + + Multi-target predictions produce separate ground-truth columns per target. + We run compute_stats_ once per target, as the statistics pipeline handles + one target at a time. + """ + random.seed(0) + np.random.seed(0) + torch.random.manual_seed(0) + + categories_per_target = {"subtype": ["A", "B"], "grade": ["1", "2", "3"]} + + for pred_i in range(n_patient_preds): + n_patients = random.randint(100, 500) + data: dict[str, list] = { + "patient": [random_string(8) for _ in range(n_patients)], + } + + for target_label, cats in categories_per_target.items(): + data[target_label] = [random.choice(cats) for _ in range(n_patients)] + probs = torch.softmax(torch.rand(len(cats), n_patients), dim=0) + for j, cat in enumerate(cats): + data[f"{target_label}_{cat}"] = probs[j].tolist() + + pd.DataFrame(data).to_csv( + tmp_path / f"multi-target-preds-{pred_i}.csv", index=False + ) + + # Run statistics per target (as the pipeline would do) + for target_label, cats in categories_per_target.items(): + true_class = cats[0] + compute_stats_( + task="classification", + output_dir=tmp_path / "output" / target_label, + pred_csvs=[ + tmp_path / f"multi-target-preds-{i}.csv" for i in range(n_patient_preds) + ], + ground_truth_label=target_label, + true_class=true_class, + ) + + assert ( + tmp_path + / "output" + / target_label + / f"{target_label}_categorical-stats_aggregated.csv" + ).is_file() + assert ( + tmp_path + / "output" + / target_label + / f"roc-curve_{target_label}={true_class}.svg" + ).is_file() + assert ( + tmp_path + / "output" + / target_label + / f"pr-curve_{target_label}={true_class}.svg" + ).is_file() + + +def test_statistics_multi_target_classification_multiple_preds( + tmp_path: Path, +) -> None: + return test_statistics_multi_target_classification_integration( + tmp_path=tmp_path, n_patient_preds=3 + ) diff --git a/tests/test_train_deploy.py b/tests/test_train_deploy.py index 0180d171..a21af576 100644 --- a/tests/test_train_deploy.py +++ b/tests/test_train_deploy.py @@ -8,6 +8,7 @@ import torch from random_data import ( create_random_dataset, + create_random_multi_target_dataset, create_random_patient_level_dataset, create_random_patient_level_survival_dataset, create_random_regression_dataset, @@ -22,8 +23,9 @@ VitModelParams, ) from stamp.modeling.deploy import deploy_categorical_model_ +from stamp.modeling.registry import ModelName from stamp.modeling.train import train_categorical_model_ -from stamp.seed import Seed +from stamp.utils.seed import Seed @pytest.mark.slow @@ -52,20 +54,20 @@ def test_train_deploy_integration( create_random_dataset( dir=tmp_path / "train", n_categories=3, - n_patients=400, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=30, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) ) deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = create_random_dataset( dir=tmp_path / "deploy", categories=categories, - n_patients=50, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=10, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) @@ -83,7 +85,7 @@ def test_train_deploy_integration( advanced = AdvancedConfig( # Dataset and -loader parameters - bag_size=500, + bag_size=32, num_workers=min(os.cpu_count() or 1, 16), # Training paramenters batch_size=8, @@ -123,6 +125,9 @@ def test_train_deploy_integration( pytest.param(False, True, id="use vary_precision_transform"), ], ) +@pytest.mark.filterwarnings( + "ignore:.*violates type hint.*not instance of tuple:UserWarning" +) def test_train_deploy_patient_level_integration( *, tmp_path: Path, @@ -137,7 +142,7 @@ def test_train_deploy_patient_level_integration( create_random_patient_level_dataset( dir=tmp_path / "train", n_categories=3, - n_patients=400, + n_patients=30, feat_dim=feat_dim, ) ) @@ -145,7 +150,7 @@ def test_train_deploy_patient_level_integration( create_random_patient_level_dataset( dir=tmp_path / "deploy", categories=categories, - n_patients=50, + n_patients=10, feat_dim=feat_dim, ) ) @@ -213,20 +218,20 @@ def test_train_deploy_regression_integration( train_clini_path, train_slide_path, train_feature_dir, _ = ( create_random_regression_dataset( dir=tmp_path / "train", - n_patients=400, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=30, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) ) deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( create_random_regression_dataset( dir=tmp_path / "deploy", - n_patients=50, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=10, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) ) @@ -245,9 +250,9 @@ def test_train_deploy_regression_integration( ) advanced = AdvancedConfig( - bag_size=500, + bag_size=32, num_workers=min(os.cpu_count() or 1, 16), - batch_size=1, + batch_size=8, max_epochs=2, patience=1, accelerator="gpu" if torch.cuda.is_available() else "cpu", @@ -292,20 +297,20 @@ def test_train_deploy_survival_integration( train_clini_path, train_slide_path, train_feature_dir, _ = ( create_random_survival_dataset( dir=tmp_path / "train", - n_patients=400, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=30, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) ) deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( create_random_survival_dataset( dir=tmp_path / "deploy", - n_patients=50, - max_slides_per_patient=3, - min_tiles_per_slide=20, - max_tiles_per_slide=600, + n_patients=10, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, feat_dim=feat_dim, ) ) @@ -324,7 +329,7 @@ def test_train_deploy_survival_integration( ) advanced = AdvancedConfig( - bag_size=500, + bag_size=32, num_workers=min(os.cpu_count() or 1, 16), batch_size=8, max_epochs=2, @@ -356,6 +361,9 @@ def test_train_deploy_survival_integration( @pytest.mark.slow +@pytest.mark.filterwarnings( + "ignore:.*violates type hint.*not instance of tuple:UserWarning" +) def test_train_deploy_patient_level_regression_integration( *, tmp_path: Path, @@ -377,7 +385,7 @@ def test_train_deploy_patient_level_regression_integration( train_feat_dir.mkdir(parents=True, exist_ok=True) deploy_feat_dir.mkdir(parents=True, exist_ok=True) - n_train, n_deploy = 300, 60 + n_train, n_deploy = 30, 10 train_rows, deploy_rows = [], [] # --- Generate random patient-level features and numeric targets --- @@ -465,6 +473,9 @@ def test_train_deploy_patient_level_regression_integration( @pytest.mark.slow +@pytest.mark.filterwarnings( + "ignore:.*violates type hint.*not instance of tuple:UserWarning" +) def test_train_deploy_patient_level_survival_integration( *, tmp_path: Path, @@ -479,14 +490,14 @@ def test_train_deploy_patient_level_survival_integration( train_clini_path, train_slide_path, train_feature_dir, _ = ( create_random_patient_level_survival_dataset( dir=tmp_path / "train", - n_patients=300, + n_patients=30, feat_dim=feat_dim, ) ) deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( create_random_patient_level_survival_dataset( dir=tmp_path / "deploy", - n_patients=60, + n_patients=10, feat_dim=feat_dim, ) ) @@ -531,3 +542,89 @@ def test_train_deploy_patient_level_survival_integration( accelerator="gpu" if torch.cuda.is_available() else "cpu", num_workers=min(os.cpu_count() or 1, 16), ) + + +@pytest.mark.slow +@pytest.mark.filterwarnings("ignore:No positive samples in targets") +def test_train_deploy_multi_target_integration( + *, + tmp_path: Path, + feat_dim: int = 25, +) -> None: + """Integration test: train + deploy a multi-target tile-level classification model.""" + Seed.set(42) + + (tmp_path / "train").mkdir() + (tmp_path / "deploy").mkdir() + + # Define multi-target setup: subtype (2 categories) and grade (3 categories) + target_labels = ["subtype", "grade"] + categories_per_target = [["A", "B"], ["1", "2", "3"]] + + # Create random multi-target tile-level dataset + train_clini_path, train_slide_path, train_feature_dir, _ = ( + create_random_multi_target_dataset( + dir=tmp_path / "train", + n_patients=30, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, + feat_dim=feat_dim, + target_labels=target_labels, + categories_per_target=categories_per_target, + ) + ) + deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( + create_random_multi_target_dataset( + dir=tmp_path / "deploy", + n_patients=10, + max_slides_per_patient=2, + min_tiles_per_slide=8, + max_tiles_per_slide=32, + feat_dim=feat_dim, + target_labels=target_labels, + categories_per_target=categories_per_target, + ) + ) + + # Build config objects + config = TrainConfig( + task="classification", + clini_table=train_clini_path, + slide_table=train_slide_path, + feature_dir=train_feature_dir, + output_dir=tmp_path / "train_output", + patient_label="patient", + ground_truth_label=target_labels, + filename_label="slide_path", + categories=[cat for cats in categories_per_target for cat in cats], + ) + + advanced = AdvancedConfig( + bag_size=32, + num_workers=min(os.cpu_count() or 1, 16), + batch_size=8, + max_epochs=2, + patience=1, + accelerator="gpu" if torch.cuda.is_available() else "cpu", + model_params=ModelParams(), + model_name=ModelName.BARSPOON, + ) + + # Train + deploy multi-target model + train_categorical_model_(config=config, advanced=advanced) + + deploy_categorical_model_( + output_dir=tmp_path / "deploy_output", + checkpoint_paths=[tmp_path / "train_output" / "model.ckpt"], + clini_table=deploy_clini_path, + slide_table=deploy_slide_path, + feature_dir=deploy_feature_dir, + patient_label="patient", + ground_truth_label=target_labels, + time_label=None, + status_label=None, + filename_label="slide_path", + accelerator="gpu" if torch.cuda.is_available() else "cpu", + num_workers=min(os.cpu_count() or 1, 16), + ) diff --git a/uv.lock b/uv.lock index 96b4b73a..4bd15ef6 100644 --- a/uv.lock +++ b/uv.lock @@ -3699,7 +3699,7 @@ wheels = [ [[package]] name = "stamp" -version = "2.4.0" +version = "2.4.1" source = { editable = "." } dependencies = [ { name = "beartype" },