Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions fireants/registration/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def save_moved_images(self, moved_images: Union[BatchedImages, FakeBatchedImages
moved_images_save.write_image(filenames)


def evaluate_inverse(self, fixed_images: Union[BatchedImages, torch.Tensor], moving_images: Union[BatchedImages, torch.Tensor], shape=None, **kwargs):
def evaluate_inverse(self, fixed_images: Union[BatchedImages, torch.Tensor], moving_images: Union[BatchedImages, torch.Tensor], shape=None, fixed_moved_coords=None, **kwargs):
''' Apply the inverse of the learned transformation to new images.

This method is useful to analyse the effect of how the moving coordinates (fixed images) are transformed
Expand All @@ -342,12 +342,13 @@ def evaluate_inverse(self, fixed_images: Union[BatchedImages, torch.Tensor], mov
moving_images = FakeBatchedImages(moving_images, self.moving_images)

fixed_arrays = moving_images()
fixed_moved_coords = self.get_inverse_warp_parameters(fixed_images, moving_images, shape=shape, **kwargs)
if fixed_moved_coords is None:
fixed_moved_coords = self.get_inverse_warp_parameters(fixed_images, moving_images, shape=shape, **kwargs)
fixed_moved_image = fireants_interpolator(fixed_arrays, **fixed_moved_coords, mode='bilinear', align_corners=True) # [N, C, H, W, [D]]
return fixed_moved_image


def evaluate(self, fixed_images: Union[BatchedImages, torch.Tensor], moving_images: Union[BatchedImages, torch.Tensor], shape=None):
def evaluate(self, fixed_images: Union[BatchedImages, torch.Tensor], moving_images: Union[BatchedImages, torch.Tensor], shape=None, moved_coords=None):
'''Apply the learned transformation to new images.

This method applies the registration transformation learned during optimization
Expand All @@ -364,6 +365,7 @@ def evaluate(self, fixed_images: Union[BatchedImages, torch.Tensor], moving_imag
moving_images (BatchedImages): Moving images to be transformed
shape (Optional[Tuple[int, ...]]): Optional output shape for the transformed image.
If None, uses the shape of the fixed image.
moved_coords (Optional[Dict]): Optional dictionary of moved coordinates. If None, the coordinates are computed using the `get_warp_parameters` method.

Returns:
torch.Tensor: The transformed moving image in the space of the fixed image.
Expand All @@ -382,7 +384,8 @@ def evaluate(self, fixed_images: Union[BatchedImages, torch.Tensor], moving_imag
moving_images = FakeBatchedImages(moving_images, self.moving_images)

moving_arrays = moving_images()
moved_coords = self.get_warp_parameters(fixed_images, moving_images, shape=shape)
if moved_coords is None:
moved_coords = self.get_warp_parameters(fixed_images, moving_images, shape=shape)
interpolate_mode = moving_images.get_interpolator_type()
moved_image = fireants_interpolator(moving_arrays, **moved_coords, mode=interpolate_mode, align_corners=True) # [N, C, H, W, [D]]
return moved_image
Expand Down
6 changes: 3 additions & 3 deletions fireants/registration/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from fireants.io.image import BatchedImages, FakeBatchedImages
from torch.optim import SGD, Adam
from torch.nn import functional as F
from fireants.utils.globals import MIN_IMG_SIZE
from fireants.registration.helpers import downsample_size
from tqdm import tqdm
import numpy as np
from fireants.losses.cc import gaussian_1d, separable_filtering
Expand Down Expand Up @@ -239,8 +239,8 @@ def optimize(self):
if hasattr(self.loss_fn, 'set_current_scale_and_iterations'):
self.loss_fn.set_current_scale_and_iterations(scale, iters)
# downsample fixed array and retrieve coords
size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in fixed_size]
mov_size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in moving_arrays.shape[2:]]
size_down = downsample_size(fixed_size, scale)
mov_size_down = downsample_size(list(moving_arrays.shape[2:]), scale)
# downsample
if self.blur and scale > 1:
sigmas = 0.5 * torch.tensor(
Expand Down
25 changes: 24 additions & 1 deletion fireants/registration/deformablemixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from fireants.interpolator import fireants_interpolator
from fireants.utils.globals import MIN_IMG_SIZE
from fireants.io.image import BatchedImages
from fireants.io.image import BatchedImages, FakeBatchedImages
from fireants.registration.abstract import AbstractRegistration
from fireants.registration.deformation.svf import StationaryVelocity
from fireants.registration.deformation.compositive import CompositiveWarp
Expand All @@ -55,6 +55,29 @@ class DeformableMixin:
- get_warped_coordinates(): Method to get transformed coordinates
"""

@torch.no_grad()
def get_partial_warped_parameters(reg, fixed_images: Union[BatchedImages, FakeBatchedImages], moving_images: Union[BatchedImages, FakeBatchedImages], fraction: float, shape=None):
"""Return warp parameters with the grid scaled by a fraction in [0, 1].

Calls get_warp_parameters then multiplies the returned grid by fraction.
Same signature as get_warp_parameters with an additional fraction parameter.

Args:
fixed_images: Fixed/reference images.
moving_images: Moving images.
fraction: Scale factor for the grid, must be in [0, 1].
shape: Optional output shape (passed to get_warp_parameters).

Returns:
Dict with 'affine' and 'grid' keys; 'grid' is the original grid multiplied by fraction.
"""
if not (0 <= fraction <= 1):
raise ValueError("fraction must be in [0, 1], got %s" % fraction)
params = reg.get_warp_parameters(fixed_images, moving_images, shape=shape)
out = dict(params)
out["grid"] = params["grid"] * fraction
return out

@torch.no_grad()
def save_as_ants_transforms(reg, filenames: Union[str, List[str]], save_inverse=False):
"""Save deformation fields in ANTs-compatible format.
Expand Down
1 change: 0 additions & 1 deletion fireants/registration/distributed/ring_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ def distributed_grid_sampler_3d_bwd(grad_output, grad_image, grad_affine, grad_g
# Compute gradients for our rank's portion
grad_affine_buf = zeros_like_or_none(grad_affine)

# breakpoint()
fused_grid_sampler_3d_backward(
grad_output,
grad_image, grad_affine_buf, grad_grid,
Expand Down
7 changes: 4 additions & 3 deletions fireants/registration/greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tqdm import tqdm
import SimpleITK as sitk

from fireants.utils.globals import MIN_IMG_SIZE
from fireants.registration.helpers import downsample_size
from fireants.io.image import BatchedImages, FakeBatchedImages
from fireants.registration.abstract import AbstractRegistration
from fireants.registration.deformation.svf import StationaryVelocity
Expand Down Expand Up @@ -281,8 +281,9 @@ def optimize(self):
if hasattr(self.loss_fn, 'set_current_scale_and_iterations'):
self.loss_fn.set_current_scale_and_iterations(scale, iters)
# resize images
size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in fixed_size]
moving_size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in moving_size]
size_down = downsample_size(fixed_size, scale)
moving_size_down = downsample_size(moving_size, scale)

if self.blur and scale > 1:
sigmas = 0.5 * torch.tensor(
[sz / szdown for sz, szdown in zip(fixed_size, size_down)],
Expand Down
41 changes: 41 additions & 0 deletions fireants/registration/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2026 Rohit Jena. All rights reserved.
#
# This file is part of FireANTs, distributed under the terms of
# the FireANTs License version 1.0. A copy of the license can be found
# in the LICENSE file at the root of this repository.
#
# IMPORTANT: This code is part of FireANTs and its use, reproduction, or
# distribution must comply with the full license terms, including:
# - Maintaining all copyright notices and bibliography references
# - Using only approved (re)-distribution channels
# - Proper attribution in derivative works
#
# For full license details, see: https://github.com/rohitrango/FireANTs/blob/main/LICENSE


from typing import List, Sequence

from fireants.utils.globals import MIN_IMG_SIZE


def downsample_size(
size: Sequence[int],
scale: float,
min_img_size: int = MIN_IMG_SIZE,
) -> List[int]:
"""Compute spatial size after downsampling by scale, with a minimum size per dimension.

When scale > 1, each dimension is divided by scale and floored, but not below min_img_size.
When scale <= 1, the original size is returned unchanged.

Args:
size: Spatial dimensions (e.g. [H, W] or [D, H, W]).
scale: Downsampling factor (e.g. 2 for half resolution).
min_img_size: Minimum value per dimension when downsampling.

Returns:
List of downsampled spatial dimensions.
"""
if scale > 1:
return [max(int(s / scale), min_img_size) for s in size]
return list(size)
14 changes: 14 additions & 0 deletions fireants/registration/optimizers/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(self, warp, lr,
smoothing_gaussians=None,
grad_gaussians=None,
freeform=False,
restrict_deformations=None,
offload=False, # try offloading to CPU
reset=False,
# distributed params
Expand Down Expand Up @@ -138,6 +139,13 @@ def __init__(self, warp, lr,
self.padding_smoothing = 0
# get wrapper around smoothing for distributed / not distributed
self.smoothing_wrapper = _get_smoothing_wrapper(self)
# gradient restriction (e.g. to restrict deformations along certain dimensions)
if restrict_deformations is None:
self.gradient_restriction = lambda x: x
else:
logger.info(f"Setting restriction to: {restrict_deformations}")
self._restrict_deformations = torch.as_tensor(restrict_deformations)
self.gradient_restriction = lambda x: x * self._restrict_deformations.to(x.device).to(x.dtype)

def cleanup(self):
# manually clean up
Expand Down Expand Up @@ -232,6 +240,9 @@ def step(self):
# adam_update_fused(grad, self.exp_avg, self.exp_avg_sq, beta_correction1, beta_correction2, self.eps)
self.adam_update_kernel(grad, self.exp_avg, self.exp_avg_sq, beta_correction1, beta_correction2, self.eps)

# apply gradient restriction (e.g. restrict deformations along certain dims)
grad = self.gradient_restriction(grad)

# we offload this to CPU
if self.offload:
# torch.cuda.synchronize()
Expand Down Expand Up @@ -278,4 +289,7 @@ def step(self):
# smooth result if asked for
if self.smoothing_gaussians is not None:
grad = self.smoothing_wrapper(grad, self.smoothing_gaussians, self.padding_smoothing)

grad = self.gradient_restriction(grad)

self.warp.data.copy_(grad)
10 changes: 10 additions & 0 deletions fireants/registration/optimizers/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, warp, lr,
momentum=0, dampening=0, weight_decay=0, nesterov=False, scaledown=False, multiply_jacobian=False,
smoothing_gaussians=None, grad_gaussians=None,
freeform=False,
restrict_deformations=None,
# distributed params
rank: int = 0,
dim_to_shard: int = 0,
Expand Down Expand Up @@ -85,6 +86,12 @@ def __init__(self, warp, lr,
self.padding_smoothing = 0
# get wrapper around smoothing for distributed / not distributed
self.smoothing_wrapper = _get_smoothing_wrapper(self)
# gradient restriction (e.g. to restrict deformations along certain dimensions)
if restrict_deformations is None:
self.gradient_restriction = lambda x: x
else:
self._restrict_deformations = torch.as_tensor(restrict_deformations)
self.gradient_restriction = lambda x: x * self._restrict_deformations.to(x.device).to(x.dtype)

def cleanup(self):
# manually clean up
Expand Down Expand Up @@ -151,6 +158,9 @@ def step(self):
else:
# grad = buf
grad.copy_(buf)
# apply gradient restriction (e.g. restrict deformations along certain dims)
grad = self.gradient_restriction(grad)

## renormalize and update warp (per pixel)
gradmax = self.eps + grad.norm(p=2, dim=-1, keepdim=True)
# gradmean = gradmax.flatten(1).mean(1) # [B,]
Expand Down
6 changes: 3 additions & 3 deletions fireants/registration/rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import numpy as np
from fireants.losses.cc import gaussian_1d, separable_filtering
from fireants.utils.imageutils import downsample
from fireants.utils.globals import MIN_IMG_SIZE
from fireants.registration.helpers import downsample_size
from fireants.interpolator import fireants_interpolator
import logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -294,8 +294,8 @@ def optimize(self):
if hasattr(self.loss_fn, 'set_current_scale_and_iterations'):
self.loss_fn.set_current_scale_and_iterations(scale, iters)
# downsample fixed array and retrieve coords
size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in fixed_size]
mov_size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in moving_arrays.shape[2:]]
size_down = downsample_size(fixed_size, scale)
mov_size_down = downsample_size(list(moving_arrays.shape[2:]), scale)
# downsample
if self.blur and scale > 1:
sigmas = 0.5 * torch.tensor(
Expand Down
8 changes: 5 additions & 3 deletions fireants/registration/syn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from fireants.io.image import BatchedImages, FakeBatchedImages
from torch.optim import SGD, Adam
from torch.nn import functional as F
from fireants.utils.globals import MIN_IMG_SIZE
from fireants.registration.helpers import downsample_size
from fireants.registration.abstract import AbstractRegistration
from fireants.registration.deformation.compositive import CompositiveWarp
from fireants.registration.deformation.svf import StationaryVelocity
Expand Down Expand Up @@ -256,8 +256,9 @@ def optimize(self):
# notify loss function of scale change if it supports it
if hasattr(self.loss_fn, 'set_current_scale_and_iterations'):
self.loss_fn.set_current_scale_and_iterations(scale, iters)
# resize images
size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in fixed_size]
# resize images
size_down = downsample_size(fixed_size, scale)

if self.blur and scale > 1:
sigmas = 0.5 * torch.tensor(
[sz / szdown for sz, szdown in zip(fixed_size, size_down)],
Expand Down Expand Up @@ -291,6 +292,7 @@ def optimize(self):
scale_factor = 1
else:
scale_factor = np.prod(fixed_image_down.shape)

for i in pbar:
# set zero grads
self.fwd_warp.set_zero_grad()
Expand Down
Loading