Skip to content

Add transforms: grad, value_and_grad, vjp, jvp, vmap, compile, checkpoint#590

Merged
sydneyrenee merged 1 commit intomainfrom
feature/transforms
Mar 25, 2026
Merged

Add transforms: grad, value_and_grad, vjp, jvp, vmap, compile, checkpoint#590
sydneyrenee merged 1 commit intomainfrom
feature/transforms

Conversation

@sydneyrenee
Copy link
Copy Markdown
Contributor

Summary

Implements all MLX function transforms, crossing the N-API boundary with JS function wrapping.

Function transformers (return new callable functions):

  • grad(fn, argnums?) - compute gradients of scalar-valued functions
  • value_and_grad(fn, argnums?) - returns both loss value and gradients
  • vmap(fn, in_axes?, out_axes?) - vectorize over batch dimensions
  • compile(fn, shapeless?) - JIT compile computation graphs
  • checkpoint(fn) - recompute intermediates during backprop

Immediate transforms (compute and return results):

  • vjp(fn, primals, cotangents) - vector-Jacobian product
  • jvp(fn, primals, tangents) - Jacobian-vector product

Global flags:

  • enable_compile() / disable_compile()

Architecture

JS functions are wrapped as C++ std::function via persistent Napi references, passed to MLX transform APIs (grad, compile, etc.), and the resulting transformed functions are wrapped back as callable JS functions. Closure data is stored with GC-safe cleanup via Napi::External attached to the function object.

Issues addressed

Test plan

  • 9 new tests in batch6-transforms.test.ts, all passing
  • Full suite: 239 passing, 9 pending (pre-existing)
  • grad of x^2 returns 2x (single and multi-element)
  • value_and_grad returns both loss and gradient
  • Multi-arg grad with argnums selection
  • vjp and jvp produce correct derivatives
  • compile and checkpoint wrap functions correctly

…oint

Function-wrapping transforms that cross the N-API boundary: JS functions
are wrapped as C++ std::function via persistent references, passed to
MLX transform APIs, and the resulting transformed functions are wrapped
back as callable JS functions with GC-safe closure data.

grad/value_and_grad: compute gradients of scalar-valued functions
vjp/jvp: vector-Jacobian and Jacobian-vector products (immediate)
vmap: vectorize a function over batch dimensions
compile: JIT compile computation graphs
checkpoint: recompute intermediates during backprop
enable_compile/disable_compile: global compilation flags
@sydneyrenee sydneyrenee merged commit f0e507f into main Mar 25, 2026
1 check failed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

1 participant