Skip to content

GragasLab/nufftax

Repository files navigation

nufftax logo

Pure JAX implementation of the Non-Uniform Fast Fourier Transform (NUFFT)

CI Documentation Python 3.12+ License: MIT


MRI reconstruction example

Why nufftax?

A JAX package for NUFFT already exists: jax-finufft. However, it wraps the C++ FINUFFT library via Foreign Function Interface (FFI), exposing it through custom XLA calls. This approach can lead to:

  • Kernel fusion issues on GPU — custom XLA calls act as optimization barriers, preventing XLA from fusing operations
  • CUDA version matching — GPU support requires matching CUDA versions between JAX and the library

nufftax takes a different approach — pure JAX implementation:

  • Fully differentiable — gradients w.r.t. both values and sample locations
  • Pure JAX — works with jit, grad, vmap, jvp, vjp with no FFI barriers
  • GPU ready — runs on CPU/GPU without code changes, benefits from XLA fusion
  • Pallas GPU kernels — fused Triton spreading kernels with 5-75x speedups on A100/H100
  • All NUFFT types — Type 1, 2, 3 in 1D, 2D, 3D

JAX Transformation Support

Transform jit grad/vjp jvp vmap
Type 1 (1D/2D/3D)
Type 2 (1D/2D/3D)
Type 3 (1D/2D/3D)

Differentiable inputs:

  • Type 1: grad w.r.t. c (strengths) and x, y, z (coordinates)
  • Type 2: grad w.r.t. f (Fourier modes) and x, y, z (coordinates)
  • Type 3: grad w.r.t. c (strengths), x, y, z (source coordinates), and s, t, u (target frequencies)

GPU Acceleration

On GPU, nufftax automatically dispatches spreading and interpolation to fused Pallas (Triton) kernels when the problem is large enough. This avoids materializing O(M × nspread^d) intermediate tensors and uses atomic scatter-add for spreading.

Operation Backend Speedup vs pure JAX
1D spread A100 5–67x (M ≥ 100K)
1D spread H100 4–75x (M ≥ 100K)
2D spread A100/H100 2–3x (M ≥ 100K)

The dispatch is transparent — no code changes required. On CPU or for small problems, the pure JAX path is used.

Installation

CPU only:

uv pip install nufftax

With CUDA 12 GPU support:

uv pip install "nufftax[cuda12]"

Development install (from source):

git clone https://github.com/GragasLab/nufftax.git
cd nufftax
uv pip install -e ".[dev]"

This installs test dependencies (pytest, ruff, finufft for comparison testing, pre-commit).

Development install with CUDA 12:

uv pip install -e ".[dev,cuda12]"

With docs dependencies:

uv pip install -e ".[docs]"

Quick Example

import jax
import jax.numpy as jnp
from nufftax import nufft1d1

# Irregular sample locations in [-pi, pi)
x = jnp.array([0.1, 0.7, 1.3, 2.1, -0.5])
c = jnp.array([1.0+0.5j, 0.3-0.2j, 0.8+0.1j, 0.2+0.4j, 0.5-0.3j])

# Compute Fourier modes
f = nufft1d1(x, c, n_modes=32, eps=1e-6)

# Differentiate through the transform
grad_c = jax.grad(lambda c: jnp.sum(jnp.abs(nufft1d1(x, c, n_modes=32)) ** 2))(c)

Documentation

Read the full documentation →

License

MIT. Algorithm based on FINUFFT by the Flatiron Institute.

Citation

If you use nufftax in your research, please cite:

@software{nufftax,
  author = {Gragas and Oudoumanessah, Geoffroy and Iollo, Jacopo},
  title = {nufftax: Pure JAX implementation of the Non-Uniform Fast Fourier Transform},
  url = {https://github.com/GragasLab/nufftax},
  year = {2026}
}

@article{finufft,
  author = {Barnett, Alexander H. and Magland, Jeremy F. and af Klinteberg, Ludvig},
  title = {A parallel non-uniform fast Fourier transform library based on an ``exponential of semicircle'' kernel},
  journal = {SIAM J. Sci. Comput.},
  volume = {41},
  number = {5},
  pages = {C479--C504},
  year = {2019}
}

About

Pure JAX implementation of the Non-Uniform Fast Fourier Transform (NUFFT)

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages