diff --git a/src/stamp/preprocessing/extractor/ticon.py b/src/stamp/preprocessing/extractor/ticon.py index ab7eb829..5780afc4 100644 --- a/src/stamp/preprocessing/extractor/ticon.py +++ b/src/stamp/preprocessing/extractor/ticon.py @@ -624,7 +624,6 @@ def load_ticon(device: str = "cuda") -> nn.Module: class HOptimusTICON(nn.Module): def __init__(self, device: torch.device): super().__init__() - self.device = device # ---------------------------- # Stage 1: H-OptimUS @@ -689,7 +688,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: [B, 3, 224, 224] (CPU or CUDA) """ - x = x.to(self.device, non_blocking=True) + # Respect the current module device (it may be moved after construction). + device = next(self.parameters()).device + x = x.to(device, non_blocking=True) # H-Optimus_1 emb = self.tile_encoder(x) # [B, 1536] @@ -700,7 +701,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: emb.size(0), 1, 2, - device=self.device, + device=device, dtype=torch.float32, ) @@ -713,7 +714,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out.squeeze(1) # [B, 1536] -def ticon(device: str = "cuda") -> Extractor[nn.Module]: +def ticon(device: str | None = None) -> Extractor[nn.Module]: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" model = HOptimusTICON(torch.device(device)) transform = transforms.Compose(