diff --git a/getting-started.md b/getting-started.md index 6d5bffec..d1e66688 100644 --- a/getting-started.md +++ b/getting-started.md @@ -58,7 +58,9 @@ Stamp currently supports the following feature extractors: - [mSTAR][mstar] - [MUSK][musk] - [PLIP][plip] + - [KEEP][keep] - [TICON][ticon] + - [RedDino][reddino] As some of the above require you to request access to the model on huggingface, @@ -154,12 +156,14 @@ meaning ignored that it was ignored during feature extraction. [mstar]: https://huggingface.co/Wangyh/mSTAR [musk]: https://huggingface.co/xiangjx/musk [plip]: https://github.com/PathologyFoundation/plip +[keep]: https://loiesun.github.io/keep/ "A Knowledge-enhanced Pathology Vision-language Foundation Model for Cancer Diagnosis" [TITAN]: https://huggingface.co/MahmoodLab/TITAN [COBRA2]: https://huggingface.co/KatherLab/COBRA [EAGLE]: https://github.com/KatherLab/EAGLE [MADELEINE]: https://huggingface.co/MahmoodLab/madeleine [PRISM]: https://huggingface.co/paige-ai/Prism [TICON]: https://cvlab-stonybrook.github.io/TICON/ "TICON: A Slide-Level Tile Contextualizer for Histopathology Representation Learning" +[reddino]: https://github.com/Snarci/RedDino "RedDino: A Foundation Model for Red Blood Cell Analysis" @@ -178,8 +182,12 @@ either in excel or `.csv` format, with contents as described below. Finally, `ground_truth_label` needs to contain the column name of the data we want to train our model on. +For single-target classification, use one column name. +For multi-target classification, use a list of column names and set +`advanced_config.model_name: "barspoon"`. Stamp only can be used to train neural networks for categorical targets. -We recommend explicitly setting the possible classes using the `categories` field. +For single-target runs, we recommend explicitly setting the possible classes +using the `categories` field. ```yaml # stamp-test-experiment/config.yaml @@ -206,12 +214,15 @@ crossval: # Name of the column from the clini table to train on. ground_truth_label: "isMSIH" + # For multi-target classification with barspoon: + # ground_truth_label: ["subtype", "grade"] # Optional settings: # The categories occurring in the target label column of the clini table. # If unspecified, they will be inferred from the table itself. categories: ["yes", "no"] + # For multi-target classification, per-target categories are inferred. # Number of folds to split the data into for cross-validation #n_splits: 5 @@ -223,6 +234,7 @@ we can run it by invoking: stamp --config stamp-test-experiment/config.yaml crossval ``` + ## Generating Statistics After training and validating your model, you may want to generate statistics to evaluate its performance. @@ -497,7 +509,7 @@ advanced_config: max_lr: 1e-4 div_factor: 25. # Select a model regardless of task - # Available models are: vit, trans_mil, mlp + # Available models are: vit, trans_mil, mlp, linear, barspoon model_name: "vit" model_params: @@ -514,4 +526,5 @@ STAMP automatically adapts its **model architecture**, **loss function**, and ** **Regression** tasks only require `ground_truth_label`. **Survival analysis** tasks require `time_label` (follow-up time) and `status_label` (event indicator). -These requirements apply consistently across cross-validation, training, deployment, and statistics. \ No newline at end of file +**Multi-target classification** requires `ground_truth_label` as a list and `advanced_config.model_name: "barspoon"`. +These requirements apply consistently across cross-validation, training, deployment, and statistics. diff --git a/pyproject.toml b/pyproject.toml index fc79a26b..ce840af8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "stamp" -version = "2.4.0" +version = "2.5.0" authors = [ { name = "Omar El Nahhas", email = "omar.el_nahhas@tu-dresden.de" }, { name = "Marko van Treeck", email = "markovantreeck@gmail.com" }, diff --git a/src/stamp/__main__.py b/src/stamp/__main__.py index ffa98bae..0b252d6f 100755 --- a/src/stamp/__main__.py +++ b/src/stamp/__main__.py @@ -6,15 +6,6 @@ import yaml -from stamp.modeling.config import ( - AdvancedConfig, - MlpModelParams, - ModelParams, - VitModelParams, -) -from stamp.utils.config import StampConfig -from stamp.utils.seed import Seed - STAMP_FACTORY_SETTINGS = Path(__file__).with_name("config.yaml") # Set up the logger @@ -41,23 +32,38 @@ def _create_config_file(config_file: Path) -> None: def _run_cli(args: argparse.Namespace) -> None: - # Handle init command + # Handle init command before any stamp-internal imports so that + # `stamp init` and `stamp --help` don't pay the full torch/pydantic + # import cost. if args.command == "init": _create_config_file(args.config_file_path) return + # Deferred imports: only reached for real commands, not --help / init. + from stamp.modeling.config import ( + AdvancedConfig, + MlpModelParams, + ModelParams, + VitModelParams, + ) + from stamp.utils.config import StampConfig + from stamp.utils.seed import Seed + # Load YAML configuration with open(args.config_file_path, "r") as config_yaml: config = StampConfig.model_validate(yaml.safe_load(config_yaml)) - # use default advanced config in case none is provided - if config.advanced_config is None: - config.advanced_config = AdvancedConfig( - model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams()), - ) + # Only build a default AdvancedConfig (with model-params) for commands + # that actually use it. Preprocess / encode / statistics / heatmaps + # never touch config.advanced_config, so don't pay the construction cost. + if args.command in {"train", "crossval"}: + if config.advanced_config is None: + config.advanced_config = AdvancedConfig( + model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams()), + ) - # Set global random seed - if config.advanced_config.seed is not None: + # Apply the global seed for any command that has one configured. + if config.advanced_config is not None and config.advanced_config.seed is not None: Seed.set(config.advanced_config.seed) match args.command: @@ -153,6 +159,7 @@ def _run_cli(args: argparse.Namespace) -> None: if config.training.task is None: raise ValueError("task must be set in training configuration") + assert config.advanced_config is not None # guaranteed above for "train" train_categorical_model_( config=config.training, advanced=config.advanced_config ) @@ -198,6 +205,7 @@ def _run_cli(args: argparse.Namespace) -> None: f"{yaml.dump(config.crossval.model_dump(mode='json', exclude_none=True))}" ) + assert config.advanced_config is not None # guaranteed above for "crossval" categorical_crossval_( config=config.crossval, advanced=config.advanced_config, diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index 71c54b74..900b40de 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -4,7 +4,7 @@ preprocessing: # Extractor to use for feature extractor. Possible options are "ctranspath", # "uni", "conch", "chief-ctranspath", "conch1_5", "uni2", "dino-bloom", # "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow", - # "virchow-full", "musk", "mstar", "plip", "ticon" + # "virchow-full", "musk", "mstar", "plip", "ticon", "red-dino", "keep" # Some of them require requesting access to the respective authors beforehand. extractor: "chief-ctranspath" @@ -327,7 +327,7 @@ advanced_config: max_lr: 1e-4 div_factor: 25. # Select a model regardless of task - model_name: "vit" # or mlp, trans_mil, barspoon + model_name: "vit" # or mlp, linear, trans_mil, barspoon model_params: vit: # Vision Transformer @@ -357,4 +357,4 @@ advanced_config: num_encoder_layers: 2 num_decoder_layers: 2 dim_feedforward: 2048 - positional_encoding: true \ No newline at end of file + positional_encoding: true diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index 4720ef9b..3b4c3ac4 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -62,7 +62,8 @@ def encode_slides_( if self.precision == torch.float16: self.model.half() - for tile_feats_filename in (progress := tqdm(os.listdir(feat_dir))): + h5_files = sorted(f for f in os.listdir(feat_dir) if f.endswith(".h5")) + for tile_feats_filename in (progress := tqdm(h5_files)): h5_path = os.path.join(feat_dir, tile_feats_filename) slide_name: str = Path(tile_feats_filename).stem progress.set_description(slide_name) @@ -185,7 +186,8 @@ def _read_h5( raise ValueError(f"File is not of type .h5: {os.path.basename(h5_path)}") with h5py.File(h5_path, "r") as f: feats_ds = cast(h5py.Dataset, f["feats"]) - feats: Tensor = torch.tensor(feats_ds[:], dtype=self.precision) + # torch.from_numpy avoids a redundant data copy vs torch.tensor(array) + feats: Tensor = torch.from_numpy(feats_ds[()]).to(dtype=self.precision) coords: CoordsInfo = get_coords(f) extractor: str = f.attrs.get("extractor", "") if extractor == "": diff --git a/src/stamp/encoding/encoder/gigapath.py b/src/stamp/encoding/encoder/gigapath.py index 4c0a2f6b..09688ad0 100644 --- a/src/stamp/encoding/encoder/gigapath.py +++ b/src/stamp/encoding/encoder/gigapath.py @@ -31,7 +31,10 @@ class Gigapath(Encoder): def __init__(self) -> None: try: model = slide_encoder.create_model( - "hf_hub:prov-gigapath/prov-gigapath", "gigapath_slide_enc12l768d", 1536 + "hf_hub:prov-gigapath/prov-gigapath", + "gigapath_slide_enc12l768d", + 1536, + global_pool=True, ) except AssertionError: raise ModuleNotFoundError( @@ -51,20 +54,9 @@ def _generate_slide_embedding( if not coords: raise ValueError("Tile coords are required for encoding") - # Calculate slide dimensions - slide_width = max(coords.coords_um[:, 0]) + coords.tile_size_um - slide_height = max(coords.coords_um[:, 1]) + coords.tile_size_um - - # Normalize coordinates to a [0, 1000] grid - n_grid = 1000 - norm_coords = self._convert_coords( - coords.coords_um, slide_width, slide_height, n_grid, current_x_offset=0 - ) + coords_px = coords.coords_um / coords.mpp norm_coords = ( - torch.tensor(norm_coords, dtype=torch.float32) - .unsqueeze(0) - .to(device) - .half() + torch.tensor(coords_px, dtype=torch.float32).unsqueeze(0).to(device).half() ) feats = feats.unsqueeze(0).half().to(device) @@ -119,8 +111,6 @@ def encode_patients_( all_feats_list = [] all_coords_list = [] - total_wsi_width = 0 - max_wsi_height = 0 slides_mpp = SlideMPP(-1) slide_info = [] @@ -151,31 +141,20 @@ def encode_patients_( ) wsi_width = max(coords.coords_um[:, 0]) + coords.tile_size_um - wsi_height = max(coords.coords_um[:, 1]) + coords.tile_size_um - - total_wsi_width += wsi_width # Sum the widths of all slides - max_wsi_height = max(max_wsi_height, wsi_height) # Track the max height - - slide_info.append((wsi_width, wsi_height, feats, coords)) + slide_info.append((wsi_width, feats, coords)) current_x_offset = 0 - for wsi_width, wsi_height, feats, coords in slide_info: - norm_coords = self._convert_coords( - coords=coords.coords_um, - total_wsi_width=total_wsi_width, - max_wsi_height=max_wsi_height, - n_grid=1000, - current_x_offset=current_x_offset, - ) + for wsi_width, feats, coords in slide_info: + offset_coords_um = coords.coords_um.copy() + offset_coords_um[:, 0] += current_x_offset - # Update x-coordinates by shifting them based on the current_x_offset - current_x_offset += ( - wsi_width # Move the x_offset forward for the next slide - ) + current_x_offset += wsi_width + + coords_px = offset_coords_um / coords.mpp norm_coords = ( - torch.tensor(norm_coords, dtype=torch.float32) + torch.tensor(coords_px, dtype=torch.float32) .unsqueeze(0) .to(device) .half() @@ -211,26 +190,3 @@ def _generate_patient_embedding( patient_embedding = torch.cat(patient_embedding, dim=0) return patient_embedding.detach().squeeze().cpu().numpy() - - def _convert_coords( - self, - coords, - total_wsi_width, - max_wsi_height, - n_grid, - current_x_offset, - ) -> np.ndarray: - """ - Normalize the x and y coordinates relative to the total WSI width and max height, using the same grid [0, 1000]. - Thanks Peter! - """ - # Normalize x-coordinates based on total WSI width (taking into account the current x offset) - normalized_x = (coords[:, 0] + current_x_offset) / total_wsi_width * n_grid - - # Normalize y-coordinates based on the maximum WSI height - normalized_y = coords[:, 1] / max_wsi_height * n_grid - - # Stack normalized x and y coordinates - converted_coords = np.stack([normalized_x, normalized_y], axis=-1) - - return np.array(converted_coords, dtype=np.float32) diff --git a/src/stamp/heatmaps/__init__.py b/src/stamp/heatmaps/__init__.py index fb704fe6..b903ef89 100644 --- a/src/stamp/heatmaps/__init__.py +++ b/src/stamp/heatmaps/__init__.py @@ -330,7 +330,7 @@ def heatmaps_( # TODO: Update version when a newer model logic breaks heatmaps. stamp_version = str(getattr(model, "stamp_version", "")) - if Version(stamp_version) < Version("2.4.0"): + if Version(stamp_version) < Version("2.5.0"): raise ValueError( f"model has been built with stamp version {stamp_version} " f"which is incompatible with the current version." diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index 26196065..4ee71563 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -336,7 +336,7 @@ def categorical_crossval_( ).to_csv(split_dir / "patient-preds.csv", index=False) elif config.task == "regression": if config.ground_truth_label is None: - raise RuntimeError("Grounf truth label is required for regression") + raise RuntimeError("Ground truth label is required for regression") if isinstance(config.ground_truth_label, str): _to_regression_prediction_df( patient_to_ground_truth=cast( @@ -353,7 +353,7 @@ def categorical_crossval_( else: if config.ground_truth_label is None: raise RuntimeError( - "Grounf truth label is required for classification" + "Ground truth label is required for classification" ) _to_prediction_df( categories=categories_for_export, diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index c61a2512..e0444f15 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -280,6 +280,22 @@ def deploy_categorical_model_( Sequence[Category] | Mapping[str, Sequence[Category]] | None ) = cast(Sequence[Category] | Mapping[str, Sequence[Category]] | None, None) for model_i, model in enumerate(models): + # Check for data leakage: if the deployment patient set overlaps with + # the patients used during model training/validation, log a critical + # message. This check is intentionally performed at the deploy level + # (not inside `_predict`) so prediction helpers can be reused without + # side-effects in other contexts (e.g., cross-validation). + patients_used_for_training: set[PatientId] = set( + getattr(model, "train_patients", []) + ) | set(getattr(model, "valid_patients", [])) + if overlap := patients_used_for_training & set(patient_ids): + _logger.critical( + "DATA LEAKAGE DETECTED: %d patient(s) in deployment set were used " + "during training/validation. Overlapping IDs: %s", + len(overlap), + sorted(overlap), + ) + predictions = _predict( model=model, test_dl=test_dl, @@ -374,17 +390,7 @@ def _predict( model = model.eval() torch.set_float32_matmul_precision("medium") - # Check for data leakage - patients_used_for_training: set[PatientId] = set( - getattr(model, "train_patients", []) - ) | set(getattr(model, "valid_patients", [])) - if overlap := patients_used_for_training & set(patient_ids): - _logger.critical( - "DATA LEAKAGE DETECTED: %d patient(s) in deployment set were used " - "during training/validation. Overlapping IDs: %s", - len(overlap), - sorted(overlap), - ) + # Note: data-leakage check intentionally performed at deploy level. trainer = lightning.Trainer( accelerator=accelerator, @@ -659,21 +665,11 @@ def _to_survival_prediction_df( else: row["pred_score"] = pred.cpu().tolist() - # Ground truth: time + event - if gt is not None: - if isinstance(gt, str) and " " in gt: - time_str, status_str = gt.split(" ", 1) - row["time"] = float(time_str) if time_str.lower() != "nan" else None - if status_str.lower() in {"dead", "event", "1"}: - row["event"] = 1 - elif status_str.lower() in {"alive", "censored", "0"}: - row["event"] = 0 - else: - row["event"] = None - elif isinstance(gt, (tuple, list)) and len(gt) == 2: - row["time"], row["event"] = gt - else: - row["time"], row["event"] = None, None + # Ground truth: prefer structured tuple/list (time, event). Do not + # call .split on ground-truth values — assume structured input. If + # the value is not a 2-tuple/list, treat both fields as unknown. + if isinstance(gt, (tuple, list)) and len(gt) == 2: + row["time"], row["event"] = gt else: row["time"], row["event"] = None, None diff --git a/src/stamp/modeling/models/__init__.py b/src/stamp/modeling/models/__init__.py index b5a59b5f..217ef8b2 100644 --- a/src/stamp/modeling/models/__init__.py +++ b/src/stamp/modeling/models/__init__.py @@ -8,11 +8,11 @@ import lightning import numpy as np import torch -from lifelines.utils import concordance_index as lifelines_cindex # Use beartype.typing.Mapping to avoid PEP-585 deprecation warnings in beartype from beartype.typing import Mapping from jaxtyping import Bool, Float +from lifelines.utils import concordance_index as lifelines_cindex from packaging.version import Version from torch import Tensor, nn, optim from torchmetrics.classification import MulticlassAUROC @@ -89,7 +89,7 @@ def __init__( # This should only happen when the model is loaded, # otherwise the default value will make these checks pass. # TODO: Change this on version change - if stamp_version < Version("2.4.0"): + if stamp_version < Version("2.5.0"): # Update this as we change our model in incompatible ways! raise ValueError( f"model has been built with stamp version {stamp_version} " @@ -239,7 +239,7 @@ def forward( def _step( self, *, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], step_name: str, use_mask: bool, ) -> Loss: @@ -280,32 +280,36 @@ def _step( def training_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="training", use_mask=False) def validation_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="validation", use_mask=False) def test_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="test", use_mask=False) def predict_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Float[Tensor, "batch logit"]: bags, coords, bag_sizes, _ = batch # adding a mask here will *drastically* and *unbearably* increase memory usage + # Ensure input dtype matches model weights to avoid dtype-mismatch errors + param_dtype = next(self.model.parameters()).dtype + bags = bags.to(dtype=param_dtype) + coords = coords.to(dtype=param_dtype) return self.model(bags, coords=coords, mask=None) @@ -365,6 +369,9 @@ def predict_step( self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int ) -> Tensor: feats, _ = batch if isinstance(batch, tuple) else batch + # Cast inputs to model parameter dtype to avoid Half/Float mismatches + param_dtype = next(self.model.parameters()).dtype + feats = feats.to(dtype=param_dtype) return self.model(feats) @@ -437,7 +444,7 @@ def forward( def _step( self, *, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], step_name: str, use_mask: bool, ) -> Loss: @@ -477,28 +484,28 @@ def _step( def training_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="training", use_mask=False) def validation_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="validation", use_mask=False) def test_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Loss: return self._step(batch=batch, step_name="test", use_mask=False) def predict_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Float[Tensor, "batch 1"]: bags, coords, bag_sizes, _ = batch @@ -741,7 +748,11 @@ def forward( # (most ViT backbones accept coords/mask even if unused) return self.model(bags, coords=coords, mask=mask) - def training_step(self, batch, batch_idx): + def training_step( + self, + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], + batch_idx: int, + ) -> Loss: bags, coords, bag_sizes, targets = batch preds = self.model(bags, coords=coords, mask=None) y = targets.to(preds.device, dtype=torch.float32) @@ -766,7 +777,7 @@ def training_step(self, batch, batch_idx): def validation_step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets], + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], batch_idx: int, ) -> Any: bags, coords, bag_sizes, targets = batch @@ -780,9 +791,13 @@ def validation_step( self._val_times.append(times.detach().cpu()) self._val_events.append(events.detach().cpu()) - def predict_step(self, batch, batch_idx): - feats, coords, n_tiles, survival_target = batch - return self.model(feats.float(), coords=coords, mask=None) + def predict_step( + self, + batch: tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets] | list[Tensor], + batch_idx: int, + ) -> Float[Tensor, "batch 1"]: + bags, coords, bag_sizes, survival_target = batch + return self.model(bags, coords=coords, mask=None) class LitSlideSurvival(LitSurvivalBase): @@ -858,6 +873,10 @@ def __init__( positional_encoding: bool = True, # Other hparams learning_rate: float = 1e-4, + # Deployment metadata (optional) — keep parity with `Base` + train_patients: Iterable[PatientId] = (), + valid_patients: Iterable[PatientId] = (), + stamp_version: Version = Version(stamp.__version__), **hparams: Any, ) -> None: weights_dict: dict[TargetLabel, torch.Tensor] = dict(category_weights) @@ -905,6 +924,13 @@ def __init__( self.ground_truth_label = ground_truth_label self.categories = normalized_categories + # Deployment metadata — mirror `Base` behavior so checkpoints include + # train/valid patient lists and stamp version for leak-detection and + # compatibility checks. + self.train_patients = train_patients + self.valid_patients = valid_patients + self.stamp_version = str(stamp_version) + self.save_hyperparameters() def forward(self, *args): diff --git a/src/stamp/modeling/models/barspoon.py b/src/stamp/modeling/models/barspoon.py index f841bb3d..92c1a15d 100644 --- a/src/stamp/modeling/models/barspoon.py +++ b/src/stamp/modeling/models/barspoon.py @@ -78,12 +78,12 @@ class tokens, one per output label. Finally, we forward each of the decoded 2. Adding absolute positions to the feature vector, scaled down so the maximum value in the training dataset is 1. - Since neither reduced performance and the author percieves the first one to + Since neither reduced performance and the author perceives the first one to be more elegant (as the magnitude of the positional encodings is bounded), we opted to keep the positional encoding regardless in the hopes of it improving performance on future tasks. - The architecture _differs_ from the one descibed in [Attention Is All You + The architecture _differs_ from the one described in [Attention Is All You Need][1] as follows: 1. There is an initial projection stage to reduce the dimension of the @@ -223,7 +223,7 @@ def __init__( _ = hparams # So we don't get unused parameter warnings # Check if version is compatible. - if stamp_version < Version("2.4.0"): + if stamp_version < Version("2.5.0"): # Update this as we change our model in incompatible ways! raise ValueError( f"model has been built with stamp version {stamp_version} " @@ -261,7 +261,7 @@ def __init__( def step( self, - batch: tuple[Bags, CoordinatesBatch, BagSizes, dict[str, torch.Tensor]], + batch: tuple[Bags, CoordinatesBatch, BagSizes, dict[str, torch.Tensor]] | list, step_name=None, ): """Process a batch with structure (feats, coords, bag_sizes, targets). diff --git a/src/stamp/modeling/models/vision_tranformer.py b/src/stamp/modeling/models/vision_tranformer.py index b936c5c9..fcd60c12 100644 --- a/src/stamp/modeling/models/vision_tranformer.py +++ b/src/stamp/modeling/models/vision_tranformer.py @@ -56,9 +56,7 @@ def forward( Which query-key pairs to mask from ALiBi (i.e. don't apply ALiBi to). """ weight_logits = torch.einsum("bqf,bkf->bqk", q, k) * (k.size(-1) ** -0.5) - distances = torch.linalg.norm( - coords_q.unsqueeze(2) - coords_k.unsqueeze(1), dim=-1 - ) + distances = torch.cdist(coords_q, coords_k) scaled_distances = self.scale_distance(distances) * self.bias_scale if alibi_mask is not None: diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index 7031eed6..a55ec1e3 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -177,13 +177,13 @@ def setup_model_for_training( f"Model '{advanced.model_name.value}' does not support feature type '{feature_type}'. " f"Supported types are: {LitModelClass.supported_features}" ) - elif ( - feature_type in ("slide", "patient") - and advanced.model_name.value.lower() != "mlp" - ): + elif feature_type in ( + "slide", + "patient", + ) and advanced.model_name.value.lower() not in {"mlp", "linear"}: raise ValueError( - f"Feature type '{feature_type}' only supports MLP backbones. " - f"Got '{advanced.model_name.value}'. Please set model_name='mlp'." + f"Feature type '{feature_type}' only supports MLP or Linear. " + f"Got '{advanced.model_name.value}'. Please set model_name='mlp' or 'linear'." ) # 4. Get model-specific hyperparameters diff --git a/src/stamp/preprocessing/__init__.py b/src/stamp/preprocessing/__init__.py index 22f1d90f..b2daa386 100755 --- a/src/stamp/preprocessing/__init__.py +++ b/src/stamp/preprocessing/__init__.py @@ -85,16 +85,9 @@ def __init__( self.canny_cutoff = canny_cutoff self.default_slide_mpp = default_slide_mpp - # Already check if we can extract the MPP here. - # We don't want to kill our dataloader later, - # because that leads to _a lot_ of error messages which are difficult to read - if ( - get_slide_mpp_( - openslide.open_slide(slide_path), default_mpp=default_slide_mpp - ) - is None - ): - raise MPPExtractionError() + # MPP is validated by the caller (extract_()) before constructing this dataset, + # so we no longer open the slide here for a redundant MPP check. + # This removes one openslide.open_slide() call per WSI. def __iter__(self) -> Iterator[tuple[Tensor, Microns, Microns]]: return ( @@ -177,6 +170,11 @@ def extract_( extractor = dino_bloom() + case ExtractorName.RED_DINO: + from stamp.preprocessing.extractor.reddino import red_dino + + extractor = red_dino() + case ExtractorName.VIRCHOW: from stamp.preprocessing.extractor.virchow import virchow @@ -222,6 +220,10 @@ def extract_( extractor = plip() + case ExtractorName.KEEP: + from stamp.preprocessing.extractor.keep import keep + + extractor = keep() case ExtractorName.TICON: from stamp.preprocessing.extractor.ticon import ticon @@ -286,6 +288,15 @@ def extract_( feature_output_path.parent.mkdir(parents=True, exist_ok=True) try: + # Validate MPP here once (avoids a second openslide.open_slide inside _TileDataset.__init__). + if ( + get_slide_mpp_( + openslide.open_slide(slide_path), default_mpp=default_slide_mpp + ) + is None + ): + raise MPPExtractionError() + ds = _TileDataset( slide_path=slide_path, cache_dir=cache_dir, @@ -300,7 +311,15 @@ def extract_( default_slide_mpp=default_slide_mpp, ) # Parallelism is implemented in the dataset iterator already, so one worker is enough! - dl = DataLoader(ds, batch_size=64, num_workers=1, drop_last=False) + # pin_memory speeds up CPU→GPU DMA for tile batches. + # num_workers=1 is intentional: WSI read parallelism is inside _supertiles. + dl = DataLoader( + ds, + batch_size=64, + num_workers=1, + drop_last=False, + pin_memory=torch.cuda.is_available(), + ) feats, xs_um, ys_um = [], [], [] for tiles, xs, ys in tqdm(dl, leave=False): @@ -384,8 +403,9 @@ def _get_rejection_thumb( dtype=bool, ) - for y, x in np.floor(coords_um / tile_size_um).astype(np.uint32): - inclusion_map[y, x] = True + # Vectorized: set all tile positions at once instead of a Python loop. + tile_indices = np.floor(coords_um / tile_size_um).astype(np.uint32) + inclusion_map[tile_indices[:, 0], tile_indices[:, 1]] = True thumb = slide.get_thumbnail(size).convert("RGBA") discarded_im = Image.fromarray( diff --git a/src/stamp/preprocessing/config.py b/src/stamp/preprocessing/config.py index 5eca41dd..e017daf8 100644 --- a/src/stamp/preprocessing/config.py +++ b/src/stamp/preprocessing/config.py @@ -1,7 +1,6 @@ from enum import StrEnum from pathlib import Path -import torch from pydantic import BaseModel, ConfigDict, Field from stamp.types import ImageExtension, Microns, SlideMPP, TilePixels @@ -28,8 +27,10 @@ class ExtractorName(StrEnum): MUSK = "musk" MSTAR = "mstar" PLIP = "plip" + KEEP = "keep" TICON = "ticon" EMPTY = "empty" + RED_DINO = "red-dino" class PreprocessingConfig(BaseModel, arbitrary_types_allowed=True): @@ -46,7 +47,11 @@ class PreprocessingConfig(BaseModel, arbitrary_types_allowed=True): tile_size_px: TilePixels = TilePixels(224) extractor: ExtractorName max_workers: int = 8 - device: str = "cuda" if torch.cuda.is_available() else "cpu" + device: str = Field( + default_factory=lambda: ( + "cuda" if __import__("torch").cuda.is_available() else "cpu" + ) + ) generate_hash: bool = True default_slide_mpp: SlideMPP | None = None diff --git a/src/stamp/preprocessing/extractor/keep.py b/src/stamp/preprocessing/extractor/keep.py new file mode 100644 index 00000000..4adc964e --- /dev/null +++ b/src/stamp/preprocessing/extractor/keep.py @@ -0,0 +1,49 @@ +""" +Adopted from https://github.com/MAGIC-AI4Med/KEEP +KEEP (KnowledgE-Enhanced Pathology) +""" + +try: + import torch + from torchvision import transforms + from transformers import AutoModel +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "keep dependencies not installed." + " Please reinstall stamp using `pip install 'stamp[keep]'`" + ) from e + +from stamp.preprocessing.config import ExtractorName +from stamp.preprocessing.extractor import Extractor + + +class KEEPWrapper(torch.nn.Module): + def __init__(self, model) -> None: + super().__init__() + self.model = model + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + return self.model.encode_image(batch) + + +def keep() -> Extractor[KEEPWrapper]: + """Extracts features from slide tiles using the KEEP tile encoder.""" + model = AutoModel.from_pretrained("Astaxanthin/KEEP", trust_remote_code=True) + model.eval() + + transform = transforms.Compose( + [ + transforms.Resize( + size=224, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.CenterCrop(size=(224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ] + ) + + return Extractor( + model=KEEPWrapper(model), + transform=transform, + identifier=ExtractorName.KEEP, + ) diff --git a/src/stamp/preprocessing/extractor/reddino.py b/src/stamp/preprocessing/extractor/reddino.py new file mode 100644 index 00000000..b370ea2d --- /dev/null +++ b/src/stamp/preprocessing/extractor/reddino.py @@ -0,0 +1,64 @@ +""" +Port from https://github.com/Snarci/RedDino +RedDino: A Foundation Model for Red Blood Cell Analysis +""" + +from typing import Callable, cast + +try: + import timm + import torch + from PIL import Image + from torchvision import transforms +except ModuleNotFoundError as e: + raise ModuleNotFoundError("red-dino dependencies not installed.") from e + +from stamp.preprocessing.config import ExtractorName +from stamp.preprocessing.extractor import Extractor + +__license__ = "MIT" + + +class RedDinoClsOnly(torch.nn.Module): + def __init__(self, model) -> None: + super().__init__() + self.model = model + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + out = self.model(batch) + if isinstance(out, tuple): + out = out[0] + # if model returns tokens, return class token + if getattr(out, "ndim", 0) >= 2 and out.shape[1] > 1: + return out[:, 0] + return out + + +def red_dino() -> Extractor[RedDinoClsOnly]: + """Extracts features from single image using RedDino encoder.""" + + model = timm.create_model( + "hf-hub:Snarcy/RedDino-large", + pretrained=True, + num_classes=0, + pretrained_strict=False, + ) + + transform = cast( + Callable[[Image.Image], torch.Tensor], + transforms.Compose( + [ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ), + ) + + return Extractor( + model=RedDinoClsOnly(model), + transform=transform, + identifier=ExtractorName.RED_DINO, + ) diff --git a/src/stamp/preprocessing/extractor/ticon.py b/src/stamp/preprocessing/extractor/ticon.py index ab7eb829..02aac13e 100644 --- a/src/stamp/preprocessing/extractor/ticon.py +++ b/src/stamp/preprocessing/extractor/ticon.py @@ -624,7 +624,6 @@ def load_ticon(device: str = "cuda") -> nn.Module: class HOptimusTICON(nn.Module): def __init__(self, device: torch.device): super().__init__() - self.device = device # ---------------------------- # Stage 1: H-OptimUS @@ -634,6 +633,8 @@ def __init__(self, device: torch.device): pretrained=True, init_values=1e-5, dynamic_img_size=False, + num_classes=0, + pretrained_strict=False, ) # ---------------------------- @@ -689,7 +690,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: [B, 3, 224, 224] (CPU or CUDA) """ - x = x.to(self.device, non_blocking=True) + # Respect the current module device (it may be moved after construction). + device = next(self.parameters()).device + x = x.to(device, non_blocking=True) # H-Optimus_1 emb = self.tile_encoder(x) # [B, 1536] @@ -700,7 +703,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: emb.size(0), 1, 2, - device=self.device, + device=device, dtype=torch.float32, ) @@ -713,7 +716,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out.squeeze(1) # [B, 1536] -def ticon(device: str = "cuda") -> Extractor[nn.Module]: +def ticon(device: str | None = None) -> Extractor[nn.Module]: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" model = HOptimusTICON(torch.device(device)) transform = transforms.Compose( diff --git a/src/stamp/preprocessing/tiling.py b/src/stamp/preprocessing/tiling.py index ce684ba4..1f560981 100644 --- a/src/stamp/preprocessing/tiling.py +++ b/src/stamp/preprocessing/tiling.py @@ -397,8 +397,10 @@ def _tiles_from_cache_file(cache_file_path: Path) -> Iterator[_Tile]: x_um, y_um = Microns(float(x_um_str)), Microns(float(y_um_str)) with zip_fp.open(name, "r") as tile_fp: + img = Image.open(tile_fp) + img.load() # force eager pixel decode while tile_fp is still open yield _Tile( - image=Image.open(tile_fp), + image=img, coordinates=_XYCoords(x_um, y_um), size=tiler_params["tile_size_um"], ) diff --git a/src/stamp/statistics/categorical.py b/src/stamp/statistics/categorical.py index 9d6c4c12..e19b1659 100755 --- a/src/stamp/statistics/categorical.py +++ b/src/stamp/statistics/categorical.py @@ -53,7 +53,9 @@ def _categorical(preds_df: pd.DataFrame, target_label: str) -> pd.DataFrame: """ categories = preds_df[target_label].unique() y_true = preds_df[target_label] - y_pred = preds_df[[f"{target_label}_{cat}" for cat in categories]].map(float).values + y_pred = ( + preds_df[[f"{target_label}_{cat}" for cat in categories]].astype(float).values + ) stats_df = pd.DataFrame(index=categories) @@ -156,15 +158,19 @@ def categorical_aggregated_multitarget_( all_target_stats = {} + # Read each CSV once and cache so we don't re-read N×M times. + csv_cache: dict[str, pd.DataFrame] = { + Path(p).parent.name: pd.read_csv(p, dtype=str) for p in preds_csvs + } + for target_label in target_labels: # Process each target separately preds_dfs = {} - for p in preds_csvs: - df = pd.read_csv(p, dtype=str) + for fold_name, df in csv_cache.items(): # Drop rows where this target's ground truth is missing df_clean = df.dropna(subset=[target_label]) if len(df_clean) > 0: - preds_dfs[Path(p).parent.name] = _categorical(df_clean, target_label) + preds_dfs[fold_name] = _categorical(df_clean, target_label) if not preds_dfs: continue diff --git a/src/stamp/statistics/prc.py b/src/stamp/statistics/prc.py index 867885e9..dd58ea2e 100755 --- a/src/stamp/statistics/prc.py +++ b/src/stamp/statistics/prc.py @@ -6,7 +6,6 @@ import scipy.stats as st from jaxtyping import Bool, Float from matplotlib.axes import Axes -from scipy.interpolate import interp1d from sklearn.metrics import ( auc, average_precision_score, @@ -56,15 +55,9 @@ def _plot_bootstrapped_pr_curve( continue precision, recall, _ = precision_recall_curve(sample_y_true, sample_y_pred) - # Create an interpolation function with decreasing values - interp_func = interp1d( - recall[::-1], - precision[::-1], - kind="linear", - fill_value=np.nan, - bounds_error=False, - ) - interp_prc = interp_func(interp_recall) + # np.interp requires increasing x; precision_recall_curve returns + # decreasing recall, so reverse both arrays. + interp_prc = np.interp(interp_recall, recall[::-1], precision[::-1]) interp_prcs[i] = interp_prc bootstrapped_auprc = auc(interp_recall, interp_prc) bootstrap_auprcs.append(bootstrapped_auprc) diff --git a/src/stamp/statistics/roc.py b/src/stamp/statistics/roc.py index d42413a4..338cf876 100755 --- a/src/stamp/statistics/roc.py +++ b/src/stamp/statistics/roc.py @@ -180,9 +180,11 @@ def _plot_bootstrapped_roc_curve( # and then sample the bottom 0.025 / top 0.975 quantile point # for each sampled fpr-position rng = np.random.default_rng() - interp_rocs = [] interp_fpr = np.linspace(0, 1, num=1000) + # Pre-allocate; rows that correspond to skipped samples stay NaN. + interp_rocs = np.full((n_bootstrap_samples, len(interp_fpr)), np.nan) bootstrap_aucs: list[float] = [] + valid_row = 0 for _ in trange(n_bootstrap_samples, desc="Bootstrapping ROC curves", leave=False): sample_idxs = rng.choice(len(y_true), len(y_true)) sample_y_true = y_true[sample_idxs] @@ -190,15 +192,17 @@ def _plot_bootstrapped_roc_curve( if len(np.unique(sample_y_true)) != 2: continue fpr, tpr, thresh = roc_curve(sample_y_true, sample_y_score) - interp_rocs.append(np.interp(interp_fpr, fpr, tpr)) + interp_rocs[valid_row] = np.interp(interp_fpr, fpr, tpr) + valid_row += 1 bootstrap_aucs.append(float(roc_auc_score(sample_y_true, sample_y_score))) + interp_rocs = interp_rocs[:valid_row] # trim unused rows roc_lower, roc_upper = cast( tuple[ Float[np.ndarray, "fpr"], # noqa: F821 Float[np.ndarray, "fpr"], # noqa: F821 ], - np.quantile(interp_rocs, [0.025, 0.975], axis=0), + np.nanquantile(interp_rocs, [0.025, 0.975], axis=0), ) ax.fill_between(interp_fpr, roc_lower, roc_upper, alpha=0.5) diff --git a/src/stamp/statistics/survival.py b/src/stamp/statistics/survival.py index 87415a7a..8fbf5d63 100644 --- a/src/stamp/statistics/survival.py +++ b/src/stamp/statistics/survival.py @@ -4,6 +4,7 @@ from pathlib import Path +import lifelines.plotting as lifelines_plotting import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -142,6 +143,10 @@ def _plot_km( if len(high_df) > 0: fitters.append(kmf_high) + # add at-risk table for fitted curves + if len(fitters) > 0: + lifelines_plotting.add_at_risk_counts(*fitters, ax=ax) + # log-rank and c-index res = logrank_test( low_df[time_label], diff --git a/src/stamp/utils/cache.py b/src/stamp/utils/cache.py index d65c7dcd..c0b00a51 100644 --- a/src/stamp/utils/cache.py +++ b/src/stamp/utils/cache.py @@ -9,28 +9,28 @@ STAMP_CACHE_DIR: Final[Path] = ( Path(os.environ.get("XDG_CACHE_HOME") or (Path.home() / ".cache")) / "stamp" ) - -# If we imported this, we probably want to use it, -# so it's okay creating the directory now -STAMP_CACHE_DIR.mkdir(exist_ok=True, parents=True) +# Directory is created on demand (inside functions that write to it) +# so that a bare import of this module does not cause filesystem I/O. def download_file(*, url: str, file_name: str, sha256sum: str) -> Path: - """Downloads a file, or loads it from cache if it has been downloaded before""" + """Downloads a file, or loads it from cache if it has been downloaded before. + + The checksum is only verified on the initial download. Once the file + exists in the cache it is trusted as-is to avoid re-reading large weight + files (which can be ~1 GB) on every run. + """ + STAMP_CACHE_DIR.mkdir(exist_ok=True, parents=True) outfile_path = STAMP_CACHE_DIR / file_name if outfile_path.is_file(): - with open(outfile_path, "rb") as weight_file: - digest = hashlib.file_digest(weight_file, "sha256") - assert digest.hexdigest() == sha256sum, ( - f"{outfile_path} has the wrong checksum. Try deleting it and rerunning this script." - ) - else: - filename, _ = urllib.request.urlretrieve(url) - with open(filename, "rb") as weight_file: - digest = hashlib.file_digest(weight_file, "sha256") - assert digest.hexdigest() == sha256sum, "hash of downloaded file did not match" - shutil.move(filename, outfile_path) - + # File already cached and verified on first download — skip re-hash. + return outfile_path + + filename, _ = urllib.request.urlretrieve(url) + with open(filename, "rb") as weight_file: + digest = hashlib.file_digest(weight_file, "sha256") + assert digest.hexdigest() == sha256sum, "hash of downloaded file did not match" + shutil.move(filename, outfile_path) return outfile_path @@ -40,14 +40,16 @@ def file_digest(file: str | Path) -> str: @cache -def get_processing_code_hash(file_path) -> str: +def get_processing_code_hash(file_path: Path) -> str: """The hash of the entire process codebase. - It is used to assure that features extracted with different versions of this code base - can be identified as such after the fact. + It is used to assure that features extracted with different versions of + this code base can be identified as such after the fact. """ hasher = hashlib.sha256() - for file_path in sorted(file_path.parent.glob("*.py")): - with open(file_path, "rb") as fp: - hasher.update(fp.read()) + for py_file in sorted(file_path.parent.glob("*.py")): + # Use file_digest to stream the file in chunks instead of reading + # the entire source into memory at once. + with open(py_file, "rb") as fp: + hasher.update(hashlib.file_digest(fp, "sha256").digest()) return hasher.hexdigest() diff --git a/uv.lock b/uv.lock index 96b4b73a..beb0a251 100644 --- a/uv.lock +++ b/uv.lock @@ -3699,7 +3699,7 @@ wheels = [ [[package]] name = "stamp" -version = "2.4.0" +version = "2.5.0" source = { editable = "." } dependencies = [ { name = "beartype" },