Skip to content
11 changes: 7 additions & 4 deletions src/stamp/preprocessing/extractor/ticon.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,6 @@ def load_ticon(device: str = "cuda") -> nn.Module:
class HOptimusTICON(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.device = device

# ----------------------------
# Stage 1: H-OptimUS
Expand Down Expand Up @@ -689,7 +688,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: [B, 3, 224, 224] (CPU or CUDA)
"""
x = x.to(self.device, non_blocking=True)
# Respect the current module device (it may be moved after construction).
device = next(self.parameters()).device
x = x.to(device, non_blocking=True)

# H-Optimus_1
emb = self.tile_encoder(x) # [B, 1536]
Expand All @@ -700,7 +701,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
emb.size(0),
1,
2,
device=self.device,
device=device,
dtype=torch.float32,
)

Expand All @@ -713,7 +714,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return out.squeeze(1) # [B, 1536]


def ticon(device: str = "cuda") -> Extractor[nn.Module]:
def ticon(device: str | None = None) -> Extractor[nn.Module]:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = HOptimusTICON(torch.device(device))

transform = transforms.Compose(
Expand Down
Loading