From fed6f310e4e97dbc477e04b2cb2690746caaedb9 Mon Sep 17 00:00:00 2001 From: rohitrango Date: Sun, 15 Feb 2026 15:46:27 -0500 Subject: [PATCH 1/3] small tweaks for scale Signed-off-by: rohitrango --- fireants/registration/distributed/ring_sampler.py | 1 - fireants/registration/greedy.py | 9 +++++++-- fireants/registration/optimizers/adam.py | 13 +++++++++++++ fireants/registration/optimizers/sgd.py | 9 +++++++++ fireants/registration/syn.py | 7 ++++++- fireants/utils/imageutils.py | 3 +++ 6 files changed, 38 insertions(+), 4 deletions(-) diff --git a/fireants/registration/distributed/ring_sampler.py b/fireants/registration/distributed/ring_sampler.py index 0be0ecf..d3661b2 100644 --- a/fireants/registration/distributed/ring_sampler.py +++ b/fireants/registration/distributed/ring_sampler.py @@ -235,7 +235,6 @@ def distributed_grid_sampler_3d_bwd(grad_output, grad_image, grad_affine, grad_g # Compute gradients for our rank's portion grad_affine_buf = zeros_like_or_none(grad_affine) - # breakpoint() fused_grid_sampler_3d_backward( grad_output, grad_image, grad_affine_buf, grad_grid, diff --git a/fireants/registration/greedy.py b/fireants/registration/greedy.py index 594db4f..c118e0a 100644 --- a/fireants/registration/greedy.py +++ b/fireants/registration/greedy.py @@ -281,8 +281,13 @@ def optimize(self): if hasattr(self.loss_fn, 'set_current_scale_and_iterations'): self.loss_fn.set_current_scale_and_iterations(scale, iters) # resize images - size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in fixed_size] - moving_size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in moving_size] + if scale > 1: + size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in fixed_size] + moving_size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in moving_size] + else: + size_down = fixed_size + moving_size_down = moving_size + if self.blur and scale > 1: sigmas = 0.5 * torch.tensor( [sz / szdown for sz, szdown in zip(fixed_size, size_down)], diff --git a/fireants/registration/optimizers/adam.py b/fireants/registration/optimizers/adam.py index b538dea..ef61378 100644 --- a/fireants/registration/optimizers/adam.py +++ b/fireants/registration/optimizers/adam.py @@ -76,6 +76,7 @@ def __init__(self, warp, lr, smoothing_gaussians=None, grad_gaussians=None, freeform=False, + restrict_deformations=None, offload=False, # try offloading to CPU reset=False, # distributed params @@ -138,6 +139,13 @@ def __init__(self, warp, lr, self.padding_smoothing = 0 # get wrapper around smoothing for distributed / not distributed self.smoothing_wrapper = _get_smoothing_wrapper(self) + # gradient restriction (e.g. to restrict deformations along certain dimensions) + if restrict_deformations is None: + self.gradient_restriction = lambda x: x + else: + logger.info(f"Setting restriction to: {restrict_deformations}") + self._restrict_deformations = torch.as_tensor(restrict_deformations) + self.gradient_restriction = lambda x: x * self._restrict_deformations.to(x.device).to(x.dtype) def cleanup(self): # manually clean up @@ -213,6 +221,8 @@ def step(self): # add weight decay term if self.weight_decay > 0: grad.add_(self.warp.data, alpha=self.weight_decay) + # apply gradient restriction (e.g. restrict deformations along certain dims) + grad = self.gradient_restriction(grad) # compute moments self.step_t += 1 @@ -278,4 +288,7 @@ def step(self): # smooth result if asked for if self.smoothing_gaussians is not None: grad = self.smoothing_wrapper(grad, self.smoothing_gaussians, self.padding_smoothing) + + grad = self.gradient_restriction(grad) + self.warp.data.copy_(grad) diff --git a/fireants/registration/optimizers/sgd.py b/fireants/registration/optimizers/sgd.py index ff7a632..42f986a 100644 --- a/fireants/registration/optimizers/sgd.py +++ b/fireants/registration/optimizers/sgd.py @@ -32,6 +32,7 @@ def __init__(self, warp, lr, momentum=0, dampening=0, weight_decay=0, nesterov=False, scaledown=False, multiply_jacobian=False, smoothing_gaussians=None, grad_gaussians=None, freeform=False, + restrict_deformations=None, # distributed params rank: int = 0, dim_to_shard: int = 0, @@ -85,6 +86,12 @@ def __init__(self, warp, lr, self.padding_smoothing = 0 # get wrapper around smoothing for distributed / not distributed self.smoothing_wrapper = _get_smoothing_wrapper(self) + # gradient restriction (e.g. to restrict deformations along certain dimensions) + if restrict_deformations is None: + self.gradient_restriction = lambda x: x + else: + self._restrict_deformations = torch.as_tensor(restrict_deformations) + self.gradient_restriction = lambda x: x * self._restrict_deformations.to(x.device).to(x.dtype) def cleanup(self): # manually clean up @@ -137,6 +144,8 @@ def step(self): # add weight decay term if self.weight_decay > 0: grad.add_(self.warp.data, alpha=self.weight_decay) + # apply gradient restriction (e.g. restrict deformations along certain dims) + grad = self.gradient_restriction(grad) # add momentum if self.momentum > 0: if self.velocity is None: diff --git a/fireants/registration/syn.py b/fireants/registration/syn.py index 81fcf8e..462a279 100644 --- a/fireants/registration/syn.py +++ b/fireants/registration/syn.py @@ -257,7 +257,11 @@ def optimize(self): if hasattr(self.loss_fn, 'set_current_scale_and_iterations'): self.loss_fn.set_current_scale_and_iterations(scale, iters) # resize images - size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in fixed_size] + if scale > 1: + size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in fixed_size] + else: + size_down = fixed_size + if self.blur and scale > 1: sigmas = 0.5 * torch.tensor( [sz / szdown for sz, szdown in zip(fixed_size, size_down)], @@ -291,6 +295,7 @@ def optimize(self): scale_factor = 1 else: scale_factor = np.prod(fixed_image_down.shape) + for i in pbar: # set zero grads self.fwd_warp.set_zero_grad() diff --git a/fireants/utils/imageutils.py b/fireants/utils/imageutils.py index 4d6ad03..0282e80 100644 --- a/fireants/utils/imageutils.py +++ b/fireants/utils/imageutils.py @@ -136,6 +136,9 @@ def downsample(image: torch.Tensor, size: List[int], mode: str, sigma: Optional[ if image.device.type == 'cpu': use_fft = False + if not all([x <= y for x, y in zip(size, image.shape[2:])]): + use_fft = False + if use_fft: return downsample_fft(image.to(torch.float32), size).to(image.dtype) From 9b7396eb4092aaaf5c4e53809b041061296b5f2b Mon Sep 17 00:00:00 2001 From: rohitrango Date: Sun, 15 Feb 2026 15:52:53 -0500 Subject: [PATCH 2/3] refactor downsample size Signed-off-by: rohitrango --- fireants/registration/affine.py | 6 ++--- fireants/registration/greedy.py | 10 +++----- fireants/registration/helpers.py | 41 ++++++++++++++++++++++++++++++++ fireants/registration/rigid.py | 6 ++--- fireants/registration/syn.py | 9 +++---- 5 files changed, 53 insertions(+), 19 deletions(-) create mode 100644 fireants/registration/helpers.py diff --git a/fireants/registration/affine.py b/fireants/registration/affine.py index fde58a9..493f903 100644 --- a/fireants/registration/affine.py +++ b/fireants/registration/affine.py @@ -20,7 +20,7 @@ from fireants.io.image import BatchedImages, FakeBatchedImages from torch.optim import SGD, Adam from torch.nn import functional as F -from fireants.utils.globals import MIN_IMG_SIZE +from fireants.registration.helpers import downsample_size from tqdm import tqdm import numpy as np from fireants.losses.cc import gaussian_1d, separable_filtering @@ -239,8 +239,8 @@ def optimize(self): if hasattr(self.loss_fn, 'set_current_scale_and_iterations'): self.loss_fn.set_current_scale_and_iterations(scale, iters) # downsample fixed array and retrieve coords - size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in fixed_size] - mov_size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in moving_arrays.shape[2:]] + size_down = downsample_size(fixed_size, scale) + mov_size_down = downsample_size(list(moving_arrays.shape[2:]), scale) # downsample if self.blur and scale > 1: sigmas = 0.5 * torch.tensor( diff --git a/fireants/registration/greedy.py b/fireants/registration/greedy.py index c118e0a..66e51b3 100644 --- a/fireants/registration/greedy.py +++ b/fireants/registration/greedy.py @@ -22,7 +22,7 @@ from tqdm import tqdm import SimpleITK as sitk -from fireants.utils.globals import MIN_IMG_SIZE +from fireants.registration.helpers import downsample_size from fireants.io.image import BatchedImages, FakeBatchedImages from fireants.registration.abstract import AbstractRegistration from fireants.registration.deformation.svf import StationaryVelocity @@ -281,12 +281,8 @@ def optimize(self): if hasattr(self.loss_fn, 'set_current_scale_and_iterations'): self.loss_fn.set_current_scale_and_iterations(scale, iters) # resize images - if scale > 1: - size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in fixed_size] - moving_size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in moving_size] - else: - size_down = fixed_size - moving_size_down = moving_size + size_down = downsample_size(fixed_size, scale) + moving_size_down = downsample_size(moving_size, scale) if self.blur and scale > 1: sigmas = 0.5 * torch.tensor( diff --git a/fireants/registration/helpers.py b/fireants/registration/helpers.py new file mode 100644 index 0000000..dd4247a --- /dev/null +++ b/fireants/registration/helpers.py @@ -0,0 +1,41 @@ +# Copyright (c) 2026 Rohit Jena. All rights reserved. +# +# This file is part of FireANTs, distributed under the terms of +# the FireANTs License version 1.0. A copy of the license can be found +# in the LICENSE file at the root of this repository. +# +# IMPORTANT: This code is part of FireANTs and its use, reproduction, or +# distribution must comply with the full license terms, including: +# - Maintaining all copyright notices and bibliography references +# - Using only approved (re)-distribution channels +# - Proper attribution in derivative works +# +# For full license details, see: https://github.com/rohitrango/FireANTs/blob/main/LICENSE + + +from typing import List, Sequence + +from fireants.utils.globals import MIN_IMG_SIZE + + +def downsample_size( + size: Sequence[int], + scale: float, + min_img_size: int = MIN_IMG_SIZE, +) -> List[int]: + """Compute spatial size after downsampling by scale, with a minimum size per dimension. + + When scale > 1, each dimension is divided by scale and floored, but not below min_img_size. + When scale <= 1, the original size is returned unchanged. + + Args: + size: Spatial dimensions (e.g. [H, W] or [D, H, W]). + scale: Downsampling factor (e.g. 2 for half resolution). + min_img_size: Minimum value per dimension when downsampling. + + Returns: + List of downsampled spatial dimensions. + """ + if scale > 1: + return [max(int(s / scale), min_img_size) for s in size] + return list(size) diff --git a/fireants/registration/rigid.py b/fireants/registration/rigid.py index fc53283..96022a9 100644 --- a/fireants/registration/rigid.py +++ b/fireants/registration/rigid.py @@ -28,7 +28,7 @@ import numpy as np from fireants.losses.cc import gaussian_1d, separable_filtering from fireants.utils.imageutils import downsample -from fireants.utils.globals import MIN_IMG_SIZE +from fireants.registration.helpers import downsample_size from fireants.interpolator import fireants_interpolator import logging logger = logging.getLogger(__name__) @@ -294,8 +294,8 @@ def optimize(self): if hasattr(self.loss_fn, 'set_current_scale_and_iterations'): self.loss_fn.set_current_scale_and_iterations(scale, iters) # downsample fixed array and retrieve coords - size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in fixed_size] - mov_size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in moving_arrays.shape[2:]] + size_down = downsample_size(fixed_size, scale) + mov_size_down = downsample_size(list(moving_arrays.shape[2:]), scale) # downsample if self.blur and scale > 1: sigmas = 0.5 * torch.tensor( diff --git a/fireants/registration/syn.py b/fireants/registration/syn.py index 462a279..aa5b7d7 100644 --- a/fireants/registration/syn.py +++ b/fireants/registration/syn.py @@ -21,7 +21,7 @@ from fireants.io.image import BatchedImages, FakeBatchedImages from torch.optim import SGD, Adam from torch.nn import functional as F -from fireants.utils.globals import MIN_IMG_SIZE +from fireants.registration.helpers import downsample_size from fireants.registration.abstract import AbstractRegistration from fireants.registration.deformation.compositive import CompositiveWarp from fireants.registration.deformation.svf import StationaryVelocity @@ -256,11 +256,8 @@ def optimize(self): # notify loss function of scale change if it supports it if hasattr(self.loss_fn, 'set_current_scale_and_iterations'): self.loss_fn.set_current_scale_and_iterations(scale, iters) - # resize images - if scale > 1: - size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in fixed_size] - else: - size_down = fixed_size + # resize images + size_down = downsample_size(fixed_size, scale) if self.blur and scale > 1: sigmas = 0.5 * torch.tensor( From 38866bc21ece7b2c051dbe5c9c896d4fb0b7a513 Mon Sep 17 00:00:00 2001 From: rohitrango Date: Wed, 18 Feb 2026 15:36:40 -0500 Subject: [PATCH 3/3] added partial warp addition Signed-off-by: rohitrango --- fireants/registration/abstract.py | 11 +- fireants/registration/deformablemixin.py | 25 +- fireants/registration/optimizers/adam.py | 5 +- fireants/registration/optimizers/sgd.py | 5 +- .../restricted_deformations.py | 261 ++++++++++++++++++ 5 files changed, 298 insertions(+), 9 deletions(-) create mode 100644 fireants/scripts/restricted_deformations/restricted_deformations.py diff --git a/fireants/registration/abstract.py b/fireants/registration/abstract.py index acc23e4..95a3ea3 100644 --- a/fireants/registration/abstract.py +++ b/fireants/registration/abstract.py @@ -331,7 +331,7 @@ def save_moved_images(self, moved_images: Union[BatchedImages, FakeBatchedImages moved_images_save.write_image(filenames) - def evaluate_inverse(self, fixed_images: Union[BatchedImages, torch.Tensor], moving_images: Union[BatchedImages, torch.Tensor], shape=None, **kwargs): + def evaluate_inverse(self, fixed_images: Union[BatchedImages, torch.Tensor], moving_images: Union[BatchedImages, torch.Tensor], shape=None, fixed_moved_coords=None, **kwargs): ''' Apply the inverse of the learned transformation to new images. This method is useful to analyse the effect of how the moving coordinates (fixed images) are transformed @@ -342,12 +342,13 @@ def evaluate_inverse(self, fixed_images: Union[BatchedImages, torch.Tensor], mov moving_images = FakeBatchedImages(moving_images, self.moving_images) fixed_arrays = moving_images() - fixed_moved_coords = self.get_inverse_warp_parameters(fixed_images, moving_images, shape=shape, **kwargs) + if fixed_moved_coords is None: + fixed_moved_coords = self.get_inverse_warp_parameters(fixed_images, moving_images, shape=shape, **kwargs) fixed_moved_image = fireants_interpolator(fixed_arrays, **fixed_moved_coords, mode='bilinear', align_corners=True) # [N, C, H, W, [D]] return fixed_moved_image - def evaluate(self, fixed_images: Union[BatchedImages, torch.Tensor], moving_images: Union[BatchedImages, torch.Tensor], shape=None): + def evaluate(self, fixed_images: Union[BatchedImages, torch.Tensor], moving_images: Union[BatchedImages, torch.Tensor], shape=None, moved_coords=None): '''Apply the learned transformation to new images. This method applies the registration transformation learned during optimization @@ -364,6 +365,7 @@ def evaluate(self, fixed_images: Union[BatchedImages, torch.Tensor], moving_imag moving_images (BatchedImages): Moving images to be transformed shape (Optional[Tuple[int, ...]]): Optional output shape for the transformed image. If None, uses the shape of the fixed image. + moved_coords (Optional[Dict]): Optional dictionary of moved coordinates. If None, the coordinates are computed using the `get_warp_parameters` method. Returns: torch.Tensor: The transformed moving image in the space of the fixed image. @@ -382,7 +384,8 @@ def evaluate(self, fixed_images: Union[BatchedImages, torch.Tensor], moving_imag moving_images = FakeBatchedImages(moving_images, self.moving_images) moving_arrays = moving_images() - moved_coords = self.get_warp_parameters(fixed_images, moving_images, shape=shape) + if moved_coords is None: + moved_coords = self.get_warp_parameters(fixed_images, moving_images, shape=shape) interpolate_mode = moving_images.get_interpolator_type() moved_image = fireants_interpolator(moving_arrays, **moved_coords, mode=interpolate_mode, align_corners=True) # [N, C, H, W, [D]] return moved_image diff --git a/fireants/registration/deformablemixin.py b/fireants/registration/deformablemixin.py index 36d7302..0e44d8b 100644 --- a/fireants/registration/deformablemixin.py +++ b/fireants/registration/deformablemixin.py @@ -30,7 +30,7 @@ from fireants.interpolator import fireants_interpolator from fireants.utils.globals import MIN_IMG_SIZE -from fireants.io.image import BatchedImages +from fireants.io.image import BatchedImages, FakeBatchedImages from fireants.registration.abstract import AbstractRegistration from fireants.registration.deformation.svf import StationaryVelocity from fireants.registration.deformation.compositive import CompositiveWarp @@ -55,6 +55,29 @@ class DeformableMixin: - get_warped_coordinates(): Method to get transformed coordinates """ + @torch.no_grad() + def get_partial_warped_parameters(reg, fixed_images: Union[BatchedImages, FakeBatchedImages], moving_images: Union[BatchedImages, FakeBatchedImages], fraction: float, shape=None): + """Return warp parameters with the grid scaled by a fraction in [0, 1]. + + Calls get_warp_parameters then multiplies the returned grid by fraction. + Same signature as get_warp_parameters with an additional fraction parameter. + + Args: + fixed_images: Fixed/reference images. + moving_images: Moving images. + fraction: Scale factor for the grid, must be in [0, 1]. + shape: Optional output shape (passed to get_warp_parameters). + + Returns: + Dict with 'affine' and 'grid' keys; 'grid' is the original grid multiplied by fraction. + """ + if not (0 <= fraction <= 1): + raise ValueError("fraction must be in [0, 1], got %s" % fraction) + params = reg.get_warp_parameters(fixed_images, moving_images, shape=shape) + out = dict(params) + out["grid"] = params["grid"] * fraction + return out + @torch.no_grad() def save_as_ants_transforms(reg, filenames: Union[str, List[str]], save_inverse=False): """Save deformation fields in ANTs-compatible format. diff --git a/fireants/registration/optimizers/adam.py b/fireants/registration/optimizers/adam.py index ef61378..71c6618 100644 --- a/fireants/registration/optimizers/adam.py +++ b/fireants/registration/optimizers/adam.py @@ -221,8 +221,6 @@ def step(self): # add weight decay term if self.weight_decay > 0: grad.add_(self.warp.data, alpha=self.weight_decay) - # apply gradient restriction (e.g. restrict deformations along certain dims) - grad = self.gradient_restriction(grad) # compute moments self.step_t += 1 @@ -242,6 +240,9 @@ def step(self): # adam_update_fused(grad, self.exp_avg, self.exp_avg_sq, beta_correction1, beta_correction2, self.eps) self.adam_update_kernel(grad, self.exp_avg, self.exp_avg_sq, beta_correction1, beta_correction2, self.eps) + # apply gradient restriction (e.g. restrict deformations along certain dims) + grad = self.gradient_restriction(grad) + # we offload this to CPU if self.offload: # torch.cuda.synchronize() diff --git a/fireants/registration/optimizers/sgd.py b/fireants/registration/optimizers/sgd.py index 42f986a..ad45fb5 100644 --- a/fireants/registration/optimizers/sgd.py +++ b/fireants/registration/optimizers/sgd.py @@ -144,8 +144,6 @@ def step(self): # add weight decay term if self.weight_decay > 0: grad.add_(self.warp.data, alpha=self.weight_decay) - # apply gradient restriction (e.g. restrict deformations along certain dims) - grad = self.gradient_restriction(grad) # add momentum if self.momentum > 0: if self.velocity is None: @@ -160,6 +158,9 @@ def step(self): else: # grad = buf grad.copy_(buf) + # apply gradient restriction (e.g. restrict deformations along certain dims) + grad = self.gradient_restriction(grad) + ## renormalize and update warp (per pixel) gradmax = self.eps + grad.norm(p=2, dim=-1, keepdim=True) # gradmean = gradmax.flatten(1).mean(1) # [B,] diff --git a/fireants/scripts/restricted_deformations/restricted_deformations.py b/fireants/scripts/restricted_deformations/restricted_deformations.py new file mode 100644 index 0000000..720bcb9 --- /dev/null +++ b/fireants/scripts/restricted_deformations/restricted_deformations.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 +""" +Run FireANTs deformable registration on reverse phase-encoding pairs from notepad/data. +Uses restrict_deformations based on folder name (AP/PA -> y only, RL/LR -> x only). +Saves fixed, moved, and moving slices with matplotlib. +""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +# Ensure fireants is importable when run from notepad or repo root +_repo_root = Path(__file__).resolve().parents[2] +if str(_repo_root) not in sys.path: + sys.path.insert(0, str(_repo_root)) + +import torch +import matplotlib.pyplot as plt +from time import time + +from fireants.io import Image, BatchedImages, FakeBatchedImages +from fireants.registration.affine import AffineRegistration +from fireants.registration.greedy import GreedyRegistration + + +# --------------------------------------------------------------------------- +# Data and restriction mapping +# --------------------------------------------------------------------------- + +SCRIPT_DIR = Path(__file__).resolve().parent +DATA_DIR = SCRIPT_DIR / "data" +OUTPUT_DIR = SCRIPT_DIR / "outputs" + +# Subdirectory names that are example/output (skip) +SKIP_SUBDIRS = {"padded_ants_example"} + +# Phase-encoding direction -> which dimensions to allow deformation (1 = allow, 0 = restrict). +# ANTs --restrict-deformation 0x1x0 means allow only y (AP/PA distortion). We use the same idea. +# 3D: [x, y, z]. AP/PA = y; RL/LR = x. +def get_restrict_deformations(subdir_name: str) -> list[float]: + name_upper = subdir_name.upper() + if "AP" in name_upper or "PA" in name_upper: + return [0.0, 1.0, 0.0] # allow deformation only in y + if "RL" in name_upper or "LR" in name_upper: + return [1.0, 0.0, 0.0] # allow deformation only in x + # default: allow all (no restriction) + return None + + +def discover_pairs(data_dir: Path): + """Yield (subdir_name, fixed_path, moving_path) for each pair directory.""" + if not data_dir.is_dir(): + return + for subdir in sorted(data_dir.iterdir()): + if not subdir.is_dir() or subdir.name in SKIP_SUBDIRS: + continue + nii = sorted(subdir.glob("*.nii.gz")) + if len(nii) < 2: + continue + # Use first as fixed, second as moving (arbitrary) + yield subdir.name, str(nii[0]), str(nii[1]) + + +def load_pair_ras(fixed_path: str, moving_path: str, device: str = "cuda"): + """Load fixed and moving images in RAS orientation; optional winsorize like ANTs.""" + fixed = Image.load_file( + fixed_path, + device=device, + orientation="RAS", + winsorize=True, + winsorize_percentile=(0.5, 99.5), + ) + moving = Image.load_file( + moving_path, + device=device, + orientation="RAS", + winsorize=True, + winsorize_percentile=(0.5, 99.5), + ) + batch_fixed = BatchedImages([fixed]) + batch_moving = BatchedImages([moving]) + return batch_fixed, batch_moving + + +# --------------------------------------------------------------------------- +# Registration (affine + greedy with restrict_deformations) +# --------------------------------------------------------------------------- + +def run_registration( + batch_fixed: BatchedImages, + batch_moving: BatchedImages, + restrict_deformations: list[float] | None, + scales: list[int] = (8, 4, 2, 1), + iterations: list[int] = (200, 200, 100, 50), + cc_kernel_size: int = 7, + smooth_grad_sigma: float = 1.0, + smooth_warp_sigma: float = 0.5, + optimizer: str = "adam", + optimizer_lr: float = 0.5, +): + """Run affine then greedy deformable registration. Optionally restrict deformation directions.""" + # Affine + # affine = AffineRegistration( + # list(scales), + # list(iterations), + # batch_fixed, + # batch_moving, + # optimizer=optimizer, + # optimizer_lr=3e-3, + # cc_kernel_size=cc_kernel_size, + # ) + # t0 = time() + # affine.optimize() + # if batch_fixed().is_cuda: + # torch.cuda.synchronize() + # print(f" Affine done in {time() - t0:.1f}s") + + # Greedy deformable with optional direction restriction + optimizer_params = {} + if restrict_deformations is not None: + optimizer_params["restrict_deformations"] = restrict_deformations + + reg = GreedyRegistration( + scales=list(scales), + iterations=list(iterations), + fixed_images=batch_fixed, + moving_images=batch_moving, + cc_kernel_size=cc_kernel_size, + deformation_type="compositive", + smooth_grad_sigma=smooth_grad_sigma, + smooth_warp_sigma=smooth_warp_sigma, + optimizer=optimizer, + optimizer_lr=optimizer_lr, + optimizer_params=optimizer_params, + ) + t0 = time() + reg.optimize() + if batch_fixed().is_cuda: + torch.cuda.synchronize() + print(f" Greedy done in {time() - t0:.1f}s") + + return reg + + +# --------------------------------------------------------------------------- +# Visualization +# --------------------------------------------------------------------------- + +def save_slice_figure( + batch_fixed: BatchedImages, + batch_moving: BatchedImages, + moved: torch.Tensor, + subdir_name: str, + output_dir: Path, + slice_idx: int | None = None, +): + """Save a matplotlib figure with fixed, moved, moving (and optional slice index for 3D).""" + fixed_np = batch_fixed()[0, 0].detach().cpu().numpy() + moving_np = batch_moving()[0, 0].detach().cpu().numpy() + moved_np = moved[0, 0].detach().cpu().numpy() + + ndim = fixed_np.ndim + if ndim == 3: + d, h, w = fixed_np.shape + if slice_idx is None: + slice_idx = d // 2 + fixed_sl = fixed_np[slice_idx, :, :] + moved_sl = moved_np[slice_idx, :, :] + moving_sl = moving_np[slice_idx, :, :] + else: + fixed_sl = fixed_np + moved_sl = moved_np + moving_sl = moving_np + + fig, axes = plt.subplots(1, 3, figsize=(14, 5)) + axes[0].imshow(fixed_sl, cmap="gray") + axes[0].set_title("Fixed") + axes[0].axis("off") + axes[0].invert_yaxis() + + axes[1].imshow(moved_sl, cmap="gray") + axes[1].set_title("Moved (registered)") + axes[1].axis("off") + axes[1].invert_yaxis() + + axes[2].imshow(moving_sl, cmap="gray") + axes[2].set_title("Moving") + axes[2].axis("off") + axes[2].invert_yaxis() + + fig.suptitle(f"FireANTs — {subdir_name}", fontsize=14) + plt.tight_layout() + output_dir.mkdir(parents=True, exist_ok=True) + out_path = output_dir / f"{subdir_name}.png" + plt.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved {out_path}") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + data_dir = DATA_DIR.resolve() + if not data_dir.exists(): + print(f"Data directory not found: {data_dir}") + return + + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + for subdir_name, fixed_path, moving_path in discover_pairs(data_dir): + + print(f"\n--- {subdir_name} ---") + print(f" Fixed: {Path(fixed_path).name}") + print(f" Moving: {Path(moving_path).name}") + + restrict = get_restrict_deformations(subdir_name) + print(f" restrict_deformations: {restrict}") + + batch_fixed, batch_moving = load_pair_ras(fixed_path, moving_path, device=device) + print(f" batch_fixed: {batch_fixed.shape}") + print(f" batch_moving: {batch_moving.shape}") + + reg = run_registration( + batch_fixed, + batch_moving, + restrict_deformations=restrict, + ) + + moved = reg.evaluate(batch_fixed, batch_moving) + save_slice_figure(batch_fixed, batch_moving, moved, subdir_name, OUTPUT_DIR) + + # Construct output naming + moved_img_path = OUTPUT_DIR / f"{subdir_name}_moved.nii.gz" + moved_img_path_partial = OUTPUT_DIR / f"{subdir_name}_moved_partial.nii.gz" + warp_path = OUTPUT_DIR / f"{subdir_name}_warp.nii.gz" + + # Construct partial warp + partial_grid_params = reg.get_partial_warped_parameters(batch_fixed, batch_moving, 0.5) + moved_partial = reg.evaluate(batch_fixed, batch_moving, moved_coords=partial_grid_params) + batch_moved_partial = FakeBatchedImages(moved_partial, batch_fixed) + batch_moved_partial.write_image(str(moved_img_path_partial)) + print(f" Saved moved image at {moved_img_path_partial}") + + # Save the moved image as NIfTI (use .save_nifti from BatchedImages) + batch_moved = FakeBatchedImages(moved, batch_fixed) + batch_moved.write_image(str(moved_img_path)) + print(f" Saved moved image at {moved_img_path}") + + # Save the deformation field in ANTs-compatible format (using DeformableMixin) + reg.save_as_ants_transforms([str(warp_path)]) + print(f" Saved deformation field at {warp_path}") + + print("\nDone.") + + +if __name__ == "__main__": + main()