Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ Stamp currently supports the following feature extractors:
- [mSTAR][mstar]
- [MUSK][musk]
- [PLIP][plip]
- [KEEP][keep]
- [TICON][ticon]
- [RedDino][reddino]


As some of the above require you to request access to the model on huggingface,
Expand Down Expand Up @@ -154,12 +156,14 @@ meaning ignored that it was ignored during feature extraction.
[mstar]: https://huggingface.co/Wangyh/mSTAR
[musk]: https://huggingface.co/xiangjx/musk
[plip]: https://github.com/PathologyFoundation/plip
[keep]: https://loiesun.github.io/keep/ "A Knowledge-enhanced Pathology Vision-language Foundation Model for Cancer Diagnosis"
[TITAN]: https://huggingface.co/MahmoodLab/TITAN
[COBRA2]: https://huggingface.co/KatherLab/COBRA
[EAGLE]: https://github.com/KatherLab/EAGLE
[MADELEINE]: https://huggingface.co/MahmoodLab/madeleine
[PRISM]: https://huggingface.co/paige-ai/Prism
[TICON]: https://cvlab-stonybrook.github.io/TICON/ "TICON: A Slide-Level Tile Contextualizer for Histopathology Representation Learning"
[reddino]: https://github.com/Snarci/RedDino "RedDino: A Foundation Model for Red Blood Cell Analysis"



Expand All @@ -178,8 +182,12 @@ either in excel or `.csv` format,
with contents as described below.
Finally, `ground_truth_label` needs to contain the column name
of the data we want to train our model on.
For single-target classification, use one column name.
For multi-target classification, use a list of column names and set
`advanced_config.model_name: "barspoon"`.
Stamp only can be used to train neural networks for categorical targets.
We recommend explicitly setting the possible classes using the `categories` field.
For single-target runs, we recommend explicitly setting the possible classes
using the `categories` field.

```yaml
# stamp-test-experiment/config.yaml
Expand All @@ -206,12 +214,15 @@ crossval:

# Name of the column from the clini table to train on.
ground_truth_label: "isMSIH"
# For multi-target classification with barspoon:
# ground_truth_label: ["subtype", "grade"]

# Optional settings:

# The categories occurring in the target label column of the clini table.
# If unspecified, they will be inferred from the table itself.
categories: ["yes", "no"]
# For multi-target classification, per-target categories are inferred.

# Number of folds to split the data into for cross-validation
#n_splits: 5
Expand All @@ -223,6 +234,7 @@ we can run it by invoking:
stamp --config stamp-test-experiment/config.yaml crossval
```


## Generating Statistics

After training and validating your model, you may want to generate statistics to evaluate its performance.
Expand Down Expand Up @@ -497,7 +509,7 @@ advanced_config:
max_lr: 1e-4
div_factor: 25.
# Select a model regardless of task
# Available models are: vit, trans_mil, mlp
# Available models are: vit, trans_mil, mlp, linear, barspoon
model_name: "vit"

model_params:
Expand All @@ -514,4 +526,5 @@ STAMP automatically adapts its **model architecture**, **loss function**, and **

**Regression** tasks only require `ground_truth_label`.
**Survival analysis** tasks require `time_label` (follow-up time) and `status_label` (event indicator).
These requirements apply consistently across cross-validation, training, deployment, and statistics.
**Multi-target classification** requires `ground_truth_label` as a list and `advanced_config.model_name: "barspoon"`.
These requirements apply consistently across cross-validation, training, deployment, and statistics.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "stamp"
version = "2.4.0"
version = "2.5.0"
authors = [
{ name = "Omar El Nahhas", email = "omar.el_nahhas@tu-dresden.de" },
{ name = "Marko van Treeck", email = "markovantreeck@gmail.com" },
Expand Down
42 changes: 25 additions & 17 deletions src/stamp/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,6 @@

import yaml

from stamp.modeling.config import (
AdvancedConfig,
MlpModelParams,
ModelParams,
VitModelParams,
)
from stamp.utils.config import StampConfig
from stamp.utils.seed import Seed

STAMP_FACTORY_SETTINGS = Path(__file__).with_name("config.yaml")

# Set up the logger
Expand All @@ -41,23 +32,38 @@ def _create_config_file(config_file: Path) -> None:


def _run_cli(args: argparse.Namespace) -> None:
# Handle init command
# Handle init command before any stamp-internal imports so that
# `stamp init` and `stamp --help` don't pay the full torch/pydantic
# import cost.
if args.command == "init":
_create_config_file(args.config_file_path)
return

# Deferred imports: only reached for real commands, not --help / init.
from stamp.modeling.config import (
AdvancedConfig,
MlpModelParams,
ModelParams,
VitModelParams,
)
from stamp.utils.config import StampConfig
from stamp.utils.seed import Seed

# Load YAML configuration
with open(args.config_file_path, "r") as config_yaml:
config = StampConfig.model_validate(yaml.safe_load(config_yaml))

# use default advanced config in case none is provided
if config.advanced_config is None:
config.advanced_config = AdvancedConfig(
model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams()),
)
# Only build a default AdvancedConfig (with model-params) for commands
# that actually use it. Preprocess / encode / statistics / heatmaps
# never touch config.advanced_config, so don't pay the construction cost.
if args.command in {"train", "crossval"}:
if config.advanced_config is None:
config.advanced_config = AdvancedConfig(
model_params=ModelParams(vit=VitModelParams(), mlp=MlpModelParams()),
)

