diff --git a/qwix/_src/core/ragged_dot_qt.py b/qwix/_src/core/ragged_dot_qt.py index 98c06f6..4192086 100644 --- a/qwix/_src/core/ragged_dot_qt.py +++ b/qwix/_src/core/ragged_dot_qt.py @@ -18,6 +18,7 @@ import functools import jax from jax import numpy as jnp +from qwix._src import interception from qwix._src.core import qarray from qwix._src.core import ragged_dot @@ -90,6 +91,7 @@ def _ragged_dot_general( return out.astype(result_type) +@interception.disable_interceptions def ragged_dot_qt_fwd( lhs: jax.Array, rhs: jax.Array,