Pure JAX implementation of the Non-Uniform Fast Fourier Transform (NUFFT)
A JAX package for NUFFT already exists: jax-finufft. However, it wraps the C++ FINUFFT library via Foreign Function Interface (FFI), exposing it through custom XLA calls. This approach can lead to:
- Kernel fusion issues on GPU — custom XLA calls act as optimization barriers, preventing XLA from fusing operations
- CUDA version matching — GPU support requires matching CUDA versions between JAX and the library
nufftax takes a different approach — pure JAX implementation:
- Fully differentiable — gradients w.r.t. both values and sample locations
- Pure JAX — works with
jit,grad,vmap,jvp,vjpwith no FFI barriers - GPU ready — runs on CPU/GPU without code changes, benefits from XLA fusion
- Pallas GPU kernels — fused Triton spreading kernels with 5-75x speedups on A100/H100
- All NUFFT types — Type 1, 2, 3 in 1D, 2D, 3D
| Transform | jit |
grad/vjp |
jvp |
vmap |
|---|---|---|---|---|
| Type 1 (1D/2D/3D) | ✅ | ✅ | ✅ | ✅ |
| Type 2 (1D/2D/3D) | ✅ | ✅ | ✅ | ✅ |
| Type 3 (1D/2D/3D) | ✅ | ✅ | ✅ | ✅ |
Differentiable inputs:
- Type 1:
gradw.r.t.c(strengths) andx, y, z(coordinates) - Type 2:
gradw.r.t.f(Fourier modes) andx, y, z(coordinates) - Type 3:
gradw.r.t.c(strengths),x, y, z(source coordinates), ands, t, u(target frequencies)
On GPU, nufftax automatically dispatches spreading and interpolation to fused Pallas (Triton) kernels when the problem is large enough. This avoids materializing O(M × nspread^d) intermediate tensors and uses atomic scatter-add for spreading.
| Operation | Backend | Speedup vs pure JAX |
|---|---|---|
| 1D spread | A100 | 5–67x (M ≥ 100K) |
| 1D spread | H100 | 4–75x (M ≥ 100K) |
| 2D spread | A100/H100 | 2–3x (M ≥ 100K) |
The dispatch is transparent — no code changes required. On CPU or for small problems, the pure JAX path is used.
CPU only:
uv pip install nufftaxWith CUDA 12 GPU support:
uv pip install "nufftax[cuda12]"Development install (from source):
git clone https://github.com/GragasLab/nufftax.git
cd nufftax
uv pip install -e ".[dev]"This installs test dependencies (pytest, ruff, finufft for comparison testing, pre-commit).
Development install with CUDA 12:
uv pip install -e ".[dev,cuda12]"With docs dependencies:
uv pip install -e ".[docs]"import jax
import jax.numpy as jnp
from nufftax import nufft1d1
# Irregular sample locations in [-pi, pi)
x = jnp.array([0.1, 0.7, 1.3, 2.1, -0.5])
c = jnp.array([1.0+0.5j, 0.3-0.2j, 0.8+0.1j, 0.2+0.4j, 0.5-0.3j])
# Compute Fourier modes
f = nufft1d1(x, c, n_modes=32, eps=1e-6)
# Differentiate through the transform
grad_c = jax.grad(lambda c: jnp.sum(jnp.abs(nufft1d1(x, c, n_modes=32)) ** 2))(c)- Quickstart — get running in 5 minutes
- Concepts — understand the mathematics
- Tutorials — MRI reconstruction, spectral analysis, optimization
- API Reference — complete function reference
MIT. Algorithm based on FINUFFT by the Flatiron Institute.
If you use nufftax in your research, please cite:
@software{nufftax,
author = {Gragas and Oudoumanessah, Geoffroy and Iollo, Jacopo},
title = {nufftax: Pure JAX implementation of the Non-Uniform Fast Fourier Transform},
url = {https://github.com/GragasLab/nufftax},
year = {2026}
}
@article{finufft,
author = {Barnett, Alexander H. and Magland, Jeremy F. and af Klinteberg, Ludvig},
title = {A parallel non-uniform fast Fourier transform library based on an ``exponential of semicircle'' kernel},
journal = {SIAM J. Sci. Comput.},
volume = {41},
number = {5},
pages = {C479--C504},
year = {2019}
}