# Set global random seed
if config.advanced_config.seed is not None:
# Apply the global seed for any command that has one configured.
if config.advanced_config is not None and config.advanced_config.seed is not None:
Seed.set(config.advanced_config.seed)

match args.command:
Expand Down Expand Up @@ -153,6 +159,7 @@ def _run_cli(args: argparse.Namespace) -> None:
if config.training.task is None:
raise ValueError("task must be set in training configuration")

assert config.advanced_config is not None # guaranteed above for "train"
train_categorical_model_(
config=config.training, advanced=config.advanced_config
)
Expand Down Expand Up @@ -198,6 +205,7 @@ def _run_cli(args: argparse.Namespace) -> None:
f"{yaml.dump(config.crossval.model_dump(mode='json', exclude_none=True))}"
)

assert config.advanced_config is not None # guaranteed above for "crossval"
categorical_crossval_(
config=config.crossval,
advanced=config.advanced_config,
Expand Down
6 changes: 3 additions & 3 deletions src/stamp/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ preprocessing:
# Extractor to use for feature extractor. Possible options are "ctranspath",
# "uni", "conch", "chief-ctranspath", "conch1_5", "uni2", "dino-bloom",
# "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow",
# "virchow-full", "musk", "mstar", "plip", "ticon"
# "virchow-full", "musk", "mstar", "plip", "ticon", "red-dino", "keep"
# Some of them require requesting access to the respective authors beforehand.
extractor: "chief-ctranspath"

Expand Down Expand Up @@ -327,7 +327,7 @@ advanced_config:
max_lr: 1e-4
div_factor: 25.
# Select a model regardless of task
model_name: "vit" # or mlp, trans_mil, barspoon
model_name: "vit" # or mlp, linear, trans_mil, barspoon

model_params:
vit: # Vision Transformer
Expand Down Expand Up @@ -357,4 +357,4 @@ advanced_config:
num_encoder_layers: 2
num_decoder_layers: 2
dim_feedforward: 2048
positional_encoding: true
positional_encoding: true
6 changes: 4 additions & 2 deletions src/stamp/encoding/encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def encode_slides_(
if self.precision == torch.float16:
self.model.half()

for tile_feats_filename in (progress := tqdm(os.listdir(feat_dir))):
h5_files = sorted(f for f in os.listdir(feat_dir) if f.endswith(".h5"))
for tile_feats_filename in (progress := tqdm(h5_files)):
h5_path = os.path.join(feat_dir, tile_feats_filename)
slide_name: str = Path(tile_feats_filename).stem
progress.set_description(slide_name)
Expand Down Expand Up @@ -185,7 +186,8 @@ def _read_h5(
raise ValueError(f"File is not of type .h5: {os.path.basename(h5_path)}")
with h5py.File(h5_path, "r") as f:
feats_ds = cast(h5py.Dataset, f["feats"])
feats: Tensor = torch.tensor(feats_ds[:], dtype=self.precision)
# torch.from_numpy avoids a redundant data copy vs torch.tensor(array)
feats: Tensor = torch.from_numpy(feats_ds[()]).to(dtype=self.precision)
coords: CoordsInfo = get_coords(f)
extractor: str = f.attrs.get("extractor", "")
if extractor == "":
Expand Down
72 changes: 14 additions & 58 deletions src/stamp/encoding/encoder/gigapath.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ class Gigapath(Encoder):
def __init__(self) -> None:
try:
model = slide_encoder.create_model(
"hf_hub:prov-gigapath/prov-gigapath", "gigapath_slide_enc12l768d", 1536
"hf_hub:prov-gigapath/prov-gigapath",
"gigapath_slide_enc12l768d",
1536,
global_pool=True,
)
except AssertionError:
raise ModuleNotFoundError(
Expand All @@ -51,20 +54,9 @@ def _generate_slide_embedding(
if not coords:
raise ValueError("Tile coords are required for encoding")

# Calculate slide dimensions
slide_width = max(coords.coords_um[:, 0]) + coords.tile_size_um
slide_height = max(coords.coords_um[:, 1]) + coords.tile_size_um

# Normalize coordinates to a [0, 1000] grid
n_grid = 1000
norm_coords = self._convert_coords(
coords.coords_um, slide_width, slide_height, n_grid, current_x_offset=0
)
coords_px = coords.coords_um / coords.mpp
norm_coords = (
torch.tensor(norm_coords, dtype=torch.float32)
.unsqueeze(0)
.to(device)
.half()
torch.tensor(coords_px, dtype=torch.float32).unsqueeze(0).to(device).half()
)
feats = feats.unsqueeze(0).half().to(device)

Expand Down Expand Up @@ -119,8 +111,6 @@ def encode_patients_(

all_feats_list = []
all_coords_list = []
total_wsi_width = 0
max_wsi_height = 0
slides_mpp = SlideMPP(-1)

slide_info = []
Expand Down Expand Up @@ -151,31 +141,20 @@ def encode_patients_(
)

wsi_width = max(coords.coords_um[:, 0]) + coords.tile_size_um
wsi_height = max(coords.coords_um[:, 1]) + coords.tile_size_um

total_wsi_width += wsi_width # Sum the widths of all slides
max_wsi_height = max(max_wsi_height, wsi_height) # Track the max height

slide_info.append((wsi_width, wsi_height, feats, coords))
slide_info.append((wsi_width, feats, coords))

current_x_offset = 0

for wsi_width, wsi_height, feats, coords in slide_info:
norm_coords = self._convert_coords(
coords=coords.coords_um,
total_wsi_width=total_wsi_width,
max_wsi_height=max_wsi_height,
n_grid=1000,
current_x_offset=current_x_offset,
)
for wsi_width, feats, coords in slide_info:
offset_coords_um = coords.coords_um.copy()
offset_coords_um[:, 0] += current_x_offset

# Update x-coordinates by shifting them based on the current_x_offset
current_x_offset += (
wsi_width # Move the x_offset forward for the next slide
)
current_x_offset += wsi_width

coords_px = offset_coords_um / coords.mpp

norm_coords = (
torch.tensor(norm_coords, dtype=torch.float32)
torch.tensor(coords_px, dtype=torch.float32)
.unsqueeze(0)
.to(device)
.half()
Expand Down Expand Up @@ -211,26 +190,3 @@ def _generate_patient_embedding(
patient_embedding = torch.cat(patient_embedding, dim=0)

return patient_embedding.detach().squeeze().cpu().numpy()

def _convert_coords(
self,
coords,
total_wsi_width,
max_wsi_height,
n_grid,
current_x_offset,
) -> np.ndarray:
"""
Normalize the x and y coordinates relative to the total WSI width and max height, using the same grid [0, 1000].
Thanks Peter!
"""
# Normalize x-coordinates based on total WSI width (taking into account the current x offset)
normalized_x = (coords[:, 0] + current_x_offset) / total_wsi_width * n_grid

# Normalize y-coordinates based on the maximum WSI height
normalized_y = coords[:, 1] / max_wsi_height * n_grid

# Stack normalized x and y coordinates
converted_coords = np.stack([normalized_x, normalized_y], axis=-1)

return np.array(converted_coords, dtype=np.float32)
2 changes: 1 addition & 1 deletion src/stamp/heatmaps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def heatmaps_(

# TODO: Update version when a newer model logic breaks heatmaps.
stamp_version = str(getattr(model, "stamp_version", ""))
if Version(stamp_version) < Version("2.4.0"):
if Version(stamp_version) < Version("2.5.0"):
raise ValueError(
f"model has been built with stamp version {stamp_version} "
f"which is incompatible with the current version."
Expand Down
4 changes: 2 additions & 2 deletions src/stamp/modeling/crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def categorical_crossval_(
).to_csv(split_dir / "patient-preds.csv", index=False)
elif config.task == "regression":
if config.ground_truth_label is None:
raise RuntimeError("Grounf truth label is required for regression")
raise RuntimeError("Ground truth label is required for regression")
if isinstance(config.ground_truth_label, str):
_to_regression_prediction_df(
patient_to_ground_truth=cast(
Expand All @@ -353,7 +353,7 @@ def categorical_crossval_(
else:
if config.ground_truth_label is None:
raise RuntimeError(
"Grounf truth label is required for classification"
"Ground truth label is required for classification"
)
_to_prediction_df(
categories=categories_for_export,
Expand Down
Loading
Loading