diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3e61f9c4..eb834ee3 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -21,6 +21,7 @@ jobs: "uni", "uni2", "dino-bloom", + "red_dino", "gigapath", "h-optimus-0", "h-optimus-1", diff --git a/getting-started.md b/getting-started.md index 6d5bffec..2bb81ee6 100644 --- a/getting-started.md +++ b/getting-started.md @@ -59,6 +59,7 @@ Stamp currently supports the following feature extractors: - [MUSK][musk] - [PLIP][plip] - [TICON][ticon] + - [RedDino][reddino] As some of the above require you to request access to the model on huggingface, @@ -160,6 +161,7 @@ meaning ignored that it was ignored during feature extraction. [MADELEINE]: https://huggingface.co/MahmoodLab/madeleine [PRISM]: https://huggingface.co/paige-ai/Prism [TICON]: https://cvlab-stonybrook.github.io/TICON/ "TICON: A Slide-Level Tile Contextualizer for Histopathology Representation Learning" +[reddino]: https://github.com/Snarci/RedDino "RedDino: A Foundation Model for Red Blood Cell Analysis" @@ -514,4 +516,4 @@ STAMP automatically adapts its **model architecture**, **loss function**, and ** **Regression** tasks only require `ground_truth_label`. **Survival analysis** tasks require `time_label` (follow-up time) and `status_label` (event indicator). -These requirements apply consistently across cross-validation, training, deployment, and statistics. \ No newline at end of file +These requirements apply consistently across cross-validation, training, deployment, and statistics. diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index 8440560b..1c0965d2 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -4,7 +4,7 @@ preprocessing: # Extractor to use for feature extractor. Possible options are "ctranspath", # "uni", "conch", "chief-ctranspath", "conch1_5", "uni2", "dino-bloom", # "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow", - # "virchow-full", "musk", "mstar", "plip" + # "virchow-full", "musk", "mstar", "plip", "ticon", "red-dino" # Some of them require requesting access to the respective authors beforehand. extractor: "chief-ctranspath" diff --git a/src/stamp/preprocessing/__init__.py b/src/stamp/preprocessing/__init__.py index ab3ff0d2..3b747d67 100755 --- a/src/stamp/preprocessing/__init__.py +++ b/src/stamp/preprocessing/__init__.py @@ -177,6 +177,11 @@ def extract_( extractor = dino_bloom() + case ExtractorName.RED_DINO: + from stamp.preprocessing.extractor.reddino import red_dino + + extractor = red_dino() + case ExtractorName.VIRCHOW: from stamp.preprocessing.extractor.virchow import virchow diff --git a/src/stamp/preprocessing/config.py b/src/stamp/preprocessing/config.py index 5eca41dd..072253e4 100644 --- a/src/stamp/preprocessing/config.py +++ b/src/stamp/preprocessing/config.py @@ -19,6 +19,7 @@ class ExtractorName(StrEnum): UNI = "uni" UNI2 = "uni2" DINO_BLOOM = "dino-bloom" + RED_DINO = "red-dino" GIGAPATH = "gigapath" H_OPTIMUS_0 = "h-optimus-0" H_OPTIMUS_1 = "h-optimus-1" diff --git a/src/stamp/preprocessing/extractor/reddino.py b/src/stamp/preprocessing/extractor/reddino.py new file mode 100644 index 00000000..ef92c551 --- /dev/null +++ b/src/stamp/preprocessing/extractor/reddino.py @@ -0,0 +1,62 @@ +""" +Port from https://github.com/Snarci/RedDino +RedDino: A Foundation Model for Red Blood Cell Analysis +""" + +from typing import Callable, cast + +try: + import timm + import torch + from PIL import Image + from torchvision import transforms +except ModuleNotFoundError as e: + raise ModuleNotFoundError("red-dino dependencies not installed.") from e + +from stamp.preprocessing.config import ExtractorName +from stamp.preprocessing.extractor import Extractor + +__license__ = "MIT" + + +class RedDinoClsOnly(torch.nn.Module): + def __init__(self, model) -> None: + super().__init__() + self.model = model + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + out = self.model(batch) + if isinstance(out, tuple): + out = out[0] + # if model returns tokens, return class token + if getattr(out, "ndim", 0) >= 2 and out.shape[1] > 1: + return out[:, 0] + return out + + +def red_dino() -> Extractor[RedDinoClsOnly]: + """Extracts features from single image using RedDino encoder.""" + + model = timm.create_model( + "hf-hub:Snarcy/RedDino-large", + pretrained=True, + ) + + transform = cast( + Callable[[Image.Image], torch.Tensor], + transforms.Compose( + [ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ), + ) + + return Extractor( + model=RedDinoClsOnly(model), + transform=transform, + identifier=ExtractorName.RED_DINO, + ) diff --git a/src/stamp/preprocessing/extractor/uni2.py b/src/stamp/preprocessing/extractor/uni2.py index 459eff95..f187bf3c 100644 --- a/src/stamp/preprocessing/extractor/uni2.py +++ b/src/stamp/preprocessing/extractor/uni2.py @@ -1,8 +1,12 @@ +from typing import Callable, cast + try: import timm import torch - from timm.data import resolve_data_config # type: ignore + from PIL import Image + from timm.data.config import resolve_data_config from timm.data.transforms_factory import create_transform + from timm.layers.mlp import SwiGLUPacked except ModuleNotFoundError as e: raise ModuleNotFoundError( "uni2 dependencies not installed." @@ -13,7 +17,7 @@ from stamp.preprocessing.extractor import Extractor -def uni2() -> Extractor: +def uni2() -> Extractor[torch.nn.Module]: # pretrained=True needed to load UNI2-h weights (and download weights for the first time) timm_kwargs = { "img_size": 224, @@ -25,7 +29,7 @@ def uni2() -> Extractor: "mlp_ratio": 2.66667 * 2, "num_classes": 0, "no_embed_class": True, - "mlp_layer": timm.layers.SwiGLUPacked, + "mlp_layer": SwiGLUPacked, "act_layer": torch.nn.SiLU, "reg_tokens": 8, "dynamic_img_size": True, @@ -33,13 +37,14 @@ def uni2() -> Extractor: model = timm.create_model( "hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs ) - transform = create_transform( - **resolve_data_config(model.pretrained_cfg, model=model) + transform = cast( + Callable[[Image.Image], torch.Tensor], + create_transform(**resolve_data_config(model.pretrained_cfg, model=model)), ) model.eval() return Extractor( model=model, transform=transform, - identifier=ExtractorName.UNI2, # type: ignore + identifier=ExtractorName.UNI2, )