Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
"uni",
"uni2",
"dino-bloom",
"red_dino",
"gigapath",
"h-optimus-0",
"h-optimus-1",
Expand Down
4 changes: 3 additions & 1 deletion getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Stamp currently supports the following feature extractors:
- [MUSK][musk]
- [PLIP][plip]
- [TICON][ticon]
- [RedDino][reddino]


As some of the above require you to request access to the model on huggingface,
Expand Down Expand Up @@ -160,6 +161,7 @@ meaning ignored that it was ignored during feature extraction.
[MADELEINE]: https://huggingface.co/MahmoodLab/madeleine
[PRISM]: https://huggingface.co/paige-ai/Prism
[TICON]: https://cvlab-stonybrook.github.io/TICON/ "TICON: A Slide-Level Tile Contextualizer for Histopathology Representation Learning"
[reddino]: https://github.com/Snarci/RedDino "RedDino: A Foundation Model for Red Blood Cell Analysis"



Expand Down Expand Up @@ -514,4 +516,4 @@ STAMP automatically adapts its **model architecture**, **loss function**, and **

**Regression** tasks only require `ground_truth_label`.
**Survival analysis** tasks require `time_label` (follow-up time) and `status_label` (event indicator).
These requirements apply consistently across cross-validation, training, deployment, and statistics.
These requirements apply consistently across cross-validation, training, deployment, and statistics.
2 changes: 1 addition & 1 deletion src/stamp/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ preprocessing:
# Extractor to use for feature extractor. Possible options are "ctranspath",
# "uni", "conch", "chief-ctranspath", "conch1_5", "uni2", "dino-bloom",
# "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow",
# "virchow-full", "musk", "mstar", "plip"
# "virchow-full", "musk", "mstar", "plip", "ticon", "red-dino"
# Some of them require requesting access to the respective authors beforehand.
extractor: "chief-ctranspath"

Expand Down
5 changes: 5 additions & 0 deletions src/stamp/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ def extract_(

extractor = dino_bloom()

case ExtractorName.RED_DINO:
from stamp.preprocessing.extractor.reddino import red_dino

extractor = red_dino()

case ExtractorName.VIRCHOW:
from stamp.preprocessing.extractor.virchow import virchow

Expand Down
1 change: 1 addition & 0 deletions src/stamp/preprocessing/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ExtractorName(StrEnum):
UNI = "uni"
UNI2 = "uni2"
DINO_BLOOM = "dino-bloom"
RED_DINO = "red-dino"
GIGAPATH = "gigapath"
H_OPTIMUS_0 = "h-optimus-0"
H_OPTIMUS_1 = "h-optimus-1"
Expand Down
62 changes: 62 additions & 0 deletions src/stamp/preprocessing/extractor/reddino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Port from https://github.com/Snarci/RedDino
RedDino: A Foundation Model for Red Blood Cell Analysis
"""

from typing import Callable, cast

try:
import timm
import torch
from PIL import Image
from torchvision import transforms
except ModuleNotFoundError as e:
raise ModuleNotFoundError("red-dino dependencies not installed.") from e

from stamp.preprocessing.config import ExtractorName
from stamp.preprocessing.extractor import Extractor

__license__ = "MIT"


class RedDinoClsOnly(torch.nn.Module):
def __init__(self, model) -> None:
super().__init__()
self.model = model

def forward(self, batch: torch.Tensor) -> torch.Tensor:
out = self.model(batch)
if isinstance(out, tuple):
out = out[0]
# if model returns tokens, return class token
if getattr(out, "ndim", 0) >= 2 and out.shape[1] > 1:
return out[:, 0]
return out


def red_dino() -> Extractor[RedDinoClsOnly]:
"""Extracts features from single image using RedDino encoder."""

model = timm.create_model(
"hf-hub:Snarcy/RedDino-large",
pretrained=True,
)

transform = cast(
Callable[[Image.Image], torch.Tensor],
transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
),
)

return Extractor(
model=RedDinoClsOnly(model),
transform=transform,
identifier=ExtractorName.RED_DINO,
)
17 changes: 11 additions & 6 deletions src/stamp/preprocessing/extractor/uni2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Callable, cast

try:
import timm
import torch
from timm.data import resolve_data_config # type: ignore
from PIL import Image
from timm.data.config import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers.mlp import SwiGLUPacked
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"uni2 dependencies not installed."
Expand All @@ -13,7 +17,7 @@
from stamp.preprocessing.extractor import Extractor


def uni2() -> Extractor:
def uni2() -> Extractor[torch.nn.Module]:
# pretrained=True needed to load UNI2-h weights (and download weights for the first time)
timm_kwargs = {
"img_size": 224,
Expand All @@ -25,21 +29,22 @@ def uni2() -> Extractor:
"mlp_ratio": 2.66667 * 2,
"num_classes": 0,
"no_embed_class": True,
"mlp_layer": timm.layers.SwiGLUPacked,
"mlp_layer": SwiGLUPacked,
"act_layer": torch.nn.SiLU,
"reg_tokens": 8,
"dynamic_img_size": True,
}
model = timm.create_model(
"hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs
)
transform = create_transform(
**resolve_data_config(model.pretrained_cfg, model=model)
transform = cast(
Callable[[Image.Image], torch.Tensor],
create_transform(**resolve_data_config(model.pretrained_cfg, model=model)),
)
model.eval()

return Extractor(
model=model,
transform=transform,
identifier=ExtractorName.UNI2, # type: ignore
identifier=ExtractorName.UNI2,
)
Loading