Skip to content

Expression Support#215

Open
hfytr wants to merge 16 commits intoexanauts:mainfrom
hfytr:optimized
Open

Expression Support#215
hfytr wants to merge 16 commits intoexanauts:mainfrom
hfytr:optimized

Conversation

@hfytr
Copy link
Copy Markdown
Collaborator

@hfytr hfytr commented Feb 5, 2026

This PR changes expressions to no longer re-evaluate on every reference, and instead pre-computes once for every call to jac_coord / hess_coord / etc.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Feb 5, 2026

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/ext/ExaModelsMOI.jl b/ext/ExaModelsMOI.jl
index 174b8f4..cb10a32 100644
--- a/ext/ExaModelsMOI.jl
+++ b/ext/ExaModelsMOI.jl
@@ -315,8 +315,8 @@ function exafy_con(
         set = MOI.get(moim, MOI.ConstraintSet(), ci)
         con_to_idx[ci] = offset + i
         start = if MOI.supports(
-            moim, MOI.ConstraintPrimalStart(), typeof(ci)
-        )
+                moim, MOI.ConstraintPrimalStart(), typeof(ci)
+            )
             MOI.get(moim, MOI.ConstraintPrimalStart(), ci)
         else
             nothing
diff --git a/src/gradient.jl b/src/gradient.jl
index 03986e0..e124a56 100644
--- a/src/gradient.jl
+++ b/src/gradient.jl
@@ -8,25 +8,25 @@ Performs dense gradient evaluation via the reverse pass on the computation (sub)
 - `y`: result vector
 - `adj`: adjoint propagated up to the current node
 """
-@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNull}
+@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D <: AdjointNull}
     nothing
 end
-@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNode1}
+@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D <: AdjointNode1}
     drpass(e, e_starts, e_cnts, d.inner, y, adj * d.y)
     nothing
 end
-@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNode2}
+@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D <: AdjointNode2}
     drpass(e, e_starts, e_cnts, d.inner1, y, adj * d.y1)
     drpass(e, e_starts, e_cnts, d.inner2, y, adj * d.y2)
     nothing
 end
-@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNodeVar}
+@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D <: AdjointNodeVar}
     @inbounds y[d.i] += adj
     nothing
 end
-@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNodeExpr}
+@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D <: AdjointNodeExpr}
     y[d.i] += e[e_starts[d.i][2]]
-    nothing
+    return nothing
 end
 
 """
diff --git a/src/graph.jl b/src/graph.jl
index d19f56e..8b469b9 100644
--- a/src/graph.jl
+++ b/src/graph.jl
@@ -134,7 +134,7 @@ struct Identity end
 @inline (v::Var{I})(i, x, θ) where {I<:AbstractNode} = @inbounds x[v.i(i, x, θ)]
 @inline (v::Var{I})(i, x, θ) where {I} = @inbounds x[v.i]
 
-@inline (e::Exp{I})(i, x, θ) where {I<:AbstractNode} = @inbounds x[e.i(i, x, θ)]
+@inline (e::Exp{I})(i, x, θ) where {I <: AbstractNode} = @inbounds x[e.i(i, x, θ)]
 @inline (e::Exp{I})(i, x, θ) where {I} = @inbounds x[e.i]
 
 @inline (v::ParameterNode{I})(i, x, θ) where {I<:AbstractNode} = @inbounds θ[v.i(i, x, θ)]
@@ -200,7 +200,7 @@ struct AdjointNodeVar{I,T} <: AbstractAdjointNode
     x::T
 end
 
-struct AdjointNodeExpr{I,T} <: AbstractAdjointNode
+struct AdjointNodeExpr{I, T} <: AbstractAdjointNode
     i::I
     x::T
 end
@@ -213,7 +213,7 @@ A source of `AdjointNode`. `adjoint_node_source[i]` returns an `AdjointNodeVar`
 # Fields:
 - `inner::VT`: variable vector
 """
-struct AdjointNodeSource{VT,OE}
+struct AdjointNodeSource{VT, OE}
     inner::VT
     offset_exps::OE
 end
@@ -223,23 +223,23 @@ end
 @inline AdjointNode2(f::F, x::T, y1, y2, inner1::I1, inner2::I2) where {F,T,I1,I2} =
     AdjointNode2{F,T,I1,I2}(x, y1, y2, inner1, inner2)
 
-@inline function Base.getindex(x::I, i) where {I<:AdjointNodeSource{Nothing,Nothing}}
+@inline function Base.getindex(x::I, i) where {I <: AdjointNodeSource{Nothing, Nothing}}
     (i, isexp, theta) = i
-    @inbounds isexp ? AdjointNodeExpr(i, NaN) : AdjointNodeVar(i, NaN)
+    return @inbounds isexp ? AdjointNodeExpr(i, NaN) : AdjointNodeVar(i, NaN)
 end
-@inline function Base.getindex(x::I, i) where {I<:AdjointNodeSource{Nothing}}
+@inline function Base.getindex(x::I, i) where {I <: AdjointNodeSource{Nothing}}
     (i, isexp, theta) = i
-    if isexp
+    return if isexp
         dump(i)
-        offset = typeof(i) <: Node2{typeof(+),T,Int} where T ? i.inner2 : 0
+        offset = typeof(i) <: Node2{typeof(+), T, Int} where {T} ? i.inner2 : 0
         x.offset_exps[offset].f.f(Identity(), x, theta, i)
     else
         AdjointNodeVar(i, NaN)
     end
 end
-@inline function Base.getindex(x::I, i) where {I<:AdjointNodeSource}
+@inline function Base.getindex(x::I, i) where {I <: AdjointNodeSource}
     (i, isexp, theta) = i
-    @inbounds isexp ? AdjointNodeExpr(i, x.inner[i]) : AdjointNodeVar(i, x.inner[i])
+    return @inbounds isexp ? AdjointNodeExpr(i, x.inner[i]) : AdjointNodeVar(i, x.inner[i])
 end
 
 """
@@ -299,7 +299,7 @@ struct SecondAdjointNodeVar{I,T} <: AbstractSecondAdjointNode
     x::T
 end
 
-struct SecondAdjointNodeExpr{I,T} <: AbstractSecondAdjointNode
+struct SecondAdjointNodeExpr{I, T} <: AbstractSecondAdjointNode
     i::I
     x::T
 end
@@ -313,7 +313,7 @@ A source of `AdjointNode`. `adjoint_node_source[i]` returns an `AdjointNodeVar`
 - `inner::VT`: variable vector
 - 'isexp::VTI': expression vector
 """
-struct SecondAdjointNodeSource{VT,OE}
+struct SecondAdjointNodeSource{VT, OE}
     inner::VT
     offset_exps::OE
 end
@@ -333,22 +333,22 @@ end
 ) where {F,T,I1,I2} =
     SecondAdjointNode2{F,T,I1,I2}(x, y1, y2, h11, h12, h22, inner1, inner2)
 
-@inline function Base.getindex(x::I, i) where {I<:SecondAdjointNodeSource{Nothing,Nothing}}
+@inline function Base.getindex(x::I, i) where {I <: SecondAdjointNodeSource{Nothing, Nothing}}
     (i, isexp, theta) = i
-    @inbounds isexp ? SecondAdjointNodeExpr(i, NaN) : SecondAdjointNodeVar(i, NaN)
+    return @inbounds isexp ? SecondAdjointNodeExpr(i, NaN) : SecondAdjointNodeVar(i, NaN)
 end
-@inline function Base.getindex(x::I, i) where {I<:SecondAdjointNodeSource{Nothing}}
+@inline function Base.getindex(x::I, i) where {I <: SecondAdjointNodeSource{Nothing}}
     (i, isexp, theta) = i
-    if isexp
-        offset = typeof(i) <: Node2{typeof(+),T,Int} where T ? i.inner2 : 0
+    return if isexp
+        offset = typeof(i) <: Node2{typeof(+), T, Int} where {T} ? i.inner2 : 0
         x.offset_exps[offset].f.f(Identity(), x, theta, i)
     else
         SecondAdjointNodeVar(i, NaN)
     end
 end
-@inline function Base.getindex(x::I, i) where {I<:SecondAdjointNodeSource}
+@inline function Base.getindex(x::I, i) where {I <: SecondAdjointNodeSource}
     (i, isexp, theta) = i
-    @inbounds isexp ? SecondAdjointNodeExpr(i, x.inner[i]) : SecondAdjointNodeVar(i, x.inner[i])
+    return @inbounds isexp ? SecondAdjointNodeExpr(i, x.inner[i]) : SecondAdjointNodeVar(i, x.inner[i])
 end
 
 @inline (v::Null{Nothing})(i, x::V, θ) where {T,V<:AbstractVector{T}} = zero(T)
@@ -356,29 +356,29 @@ end
 @inline (v::Null{N})(i, x::AdjointNodeSource{T}, θ) where {N,T} = AdjointNull()
 @inline (v::Null{N})(i, x::SecondAdjointNodeSource{T}, θ) where {N,T} = SecondAdjointNull()
 
-const NodeSource = Union{AdjointNodeSource,SecondAdjointNodeSource}
+const NodeSource = Union{AdjointNodeSource, SecondAdjointNodeSource}
 
-@inline (v::Var{I})(i, x::AdjointNodeSource, θ) where {I<:AbstractNode} = @inbounds AdjointNodeVar(v.i(i, x, θ), NaN)
-@inline (v::Var{I})(i, x::SecondAdjointNodeSource, θ) where {I<:AbstractNode} = @inbounds SecondAdjointNodeVar(v.i(i, x, θ), NaN)
+@inline (v::Var{I})(i, x::AdjointNodeSource, θ) where {I <: AbstractNode} = @inbounds AdjointNodeVar(v.i(i, x, θ), NaN)
+@inline (v::Var{I})(i, x::SecondAdjointNodeSource, θ) where {I <: AbstractNode} = @inbounds SecondAdjointNodeVar(v.i(i, x, θ), NaN)
 
 @inline (v::Var{I})(i, x::AdjointNodeSource, θ) where {I} = @inbounds AdjointNodeVar(i, NaN)
 @inline (v::Var{I})(i, x::SecondAdjointNodeSource, θ) where {I} = @inbounds SecondAdjointNodeVar(i, NaN)
 
-@inline (v::Var{I})(i::Identity, x::AdjointNodeSource, θ) where {I<:AbstractNode} = @inbounds AdjointNodeVar(v.i, NaN)
-@inline (v::Var{I})(i::Identity, x::SecondAdjointNodeSource, θ) where {I<:AbstractNode} = @inbounds SecondAdjointNodeVar(v.i, NaN)
+@inline (v::Var{I})(i::Identity, x::AdjointNodeSource, θ) where {I <: AbstractNode} = @inbounds AdjointNodeVar(v.i, NaN)
+@inline (v::Var{I})(i::Identity, x::SecondAdjointNodeSource, θ) where {I <: AbstractNode} = @inbounds SecondAdjointNodeVar(v.i, NaN)
 
-@inline (v::Exp{I})(i, x::AdjointNodeSource, θ) where {I<:AbstractNode} = @inbounds AdjointNodeExpr(v.i(i, x, θ), NaN)
-@inline (v::Exp{I})(i, x::SecondAdjointNodeSource, θ) where {I<:AbstractNode} = @inbounds SecondAdjointNodeExpr(v.i(i, x, θ), NaN)
+@inline (v::Exp{I})(i, x::AdjointNodeSource, θ) where {I <: AbstractNode} = @inbounds AdjointNodeExpr(v.i(i, x, θ), NaN)
+@inline (v::Exp{I})(i, x::SecondAdjointNodeSource, θ) where {I <: AbstractNode} = @inbounds SecondAdjointNodeExpr(v.i(i, x, θ), NaN)
 
-@inline function (e::Exp{I})(i::Identity, x::AdjointNodeSource, θ) where {I<:AbstractNode}
-    offset = typeof(e.i) <: Node2{typeof(+),T,Int} where T ? e.i.inner2 : 0
-    x.offset_exps[offset].f.f(e.i(i, x, θ), x, θ)
+@inline function (e::Exp{I})(i::Identity, x::AdjointNodeSource, θ) where {I <: AbstractNode}
+    offset = typeof(e.i) <: Node2{typeof(+), T, Int} where {T} ? e.i.inner2 : 0
+    return x.offset_exps[offset].f.f(e.i(i, x, θ), x, θ)
 end
 
-@inline function (e::Exp{I})(i::Identity, x::SecondAdjointNodeSource, θ) where {I<:AbstractNode}
-    offset = typeof(e.i) <: Node2{typeof(+),T,Int} where T ? e.i.inner2 : 0
-    x.offset_exps[offset].f.f(e.i(i, x, θ), x, θ)
+@inline function (e::Exp{I})(i::Identity, x::SecondAdjointNodeSource, θ) where {I <: AbstractNode}
+    offset = typeof(e.i) <: Node2{typeof(+), T, Int} where {T} ? e.i.inner2 : 0
+    return x.offset_exps[offset].f.f(e.i(i, x, θ), x, θ)
 end
 
-@inline (v::Exp{I})(i, x::X, θ) where {I, X<:AdjointNodeSource} = @inbounds AdjointNodeExpr(i, NaN)
-@inline (v::Exp{I})(i, x::X, θ) where {I, X<:SecondAdjointNodeSource} = @inbounds SecondAdjointNodeExpr(i, NaN)
+@inline (v::Exp{I})(i, x::X, θ) where {I, X <: AdjointNodeSource} = @inbounds AdjointNodeExpr(i, NaN)
+@inline (v::Exp{I})(i, x::X, θ) where {I, X <: SecondAdjointNodeSource} = @inbounds SecondAdjointNodeExpr(i, NaN)
diff --git a/src/hessian.jl b/src/hessian.jl
index 396beee..d76c97b 100644
--- a/src/hessian.jl
+++ b/src/hessian.jl
@@ -18,9 +18,9 @@ Performs sparse hessian evaluation (`(df1/dx)(df2/dx)'` portion) via the reverse
 - `adj`: second adjoint propagated up to the current node
 """
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::T1,
     t2::T2,
     comp,
@@ -34,9 +34,9 @@ Performs sparse hessian evaluation (`(df1/dx)(df2/dx)'` portion) via the reverse
     cnt
 end
 function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::SecondAdjointNode1,
     t2::SecondAdjointNode1,
     comp::Nothing,
@@ -52,9 +52,9 @@ end
 
 
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::T1,
     t2::T2,
     comp,
@@ -63,14 +63,14 @@ end
     o2,
     cnt,
     adj,
-) where {T1<:Union{SecondAdjointNodeVar,SecondAdjointNodeExpr},T2<:SecondAdjointNode1}
+    ) where {T1 <: Union{SecondAdjointNodeVar, SecondAdjointNodeExpr}, T2 <: SecondAdjointNode1}
     cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner, comp, y1, y2, o2, cnt, adj * t2.y)
     cnt
 end
 function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::SecondAdjointNodeVar,
     t2::SecondAdjointNode1,
     comp::Nothing,
@@ -81,29 +81,29 @@ function hdrpass(
     adj,
 )  # despecialized
     cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner, comp, y1, y2, o2, cnt, adj * t2.y)
-    cnt
+    return cnt
 end
 function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
-    t1::SecondAdjointNodeExpr,
-    t2::SecondAdjointNode1,
-    comp::Nothing,
-    y1,
-    y2,
-    o2,
-    cnt,
-    adj,
-)  # despecialized
+        e,
+        e_starts,
+        e_cnts,
+        t1::SecondAdjointNodeExpr,
+        t2::SecondAdjointNode1,
+        comp::Nothing,
+        y1,
+        y2,
+        o2,
+        cnt,
+        adj,
+    )  # despecialized
     cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner, comp, y1, y2, o2, cnt, adj * t2.y)
     cnt
 end
 
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::T1,
     t2::T2,
     comp,
@@ -112,14 +112,14 @@ end
     o2,
     cnt,
     adj,
-) where {T1<:SecondAdjointNode1,T2<:Union{SecondAdjointNodeVar,SecondAdjointNodeExpr}}
+    ) where {T1 <: SecondAdjointNode1, T2 <: Union{SecondAdjointNodeVar, SecondAdjointNodeExpr}}
     cnt = hdrpass(e, e_starts, e_cnts, t1.inner, t2, comp, y1, y2, o2, cnt, adj * t1.y)
     cnt
 end
 function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::SecondAdjointNode1,
     t2::SecondAdjointNodeVar,
     comp::Nothing,
@@ -130,30 +130,30 @@ function hdrpass(
     adj,
 )  # despecialized
     cnt = hdrpass(e, e_starts, e_cnts, t1.inner, t2, comp, y1, y2, o2, cnt, adj * t1.y)
-    cnt
+    return cnt
 end
 function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
-    t1::SecondAdjointNode1,
-    t2::SecondAdjointNodeExpr,
-    comp::Nothing,
-    y1,
-    y2,
-    o2,
-    cnt,
-    adj,
-)  # despecialized
+        e,
+        e_starts,
+        e_cnts,
+        t1::SecondAdjointNode1,
+        t2::SecondAdjointNodeExpr,
+        comp::Nothing,
+        y1,
+        y2,
+        o2,
+        cnt,
+        adj,
+    )  # despecialized
     cnt = hdrpass(e, e_starts, e_cnts, t1.inner, t2, comp, y1, y2, o2, cnt, adj * t1.y)
     cnt
 end
 
 
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::T1,
     t2::T2,
     comp,
@@ -170,9 +170,9 @@ end
     cnt
 end
 function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::SecondAdjointNode2,
     t2::SecondAdjointNode2,
     comp::Nothing,
@@ -191,9 +191,9 @@ end
 
 
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::T1,
     t2::T2,
     comp,
@@ -208,9 +208,9 @@ end
     cnt
 end
 function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::SecondAdjointNode1,
     t2::SecondAdjointNode2,
     comp::Nothing,
@@ -226,9 +226,9 @@ function hdrpass(
 end
 
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::T1,
     t2::T2,
     comp,
@@ -243,9 +243,9 @@ end
     cnt
 end
 function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::SecondAdjointNode2,
     t2::SecondAdjointNode1,
     comp::Nothing,
@@ -261,9 +261,9 @@ function hdrpass(
 end
 
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::T1,
     t2::T2,
     comp,
@@ -272,15 +272,15 @@ end
     o2,
     cnt,
     adj,
-) where {T1<:Union{SecondAdjointNodeVar,SecondAdjointNodeExpr},T2<:SecondAdjointNode2}
+    ) where {T1 <: Union{SecondAdjointNodeVar, SecondAdjointNodeExpr}, T2 <: SecondAdjointNode2}
     cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner1, comp, y1, y2, o2, cnt, adj * t2.y1)
     cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner2, comp, y1, y2, o2, cnt, adj * t2.y2)
     cnt
 end
 function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::SecondAdjointNodeVar,
     t2::SecondAdjointNode2,
     comp::Nothing,
@@ -292,30 +292,30 @@ function hdrpass(
 ) # despecialized
     cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner1, comp, y1, y2, o2, cnt, adj * t2.y1)
     cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner2, comp, y1, y2, o2, cnt, adj * t2.y2)
-    cnt
+    return cnt
 end
 function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
-    t1::SecondAdjointNodeExpr,
-    t2::SecondAdjointNode2,
-    comp::Nothing,
-    y1,
-    y2,
-    o2,
-    cnt,
-    adj,
-) # despecialized
+        e,
+        e_starts,
+        e_cnts,
+        t1::SecondAdjointNodeExpr,
+        t2::SecondAdjointNode2,
+        comp::Nothing,
+        y1,
+        y2,
+        o2,
+        cnt,
+        adj,
+    ) # despecialized
     cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner1, comp, y1, y2, o2, cnt, adj * t2.y1)
     cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner2, comp, y1, y2, o2, cnt, adj * t2.y2)
     cnt
 end
 
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::T1,
     t2::T2,
     comp,
@@ -324,15 +324,15 @@ end
     o2,
     cnt,
     adj,
-) where {T1<:SecondAdjointNode2,T2<:Union{SecondAdjointNodeVar,SecondAdjointNodeExpr}}
+    ) where {T1 <: SecondAdjointNode2, T2 <: Union{SecondAdjointNodeVar, SecondAdjointNodeExpr}}
     cnt = hdrpass(e, e_starts, e_cnts, t1.inner1, t2, comp, y1, y2, o2, cnt, adj * t1.y1)
     cnt = hdrpass(e, e_starts, e_cnts, t1.inner2, t2, comp, y1, y2, o2, cnt, adj * t1.y2)
     cnt
 end
 function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::SecondAdjointNode2,
     t2::SecondAdjointNodeVar,
     comp::Nothing,
@@ -344,30 +344,30 @@ function hdrpass(
 ) # despecialized
     cnt = hdrpass(e, e_starts, e_cnts, t1.inner1, t2, comp, y1, y2, o2, cnt, adj * t1.y1)
     cnt = hdrpass(e, e_starts, e_cnts, t1.inner2, t2, comp, y1, y2, o2, cnt, adj * t1.y2)
-    cnt
+    return cnt
 end
 function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
-    t1::SecondAdjointNode2,
-    t2::SecondAdjointNodeExpr,
-    comp::Nothing,
-    y1,
-    y2,
-    o2,
-    cnt,
-    adj,
-) # despecialized
+        e,
+        e_starts,
+        e_cnts,
+        t1::SecondAdjointNode2,
+        t2::SecondAdjointNodeExpr,
+        comp::Nothing,
+        y1,
+        y2,
+        o2,
+        cnt,
+        adj,
+    ) # despecialized
     cnt = hdrpass(e, e_starts, e_cnts, t1.inner1, t2, comp, y1, y2, o2, cnt, adj * t1.y1)
     cnt = hdrpass(e, e_starts, e_cnts, t1.inner2, t2, comp, y1, y2, o2, cnt, adj * t1.y2)
     cnt
 end
 
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::T1,
     t2::T2,
     comp,
@@ -379,72 +379,72 @@ end
 ) where {T1<:SecondAdjointNodeVar,T2<:SecondAdjointNodeVar}
     i, j = t1.i, t2.i
     @inbounds if i == j
-        y1[o2+comp(cnt += 1)] += 2 * adj
+        y1[o2 + comp(cnt += 1)] += 2 * adj
     else
-        y1[o2+comp(cnt += 1)] += adj
+        y1[o2 + comp(cnt += 1)] += adj
     end
     cnt
 end
 
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
-    t1::T1,
-    t2::T2,
-    comp,
-    y1,
-    y2,
-    o2,
-    cnt,
-    adj,
-) where {T1<:SecondAdjointNodeExpr,T2<:SecondAdjointNodeVar}
+        e,
+        e_starts,
+        e_cnts,
+        t1::T1,
+        t2::T2,
+        comp,
+        y1,
+        y2,
+        o2,
+        cnt,
+        adj,
+    ) where {T1 <: SecondAdjointNodeExpr, T2 <: SecondAdjointNodeVar}
     (cnt_start, e_start) = e_starts[t1.i]
     len = e_cnts[cnt_start]
     cnt += 1
     for i in 1:len
-        @inbounds y1[o2+comp(cnt)] += e[e_start+i-1] * adj
-        cnt += e_cnts[cnt_start+i]
+        @inbounds y1[o2 + comp(cnt)] += e[e_start + i - 1] * adj
+        cnt += e_cnts[cnt_start + i]
     end
     return cnt
 end
 
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
-    t1::T1,
-    t2::T2,
-    comp,
-    y1,
-    y2,
-    o2,
-    cnt,
-    adj,
-) where {T1<:SecondAdjointNodeVar,T2<:SecondAdjointNodeExpr}
+        e,
+        e_starts,
+        e_cnts,
+        t1::T1,
+        t2::T2,
+        comp,
+        y1,
+        y2,
+        o2,
+        cnt,
+        adj,
+    ) where {T1 <: SecondAdjointNodeVar, T2 <: SecondAdjointNodeExpr}
     (cnt_start, e_start) = e_starts[t2.i]
     len = e_cnts[cnt_start]
     cnt += 1
     for i in 1:len
-        @inbounds y1[o2+comp(cnt)] += e[e_start+i-1] * adj
-        cnt += e_cnts[cnt_start+i]
+        @inbounds y1[o2 + comp(cnt)] += e[e_start + i - 1] * adj
+        cnt += e_cnts[cnt_start + i]
     end
     return cnt
 end
 
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
-    t1::T1,
-    t2::T2,
-    comp,
-    y1,
-    y2,
-    o2,
-    cnt,
-    adj,
-) where {T1<:SecondAdjointNodeExpr,T2<:SecondAdjointNodeExpr}
+        e,
+        e_starts,
+        e_cnts,
+        t1::T1,
+        t2::T2,
+        comp,
+        y1,
+        y2,
+        o2,
+        cnt,
+        adj,
+    ) where {T1 <: SecondAdjointNodeExpr, T2 <: SecondAdjointNodeExpr}
     (cnt_start1, e_start1) = e_starts[t1.i]
     len1 = e_cnts[cnt_start1]
     (cnt_start2, e_start2) = e_starts[t2.i]
@@ -452,27 +452,27 @@ end
 
     cnt += 1
     for i in 1:len1
-        val1 = e[e_start1+i-1]
+        val1 = e[e_start1 + i - 1]
         for j in 1:len2
-            val2 = e[e_start2+j-1]
+            val2 = e[e_start2 + j - 1]
             ind = o2 + comp(cnt)
             @inbounds if t1.i == t2.i && i == j
                 y1[ind] += 2 * val1 * val2 * adj
             else
                 y1[ind] += val1 * val2 * adj
             end
-            cnt += e_cnts[cnt_start2+j]
+            cnt += e_cnts[cnt_start2 + j]
         end
-        cnt += e_cnts[cnt_start1+i]
+        cnt += e_cnts[cnt_start1 + i]
     end
     return cnt
 end
 
 
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::T1,
     t2::T2,
     comp,
@@ -521,12 +521,12 @@ Performs sparse hessian evaluation (`d²f/dx²` portion) via the reverse pass on
 """
 
 @inline function hrpass(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
     t::D,
     comp,
     y1,
@@ -540,12 +540,12 @@ Performs sparse hessian evaluation (`d²f/dx²` portion) via the reverse pass on
 end
 
 @inline function hrpass(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
     t::D,
     comp,
     y1,
@@ -554,39 +554,39 @@ end
     cnt,
     adj,
     adj2,
-) where {D<:SecondAdjointNodeExpr}
+    ) where {D <: SecondAdjointNodeExpr}
     (cnt_start2, e_start2) = e2_starts[t.i]
     len2 = e2_cnts[cnt_start2]
     cnt += 1
     for i in 1:len2
-        @inbounds y1[o2+comp(cnt)] += adj * e2[e_start2+i-1]
-        cnt += e2_cnts[cnt_start2+i]
+        @inbounds y1[o2 + comp(cnt)] += adj * e2[e_start2 + i - 1]
+        cnt += e2_cnts[cnt_start2 + i]
     end
     return cnt
 end
 
 @inline function hrpass(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
-    t::D,
-    comp,
-    y1::V,
-    y2::V,
-    o2,
-    cnt,
-    adj,
-    adj2,
-) where {D<:SecondAdjointNodeExpr,I<:Integer,V<:AbstractVector{I}}
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
+        t::D,
+        comp,
+        y1::V,
+        y2::V,
+        o2,
+        cnt,
+        adj,
+        adj2,
+    ) where {D <: SecondAdjointNodeExpr, I <: Integer, V <: AbstractVector{I}}
     (cnt_start2, e_start2) = e2_starts[t.i]
     len2 = e2_cnts[cnt_start2]
     cnt += 1
     for i in 1:len2
         ind = o2 + comp(cnt)
-        val = e2[e_start2+i-1]
+        val = e2[e_start2 + i - 1]
         r = unpack_row(val)
         c = unpack_col(val)
         if y1 === y2
@@ -599,38 +599,38 @@ end
                 @inbounds y2[ind] = c
             end
         end
-        cnt += e2_cnts[cnt_start2+i]
+        cnt += e2_cnts[cnt_start2 + i]
     end
     return cnt
 end
 
 @inline function hrpass(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
-    t::D,
-    comp,
-    y1,
-    y2,
-    o2,
-    cnt,
-    adj,
-    adj2,
-) where {D<:SecondAdjointNode1}
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
+        t::D,
+        comp,
+        y1,
+        y2,
+        o2,
+        cnt,
+        adj,
+        adj2,
+    ) where {D <: SecondAdjointNode1}
     cnt = hrpass(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t.inner, comp, y1, y2, o2, cnt, adj * t.y, adj2 * (t.y)^2 + adj * t.h)
     cnt
 end
 
 @inline function hrpass(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
     t::D,
     comp,
     y1,
@@ -639,7 +639,7 @@ end
     cnt,
     adj,
     adj2,
-) where {D<:SecondAdjointNode2}
+    ) where {D <: SecondAdjointNode2}
     adj2y1y2 = adj2 * t.y1 * t.y2
     adjh12 = adj * t.h12
     cnt = hrpass(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t.inner1, comp, y1, y2, o2, cnt, adj * t.y1, adj2 * (t.y1)^2 + adj * t.h11)
@@ -651,12 +651,12 @@ end
 @inline hrpass0(args...) = hrpass(args...)
 
 @inline function hrpass0(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
     t::D,
     comp,
     y1,
@@ -671,12 +671,12 @@ end
 end
 
 @inline function hrpass0(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
     t::D,
     comp,
     y1,
@@ -691,12 +691,12 @@ end
 end
 
 @inline function hrpass0(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
     t::D,
     comp,
     y1,
@@ -711,12 +711,12 @@ end
 end
 
 @inline function hrpass0(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
     t::D,
     comp,
     y1,
@@ -731,12 +731,12 @@ end
 end
 
 @inline function hrpass0(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
     t::D,
     comp,
     y1,
@@ -751,12 +751,12 @@ end
 end
 
 @inline function hrpass0(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
     t::D,
     comp,
     y1,
@@ -771,12 +771,12 @@ end
 end
 
 @inline function hrpass0(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
     t::D,
     comp,
     y1,
@@ -792,12 +792,12 @@ end
 end
 
 @inline function hrpass0(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
     t::D,
     comp,
     y1,
@@ -813,12 +813,12 @@ end
 end
 
 @inline function hrpass0(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
     t::T,
     comp,
     y1,
@@ -832,12 +832,12 @@ end
 end
 
 @inline function hrpass0(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
     t::T,
     comp::Nothing,
     y1,
@@ -851,27 +851,27 @@ end
 end
 
 @inline function hrpass0(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
-    t::T,
-    comp,
-    y1,
-    y2,
-    o2,
-    cnt,
-    adj,
-    adj2,
-) where {T<:SecondAdjointNodeExpr}
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
+        t::T,
+        comp,
+        y1,
+        y2,
+        o2,
+        cnt,
+        adj,
+        adj2,
+    ) where {T <: SecondAdjointNodeExpr}
     (cnt_start2, e_start2) = e2_starts[t.i]
     len2 = e2_cnts[cnt_start2]
     cnt += 1
     for i in 1:len2
-        @inbounds y1[o2+comp(cnt)] += adj * e2[e_start2+i-1]
-        cnt += e2_cnts[cnt_start2+i]
+        @inbounds y1[o2 + comp(cnt)] += adj * e2[e_start2 + i - 1]
+        cnt += e2_cnts[cnt_start2 + i]
     end
 
 
@@ -879,51 +879,51 @@ end
 end
 
 function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
-    t1::T1,
-    t2::T2,
+        e,
+        e_starts,
+        e_cnts,
+        t1::T1,
+        t2::T2,
     comp::Nothing,
     y1,
     y2,
     o2,
     cnt,
     adj,
-) where {T1<:SecondAdjointNodeVar, T2<:SecondAdjointNodeVar}
+    ) where {T1 <: SecondAdjointNodeVar, T2 <: SecondAdjointNodeVar}
     cnt += 1
     push!(y1, (t1.i, t2.i))
     cnt
 end
 
 function hrpass(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
-    t::T,
-    comp::Nothing,
-    y1,
-    y2,
-    o2,
-    cnt,
-    adj,
-    adj2
-) where {T<:SecondAdjointNodeVar}
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
+        t::T,
+        comp::Nothing,
+        y1,
+        y2,
+        o2,
+        cnt,
+        adj,
+        adj2
+    ) where {T <: SecondAdjointNodeVar}
     cnt += 1
     push!(y1, (t.i, t.i))
     cnt
 end
 
 @inline function hrpass(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
     t::T,
     comp,
     y1::Tuple{V1,V2},
@@ -939,12 +939,12 @@ end
 end
 
 @inline function hrpass(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
     t::T,
     comp,
     y1,
@@ -954,17 +954,17 @@ end
     adj,
     adj2,
 ) where {T<:SecondAdjointNodeVar}
-    @inbounds y1[o2+comp(cnt += 1)] += adj2
+    @inbounds y1[o2 + comp(cnt += 1)] += adj2
     cnt
 end
 
 @inline function hrpass(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
     t::T,
     comp,
     y1::V,
@@ -989,12 +989,12 @@ end
 end
 
 @inline function hrpass(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
     t::T,
     comp,
     y1::V,
@@ -1014,9 +1014,9 @@ end
 @inline unpack_col(v) = Int(v & 0xFFFFFFFF)
 
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::T1,
     t2::T2,
     comp,
@@ -1052,9 +1052,9 @@ end
 end
 
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     t1::T1,
     t2::T2,
     comp,
@@ -1080,27 +1080,27 @@ end
 end
 
 @inline function hrpass0(
-    e,
-    e_starts,
-    e_cnts,
-    e2,
-    e2_starts,
-    e2_cnts,
-    t::T,
-    comp,
-    y1::V,
-    y2::V,
-    o2,
-    cnt,
-    adj,
-    adj2,
-) where {T<:SecondAdjointNodeExpr,I<:Integer,V<:AbstractVector{I}}
+        e,
+        e_starts,
+        e_cnts,
+        e2,
+        e2_starts,
+        e2_cnts,
+        t::T,
+        comp,
+        y1::V,
+        y2::V,
+        o2,
+        cnt,
+        adj,
+        adj2,
+    ) where {T <: SecondAdjointNodeExpr, I <: Integer, V <: AbstractVector{I}}
     (cnt_start2, e_start2) = e2_starts[t.i]
     len2 = e2_cnts[cnt_start2]
     cnt += 1
     for i in 1:len2
         ind = o2 + comp(cnt)
-        val = e2[e_start2+i-1]
+        val = e2[e_start2 + i - 1]
         r = unpack_row(val)
         c = unpack_col(val)
         if y1 === y2
@@ -1113,31 +1113,31 @@ end
                 @inbounds y2[ind] = c
             end
         end
-        cnt += e2_cnts[cnt_start2+i]
+        cnt += e2_cnts[cnt_start2 + i]
     end
     return cnt
 end
 
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
-    t1::T1,
-    t2::T2,
-    comp,
-    y1::V,
-    y2::V,
-    o2,
-    cnt,
-    adj,
-) where {T1<:SecondAdjointNodeExpr,T2<:SecondAdjointNodeVar,I<:Integer,V<:AbstractVector{I}}
+        e,
+        e_starts,
+        e_cnts,
+        t1::T1,
+        t2::T2,
+        comp,
+        y1::V,
+        y2::V,
+        o2,
+        cnt,
+        adj,
+    ) where {T1 <: SecondAdjointNodeExpr, T2 <: SecondAdjointNodeVar, I <: Integer, V <: AbstractVector{I}}
     (cnt_start, e_start) = e_starts[t1.i]
     len = e_cnts[cnt_start]
     j = t2.i
     cnt += 1
     for i in 1:len
         ind = o2 + comp(cnt)
-        idx = e[e_start+i-1]
+        idx = e[e_start + i - 1]
         if y1 === y2
             if idx != 0 || j != 0
                 @inbounds if idx >= j
@@ -1157,31 +1157,31 @@ end
                 end
             end
         end
-        cnt += e_cnts[cnt_start+i]
+        cnt += e_cnts[cnt_start + i]
     end
     return cnt
 end
 
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
-    t1::T1,
-    t2::T2,
-    comp,
-    y1::V,
-    y2::V,
-    o2,
-    cnt,
-    adj,
-) where {T1<:SecondAdjointNodeVar,T2<:SecondAdjointNodeExpr,I<:Integer,V<:AbstractVector{I}}
+        e,
+        e_starts,
+        e_cnts,
+        t1::T1,
+        t2::T2,
+        comp,
+        y1::V,
+        y2::V,
+        o2,
+        cnt,
+        adj,
+    ) where {T1 <: SecondAdjointNodeVar, T2 <: SecondAdjointNodeExpr, I <: Integer, V <: AbstractVector{I}}
     i = t1.i
     (cnt_start, e_start) = e_starts[t2.i]
     len = e_cnts[cnt_start]
     cnt += 1
     for k in 1:len
         ind = o2 + comp(cnt)
-        idx = e[e_start+k-1]
+        idx = e[e_start + k - 1]
         if y1 === y2
             if i != 0 || idx != 0
                 @inbounds if i >= idx
@@ -1201,24 +1201,24 @@ end
                 end
             end
         end
-        cnt += e_cnts[cnt_start+k]
+        cnt += e_cnts[cnt_start + k]
     end
     return cnt
 end
 
 @inline function hdrpass(
-    e,
-    e_starts,
-    e_cnts,
-    t1::T1,
-    t2::T2,
-    comp,
-    y1::V,
-    y2::V,
-    o2,
-    cnt,
-    adj,
-) where {T1<:SecondAdjointNodeExpr,T2<:SecondAdjointNodeExpr,I<:Integer,V<:AbstractVector{I}}
+        e,
+        e_starts,
+        e_cnts,
+        t1::T1,
+        t2::T2,
+        comp,
+        y1::V,
+        y2::V,
+        o2,
+        cnt,
+        adj,
+    ) where {T1 <: SecondAdjointNodeExpr, T2 <: SecondAdjointNodeExpr, I <: Integer, V <: AbstractVector{I}}
     (cnt_start1, e_start1) = e_starts[t1.i]
     len1 = e_cnts[cnt_start1]
     (cnt_start2, e_start2) = e_starts[t2.i]
@@ -1226,9 +1226,9 @@ end
 
     cnt += 1
     for i in 1:len1
-        idx1 = e[e_start1+i-1]
+        idx1 = e[e_start1 + i - 1]
         for j in 1:len2
-            idx2 = e[e_start2+j-1]
+            idx2 = e[e_start2 + j - 1]
             ind = o2 + comp(cnt)
             if y1 === y2
                 if idx1 != 0 || idx2 != 0
@@ -1249,9 +1249,9 @@ end
                     end
                 end
             end
-            cnt += e_cnts[cnt_start2+j]
+            cnt += e_cnts[cnt_start2 + j]
         end
-        cnt += e_cnts[cnt_start1+i]
+        cnt += e_cnts[cnt_start1 + i]
     end
     return cnt
 end
@@ -1297,7 +1297,7 @@ function shessian!(y1, y2, f, x, θ, e1, e1_starts, e1_cnts, e2, e2_starts, e2_c
     end
 end
 
-function shessian!(y1, y2, f, x, θ, e1, e1_starts, e1_cnts, e2, e2_starts, e2_cnts, adj1s::V, adj2, isexp) where {V<:AbstractVector}
+function shessian!(y1, y2, f, x, θ, e1, e1_starts, e1_cnts, e2, e2_starts, e2_cnts, adj1s::V, adj2, isexp) where {V <: AbstractVector}
     @simd for k in eachindex(f.itr)
         @inbounds shessian!(
             y1,
@@ -1319,5 +1319,5 @@ end
 
 function shessian!(y1, y2, f, p, x, θ, e1, e1_starts, e1_cnts, e2, e2_starts, e2_cnts, comp, o2, adj1, adj2, isexp)
     graph = f(p, SecondAdjointNodeSource(x, nothing), θ)
-    hrpass0(e1, e1_starts, e1_cnts, e2, e2_starts, e2_cnts, graph, comp, y1, y2, o2, 0, adj1, adj2)
+    return hrpass0(e1, e1_starts, e1_cnts, e2, e2_starts, e2_cnts, graph, comp, y1, y2, o2, 0, adj1, adj2)
 end
diff --git a/src/jacobian.jl b/src/jacobian.jl
index 21ab3d0..41f8df5 100644
--- a/src/jacobian.jl
+++ b/src/jacobian.jl
@@ -15,9 +15,9 @@ Performs sparse jacobian evaluation via the reverse pass on the computation (sub
 """
 @inline function jrpass(
     d::D,
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     comp,
     i,
     y1,
@@ -28,26 +28,26 @@ Performs sparse jacobian evaluation via the reverse pass on the computation (sub
 ) where {D<:Union{AdjointNull,Real}}
     return cnt
 end
-@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D<:AdjointNode1}
+@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D <: AdjointNode1}
     cnt = jrpass(d.inner, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj * d.y)
     return cnt
 end
-@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D<:AdjointNode2}
+@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D <: AdjointNode2}
     cnt = jrpass(d.inner1, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj * d.y1)
     cnt = jrpass(d.inner2, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj * d.y2)
     return cnt
 end
 # jac_coord
-@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D<:AdjointNodeVar}
+@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D <: AdjointNodeVar}
     @inbounds y1[o1+comp(cnt+=1)] += adj
     return cnt
 end
-@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D<:AdjointNodeExpr}
+@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D <: AdjointNodeExpr}
     (cnt_start, e_start) = e_starts[d.i]
     len = e_cnts[cnt_start]
     cnt += 1
     for i in 1:len
-        @inbounds y1[o1+comp(cnt)] += adj * e[e_start + i - 1]
+        @inbounds y1[o1 + comp(cnt)] += adj * e[e_start + i - 1]
         cnt += e_cnts[cnt_start + i]
     end
     return cnt
@@ -55,13 +55,13 @@ end
 # jprod_nln
 @inline function jrpass(
     d::D,
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     comp,
-    o0,
+        o0,
     y1::Tuple{V1,V2},
-    y2::Nothing,
+        y2::Nothing,
     o1,
     cnt,
     adj,
@@ -74,12 +74,12 @@ end
 # jtprod_nln
 @inline function jrpass(
     d::D,
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     comp,
-    o0,
-    y1::Nothing,
+        o0,
+        y1::Nothing,
     y2::Tuple{V1,V2},
     o1,
     cnt,
@@ -93,11 +93,11 @@ end
 # jac_structure
 @inline function jrpass(
     d::D,
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     comp,
-    o0,
+        o0,
     y1::V,
     y2::V,
     o1,
@@ -111,17 +111,17 @@ end
 end
 @inline function jrpass(
     d::D,
-    e,
-    e_starts,
-    e_cnts,
+        e,
+        e_starts,
+        e_cnts,
     comp,
-    o0,
-    y1::V,
-    y2::V,
-    o1,
-    cnt,
-    adj,
-) where {D<:AdjointNodeExpr,I<:Integer,V<:AbstractVector{I}}
+        o0,
+        y1::V,
+        y2::V,
+        o1,
+        cnt,
+        adj,
+    ) where {D <: AdjointNodeExpr, I <: Integer, V <: AbstractVector{I}}
     (cnt_start, e_start) = e_starts[d.i]
     len = e_cnts[cnt_start]
     cnt += 1
@@ -135,35 +135,35 @@ end
 end
 # no rows when precomputing expressions
 @inline function jrpass(
-    d::D,
-    e,
-    e_starts,
-    e_cnts,
-    comp,
-    o0,
-    y1::Nothing,
-    y2::V,
-    o1,
-    cnt,
-    adj,
-) where {D<:AdjointNodeVar,I<:Integer,V<:AbstractVector{I}}
+        d::D,
+        e,
+        e_starts,
+        e_cnts,
+        comp,
+        o0,
+        y1::Nothing,
+        y2::V,
+        o1,
+        cnt,
+        adj,
+    ) where {D <: AdjointNodeVar, I <: Integer, V <: AbstractVector{I}}
     ind = o1 + comp(cnt += 1)
     @inbounds y2[ind] = d.i
     return cnt
 end
 @inline function jrpass(
-    d::D,
-    e,
-    e_starts,
-    e_cnts,
-    comp,
-    o0,
-    y1::Nothing,
-    y2::V,
-    o1,
-    cnt,
-    adj,
-) where {D<:AdjointNodeExpr,I<:Integer,V<:AbstractVector{I}}
+        d::D,
+        e,
+        e_starts,
+        e_cnts,
+        comp,
+        o0,
+        y1::Nothing,
+        y2::V,
+        o1,
+        cnt,
+        adj,
+    ) where {D <: AdjointNodeExpr, I <: Integer, V <: AbstractVector{I}}
     (cnt_start, e_start) = e_starts[d.i]
     len = e_cnts[cnt_start]
     cnt += 1
@@ -175,12 +175,12 @@ end
     return cnt
 end
 @inline function jrpass(
-    d::D,
-    e,
-    e_starts,
-    e_cnts,
-    comp,
-    o0,
+        d::D,
+        e,
+        e_starts,
+        e_cnts,
+        comp,
+        o0,
     y1::V,
     y2,
     o1,
@@ -228,5 +228,5 @@ end
 function sjacobian!(isexp, y1, y2, f, e, e_starts, e_cnts, p, x, θ, comp, o0, o1, adj)
     s = AdjointNodeSource(x, nothing)
     graph = f(p, s, θ)
-    jrpass(graph, e, e_starts, e_cnts, comp, o0, y1, y2, o1, 0, adj)
+    return jrpass(graph, e, e_starts, e_cnts, comp, o0, y1, y2, o1, 0, adj)
 end
diff --git a/src/nlp.jl b/src/nlp.jl
index b20e8d7..3034733 100644
--- a/src/nlp.jl
+++ b/src/nlp.jl
@@ -54,7 +54,7 @@ Objective
 )
 
 
-struct Expression{R,F,I,O,S} <: AbstractExpression
+struct Expression{R, F, I, O, S} <: AbstractExpression
     inner::R
     f::F
     itr::I
@@ -64,13 +64,13 @@ end
 Base.show(io::IO, v::Expression) = print(
     io,
     """
-Expression
+    Expression
 
-  s.t. (...)
-       g♭ ≤ [g(x,θ,p)]_{p ∈ P} ≤ g♯
+      s.t. (...)
+           g♭ ≤ [g(x,θ,p)]_{p ∈ P} ≤ g♯
 
-  where |P| = $(length(v.itr))
-""",
+      where |P| = $(length(v.itr))
+    """,
 )
 
 
@@ -92,7 +92,7 @@ Constraint
 """,
 )
 
-struct ExpressionAug{R,F,I} <: AbstractConstraint
+struct ExpressionAug{R, F, I} <: AbstractConstraint
     inner::R
     f::F
     itr::I
@@ -161,21 +161,21 @@ An ExaCore

"""
Base.@kwdef mutable struct ExaCore{

  • T,
  • VT<:AbstractVector{T},
  • VI<:AbstractVector{UInt},
  • DI<:Dict{Int,Expression},
  • B,
  • VII<:AbstractVector{Tuple{UInt,UInt}}
    -}
  •    T,
    
  •    VT <: AbstractVector{T},
    
  •    VI <: AbstractVector{UInt},
    
  •    DI <: Dict{Int, Expression},
    
  •    B,
    
  •    VII <: AbstractVector{Tuple{UInt, UInt}},
    
  • }
    backend::B = nothing
    obj::AbstractObjective = ObjectiveNull()
    con::AbstractConstraint = ConstraintNull()
    exp::AbstractExpression = ExpressionNull()

    corresponds to y1 and y2 in _simdfunction()

  • offset_exps::DI = Dict{Int,Expression}()
  • e1_starts::VII = convert_array(Tuple{UInt,UInt}[], backend)
  • e2_starts::VII = convert_array(Tuple{UInt,UInt}[], backend)
  • offset_exps::DI = Dict{Int, Expression}()
  • e1_starts::VII = convert_array(Tuple{UInt, UInt}[], backend)
  • e2_starts::VII = convert_array(Tuple{UInt, UInt}[], backend)
    e1_cnts::VI = convert_array(UInt[], backend)
    e2_cnts::VI = convert_array(UInt[], backend)
    e1_len::Int = 0
    @@ -205,15 +205,15 @@ end

Deprecated as of v0.7

function ExaCore(::Type{T}, backend) where {T<:AbstractFloat}
@warn "ExaCore(T, backend) is deprecated. Use ExaCore(T; backend = backend) instead"

  • return ExaCore(T; backend=backend)
  • return ExaCore(T; backend = backend)
    end
    function ExaCore(backend)
    @warn "ExaCore(backend) is deprecated. Use ExaCore(T; backend = backend) instead"
  • return ExaCore(; backend=backend)
  • return ExaCore(; backend = backend)
    end

-ExaCore(::Type{T}; backend=nothing, kwargs...) where {T<:AbstractFloat} =

  • ExaCore(x0=convert_array(zeros(T, 0), backend); backend=backend, kwargs...)
    +ExaCore(::Type{T}; backend = nothing, kwargs...) where {T <: AbstractFloat} =
  • ExaCore(x0 = convert_array(zeros(T, 0), backend); backend = backend, kwargs...)

depth(a) = depth(a.inner) + 1
depth(a::ObjectiveNull) = 0
@@ -234,7 +234,7 @@ An ExaCore
""",
)

-struct ExaModel{T,VT,VI,E,O,C,EX,VII} <: NLPModels.AbstractNLPModel{T,VT}
+struct ExaModel{T, VT, VI, E, O, C, EX, VII} <: NLPModels.AbstractNLPModel{T, VT}
objs::O
cons::C
exps::EX
@@ -300,7 +300,7 @@ julia> result = ipopt(m; print_level=0) # solve the problem

"""
-function ExaModel(c::C; prod=nothing) where {C<:ExaCore}
+function ExaModel(c::C; prod = nothing) where {C <: ExaCore}
    return ExaModel(
        c.obj,
        c.con,
@@ -317,16 +317,16 @@ function ExaModel(c::C; prod=nothing) where {C<:ExaCore}
        c.θ,
        NLPModels.NLPModelMeta(
            c.nvar,
-            ncon=c.ncon,
-            nnzj=c.nnzj,
-            nnzh=c.nnzh,
-            x0=c.x0,
-            lvar=c.lvar,
-            uvar=c.uvar,
-            y0=c.y0,
-            lcon=c.lcon,
-            ucon=c.ucon,
-            minimize=c.minimize,
+            ncon = c.ncon,
+            nnzj = c.nnzj,
+            nnzh = c.nnzh,
+            x0 = c.x0,
+            lvar = c.lvar,
+            uvar = c.uvar,
+            y0 = c.y0,
+            lcon = c.lcon,
+            ucon = c.ucon,
+            minimize = c.minimize,
        ),
        NLPModels.Counters(),
        nothing,
@@ -343,14 +343,14 @@ end
    Var(v.offset + idxx(is .- (_start.(v.size) .- 1), _length.(v.size)))
end

-@inline function Base.getindex(e::E, i) where {E<:Expression}
+@inline function Base.getindex(e::E, i) where {E <: Expression}
    _bound_check(e.size, i)
-    Exp(i + (e.offset - _start(e.size[1]) + 1))
+    return Exp(i + (e.offset - _start(e.size[1]) + 1))
end
-@inline function Base.getindex(e::E, is...) where {E<:Expression}
+@inline function Base.getindex(e::E, is...) where {E <: Expression}
    @assert(length(is) == length(e.size), "Expression index dimension error. Got $(length(is)) dimensions, expected $(length(e.size)).")
    _bound_check(e.size, is)
-    Exp(e.offset + idxx(is .- (_start.(e.size) .- 1), _length.(e.size)))
+    return Exp(e.offset + idxx(is .- (_start.(e.size) .- 1), _length.(e.size)))
end

@inline function Base.getindex(p::P, i) where {P<:Parameter}
@@ -453,14 +453,14 @@ Variable
function variable(
    c::C,
    ns...;
-    start=zero(T),
-    lvar=T(-Inf),
-    uvar=T(Inf),
-) where {T,C<:ExaCore{T}}
+        start = zero(T),
+        lvar = T(-Inf),
+        uvar = T(Inf),
+    ) where {T, C <: ExaCore{T}}
    o = c.nvar
    len = total(ns)
    c.nvar += len
-    c.varis = vcat(c.varis, (o+1):c.nvar)
+    c.varis = vcat(c.varis, (o + 1):c.nvar)
    append!(c.backend, c.isexp, 0, len)
    append!(c.backend, c.e1_starts, [(0, 0) for _ in 1:len], len)
    append!(c.backend, c.e2_starts, [(0, 0) for _ in 1:len], len)
@@ -574,7 +574,7 @@ end

Adds objective terms specified by a `expr` and `pars` to `core`, and returns an `Objective` object.
"""
-function objective(c::C, expr::N, pars=1:1) where {C<:ExaCore,N<:AbstractNode}
+function objective(c::C, expr::N, pars = 1:1) where {C <: ExaCore, N <: AbstractNode}
    f = _simdfunction(expr, c.offset_exps, c.exp, c.isexp, c.nobj, c.nnzg, c.nnzh)

    _objective(c, f, pars)
@@ -619,9 +619,9 @@ Constraint
function constraint(
    c::C,
    gen::Base.Generator;
-    start=zero(T),
-    lcon=zero(T),
-    ucon=zero(T),
+        start = zero(T),
+        lcon = zero(T),
+        ucon = zero(T),
) where {T,C<:ExaCore{T}}

    gen = _adapt_gen(gen)
@@ -639,10 +639,10 @@ Adds constraints specified by a `expr` and `pars` to `core`, and returns an `Con
function constraint(
    c::C,
    expr::N,
-    pars=1:1;
-    start=zero(T),
-    lcon=zero(T),
-    ucon=zero(T),
+        pars = 1:1;
+        start = zero(T),
+        lcon = zero(T),
+        ucon = zero(T),
) where {T,C<:ExaCore{T},N<:AbstractNode}

    f = _simdfunction(expr, c.offset_exps, c.isexp, c.ncon, c.nnzj, c.nnzh)
@@ -658,9 +658,9 @@ Adds empty constraints of dimension n, so that later the terms can be added with
function constraint(
    c::C,
    n;
-    start=zero(T),
-    lcon=zero(T),
-    ucon=zero(T),
+        start = zero(T),
+        lcon = zero(T),
+        ucon = zero(T),
) where {T,C<:ExaCore{T}}

    f = _simdfunction(Null(), c.offset_exps, c.isexp, c.ncon, c.nnzj, c.nnzh)
@@ -768,16 +768,16 @@ Expression
function subexpr(
        c::C,
        gen::I,
-    ) where {T, C <: ExaCore{T}, I<:Base.Iterators.Flatten}
-    ns=[]
+    ) where {T, C <: ExaCore{T}, I <: Base.Iterators.Flatten}
+    ns = []
    it = gen.it
    while typeof(it) <: Union{Base.Generator, Base.Iterators.Flatten}
        push!(ns, length(it))
        (it, _) = Base.iterate(it)
    end
-    subexpr(c, (nsi for nsi in ns), gen)
+    return subexpr(c, (nsi for nsi in ns), gen)
end
-subexpr(c::C, gen::G) where {T, C <: ExaCore{T}, G<:Base.Generator} = subexpr(c, Base.size(gen.iter), gen)
+subexpr(c::C, gen::G) where {T, C <: ExaCore{T}, G <: Base.Generator} = subexpr(c, Base.size(gen.iter), gen)

function subexpr(
        c::C,
@@ -810,7 +810,7 @@ function subexpr(
    return c.exp
end

-function simd_expr(c::ExaCore, gen,)
+function simd_expr(c::ExaCore, gen)
    f = gen.f(ParSource())
    nitr = length(gen.iter)

@@ -826,9 +826,11 @@ function simd_expr(c::ExaCore, gen,)
    o1step = length(a1)
    e1_cnts = compress_ref_cnts(y1, a1)
    c1 = Compressor(Tuple(findfirst(isequal(di), a1) for di in y1))
-    append!(c.backend, c.e1_starts, [
+    append!(
+        c.backend, c.e1_starts, [
            (length(c.e1_cnts) + 1, (i - 1) * o1step + c.e1_len + 1) for i in 1:nitr
-        ], nitr)
+        ], nitr
+    )
    o1 = c.e1_len
    c.e1_len += nitr * o1step
    push!(c.e1_cnts, o1step)
@@ -838,21 +840,23 @@ function simd_expr(c::ExaCore, gen,)
    e2_cnts = compress_ref_cnts(y2, a2)
    o2step = length(a2)
    c2 = Compressor(Tuple(findfirst(isequal(di), a2) for di in y2))
-    append!(c.backend, c.e2_starts, [
+    append!(
+        c.backend, c.e2_starts, [
            (length(c.e2_cnts) + 1, (i - 1) * o2step + c.e2_len + 1) for i in 1:nitr
-        ], nitr)
+        ], nitr
+    )
    o2 = c.e2_len
    c.e2_len += nitr * o2step
    push!(c.e2_cnts, o2step)
    append!(c.backend, c.e2_cnts, e2_cnts, length(e2_cnts))

-    SIMDFunction(f, c1, c2, c.nvar, o1, o2, o1step, o2step)
+    return SIMDFunction(f, c1, c2, c.nvar, o1, o2, o1step, o2step)
end

expr!(m, x, θ) = _expr!(m.exps, m, x, θ)
function _expr!(expr, m, x, θ)
    _expr!(expr.inner, m, x, θ)
-    @simd for i in eachindex(expr.itr)
+    return @simd for i in eachindex(expr.itr)
        x[offset0(expr, i)] = expr.f(expr.itr[i], x, θ)
    end
end
@@ -870,7 +874,7 @@ _jac_structure!(cons::ExpressionNull, m, e1_uint, rows, cols) = nothing
_jac_structure!(cons::ConstraintNull, m, e1_uint, rows, cols) = nothing
function _jac_structure!(f, m, e1_uint, rows, cols)
    _jac_structure!(f.inner, m, e1_uint, rows, cols)
-    sjacobian!(e1_uint, m.e1_starts, m.e1_cnts, m.isexp, rows, cols, f, nothing, nothing, NaN)
+    return sjacobian!(e1_uint, m.e1_starts, m.e1_cnts, m.isexp, rows, cols, f, nothing, nothing, NaN)
end

function hess_structure!(m::ExaModel, rows::AbstractVector, cols::AbstractVector)
@@ -888,22 +892,24 @@ end
_exp_hess_structure!(exps::ExpressionNull, m, e2_uint) = nothing
function _exp_hess_structure!(exps, m, e2_uint)
    _exp_hess_structure!(exps.inner, m, e2_uint)
-    shessian!(e2_uint, e2_uint, exps, nothing, nothing,
+    return shessian!(
+        e2_uint, e2_uint, exps, nothing, nothing,
        reinterpret(UInt, m.e1), m.e1_starts, m.e1_cnts,
        e2_uint, m.e2_starts, m.e2_cnts,
-        NaN, NaN, m.isexp)
+        NaN, NaN, m.isexp
+    )
end

_obj_hess_structure!(objs::ObjectiveNull, m, rows, cols, e1_uint, e2_uint) = nothing
function _obj_hess_structure!(objs, m, rows, cols, e1_uint, e2_uint)
    _obj_hess_structure!(objs.inner, m, rows, cols, e1_uint, e2_uint)
-    shessian!(rows, cols, objs, nothing, nothing, e1_uint, m.e1_starts, m.e1_cnts, e2_uint, m.e2_starts, m.e2_cnts, NaN, NaN, m.isexp)
+    return shessian!(rows, cols, objs, nothing, nothing, e1_uint, m.e1_starts, m.e1_cnts, e2_uint, m.e2_starts, m.e2_cnts, NaN, NaN, m.isexp)
end

_con_hess_structure!(cons::ConstraintNull, m, rows, cols, e1_uint, e2_uint) = nothing
function _con_hess_structure!(cons, m, rows, cols, e1_uint, e2_uint)
    _con_hess_structure!(cons.inner, m, rows, cols, e1_uint, e2_uint)
-    shessian!(rows, cols, cons, nothing, nothing, e1_uint, m.e1_starts, m.e1_cnts, e2_uint, m.e2_starts, m.e2_cnts, NaN, NaN, m.isexp)
+    return shessian!(rows, cols, cons, nothing, nothing, e1_uint, m.e1_starts, m.e1_cnts, e2_uint, m.e2_starts, m.e2_cnts, NaN, NaN, m.isexp)
end

function obj(m::ExaModel, x::AbstractVector)
@@ -943,7 +949,7 @@ _grad!(f::ObjectiveNull, m, x, out) = nothing
_grad!(f::ExpressionNull, m, x, out) = nothing
function _grad!(f, m, x, out)
    _grad!(f.inner, m, x, out)
-    gradient!(m.isexp, m.e1, m.e1_starts, m.e1_cnts, out, m.objs, x, m.θ, one(eltype(out)))
+    return gradient!(m.isexp, m.e1, m.e1_starts, m.e1_cnts, out, m.objs, x, m.θ, one(eltype(out)))
end

function jac_coord!(m::ExaModel, x::AbstractVector, jac::AbstractVector)
@@ -958,7 +964,7 @@ _jac_coord!(f::ConstraintNull, x, m, jac) = nothing
_jac_coord!(f::ExpressionNull, x, m, jac) = nothing
function _jac_coord!(f, x, m, jac)
    _jac_coord!(f.inner, x, m, jac)
-    sjacobian!(m.e1, m.e1_starts, m.e1_cnts, m.isexp, jac, nothing, f, x, m.θ, one(eltype(jac)))
+    return sjacobian!(m.e1, m.e1_starts, m.e1_cnts, m.isexp, jac, nothing, f, x, m.θ, one(eltype(jac)))
end

function jprod_nln!(m::ExaModel, x::AbstractVector, v::AbstractVector, Jv::AbstractVector)
@@ -973,7 +979,7 @@ _jprod_nln!(f::ConstraintNull, x, m, v, Jv) = nothing
_jprod_nln!(f::ExpressionNull, x, m, v, Jv) = nothing
function _jprod_nln!(f, x, m, v, Jv)
    _jprod_nln!(f.inner, x, m, v, Jv)
-    sjacobian!(m.e1, m.e1_starts, m.e1_cnts, m.isexp, (Jv, v), nothing, f, x, m.θ, one(eltype(Jv)))
+    return sjacobian!(m.e1, m.e1_starts, m.e1_cnts, m.isexp, (Jv, v), nothing, f, x, m.θ, one(eltype(Jv)))
end

function jtprod_nln!(m::ExaModel, x::AbstractVector, v::AbstractVector, Jtv::AbstractVector)
@@ -988,14 +994,14 @@ _jtprod_nln!(f::ConstraintNull, x, m, v, Jtv) = nothing
_jtprod_nln!(f::ExpressionNull, x, m, v, Jtv) = nothing
function _jtprod_nln!(f, x, m, v, Jtv)
    _jtprod_nln!(f.inner, x, m, v, Jtv)
-    sjacobian!(m.e1, m.e1_starts, m.e1_cnts, m.isexp, nothing, (Jtv, v), f, x, m.θ, one(eltype(Jv)))
+    return sjacobian!(m.e1, m.e1_starts, m.e1_cnts, m.isexp, nothing, (Jtv, v), f, x, m.θ, one(eltype(Jv)))
end

function hess_coord!(
    m::ExaModel,
    x::AbstractVector,
    hess::AbstractVector;
-    obj_weight=one(eltype(x)),
+        obj_weight = one(eltype(x)),
)
    fill!(hess, zero(eltype(hess)))
    fill!(m.e1, zero(eltype(m.e1)))
@@ -1012,7 +1018,7 @@ function hess_coord!(
    x::AbstractVector,
    y::AbstractVector,
    hess::AbstractVector;
-    obj_weight=one(eltype(x)),
+        obj_weight = one(eltype(x)),
)
    fill!(hess, zero(eltype(hess)))
    fill!(m.e1, zero(eltype(m.e1)))
@@ -1028,19 +1034,19 @@ end
_exp_hess_coord!(exps::ExpressionNull, x, m) = nothing
function _exp_hess_coord!(exps, x, m)
    _exp_hess_coord!(exps.inner, x, m)
-    shessian!(m.e2, nothing, exps, x, m.θ, m.e1, m.e1_starts, m.e1_cnts, m.e2, m.e2_starts, m.e2_cnts, one(eltype(m.e2)), zero(eltype(m.e2)), m.isexp)
+    return shessian!(m.e2, nothing, exps, x, m.θ, m.e1, m.e1_starts, m.e1_cnts, m.e2, m.e2_starts, m.e2_cnts, one(eltype(m.e2)), zero(eltype(m.e2)), m.isexp)
end

_obj_hess_coord!(objs::ObjectiveNull, x, m, hess, w) = nothing
function _obj_hess_coord!(objs, x, m, hess, w)
    _obj_hess_coord!(objs.inner, x, m, hess, w)
-    shessian!(hess, nothing, objs, x, m.θ, m.e1, m.e1_starts, m.e1_cnts, m.e2, m.e2_starts, m.e2_cnts, w, zero(eltype(hess)), m.isexp)
+    return shessian!(hess, nothing, objs, x, m.θ, m.e1, m.e1_starts, m.e1_cnts, m.e2, m.e2_starts, m.e2_cnts, w, zero(eltype(hess)), m.isexp)
end

_con_hess_coord!(cons::ConstraintNull, x, m, y, hess) = nothing
function _con_hess_coord!(cons, x, m, y, hess)
    _con_hess_coord!(cons.inner, x, m, y, hess)
-    shessian!(hess, nothing, cons, x, m.θ, m.e1, m.e1_starts, m.e1_cnts, m.e2, m.e2_starts, m.e2_cnts, y, zero(eltype(hess)), m.isexp)
+    return shessian!(hess, nothing, cons, x, m.θ, m.e1, m.e1_starts, m.e1_cnts, m.e2, m.e2_starts, m.e2_cnts, y, zero(eltype(hess)), m.isexp)
end

function hprod!(
@@ -1048,7 +1054,7 @@ function hprod!(
    x::AbstractVector,
    v::AbstractVector,
    Hv::AbstractVector;
-    obj_weight=one(eltype(x)),
+        obj_weight = one(eltype(x)),
)
    fill!(Hv, zero(eltype(Hv)))
    fill!(m.e1, zero(eltype(m.e1)))
@@ -1066,7 +1072,7 @@ function hprod!(
    y::AbstractVector,
    v::AbstractVector,
    Hv::AbstractVector;
-    obj_weight=one(eltype(x)),
+        obj_weight = one(eltype(x)),
)
    fill!(Hv, zero(eltype(Hv)))
    fill!(m.e1, zero(eltype(m.e1)))
@@ -1082,13 +1088,13 @@ end
_obj_hprod!(objs::ObjectiveNull, x, m, v, Hv, obj_weight) = nothing
function _obj_hprod!(objs, x, m, v, Hv, obj_weight)
    _obj_hprod!(objs.inner, x, m, v, Hv, obj_weight)
-    shessian!((Hv, v), nothing, objs, x, m.θ, m.e1, m.e1_starts, m.e1_cnts, m.e2, m.e2_starts, m.e2_cnts, obj_weight, zero(eltype(Hv)), m.isexp)
+    return shessian!((Hv, v), nothing, objs, x, m.θ, m.e1, m.e1_starts, m.e1_cnts, m.e2, m.e2_starts, m.e2_cnts, obj_weight, zero(eltype(Hv)), m.isexp)
end

_con_hprod!(cons::ConstraintNull, x, m, y, v, Hv) = nothing
function _con_hprod!(cons, x, m, y, v, Hv)
    _con_hprod!(cons.inner, x, m, y, v, Hv)
-    shessian!((Hv, v), nothing, cons, x, m.θ, m.e1, m.e1_starts, m.e1_cnts, m.e2, m.e2_starts, m.e2_cnts, y, zero(eltype(Hv)), m.isexp)
+    return shessian!((Hv, v), nothing, cons, x, m.θ, m.e1, m.e1_starts, m.e1_cnts, m.e2, m.e2_starts, m.e2_cnts, y, zero(eltype(Hv)), m.isexp)
end

@inbounds @inline offset0(a, i) = offset0(a.f, i)
diff --git a/src/register.jl b/src/register.jl
index a3088eb..4feaed5 100644
--- a/src/register.jl
+++ b/src/register.jl
@@ -36,7 +36,7 @@ macro register_univariate(f, df, ddf)
            @inline $f(t::T) where {T<:ExaModels.AbstractSecondAdjointNode} =
                ExaModels.SecondAdjointNode1($f, $f(t.x), $df(t.x), $ddf(t.x), t)

-            @inline (n::ExaModels.Node1{typeof($f),I})(i, x, θ) where {I} = $f(n.inner(i, x, θ))
+            @inline (n::ExaModels.Node1{typeof($f), I})(i, x, θ) where {I} = $f(n.inner(i, x, θ))
        end,
    )
end
@@ -203,9 +203,9 @@ macro register_bivariate(f, df1, df2, ddf11, ddf12, ddf22)
                )
            end

-            @inline (n::ExaModels.Node2{typeof($f),I1,I2})(i, x, θ) where {I1,I2} = $f(n.inner1(i, x, θ), n.inner2(i, x, θ))
-            @inline (n::ExaModels.Node2{typeof($f),I1,I2})(i, x, θ) where {I1<:Real,I2} = $f(n.inner1, n.inner2(i, x, θ))
-            @inline (n::ExaModels.Node2{typeof($f),I1,I2})(i, x, θ) where {I1,I2<:Real} = $f(n.inner1(i, x, θ), n.inner2)
+            @inline (n::ExaModels.Node2{typeof($f), I1, I2})(i, x, θ) where {I1, I2} = $f(n.inner1(i, x, θ), n.inner2(i, x, θ))
+            @inline (n::ExaModels.Node2{typeof($f), I1, I2})(i, x, θ) where {I1 <: Real, I2} = $f(n.inner1, n.inner2(i, x, θ))
+            @inline (n::ExaModels.Node2{typeof($f), I1, I2})(i, x, θ) where {I1, I2 <: Real} = $f(n.inner1(i, x, θ), n.inner2)
        end,
    )
end
diff --git a/src/simdfunction.jl b/src/simdfunction.jl
index 801a16a..b8f1a66 100644
--- a/src/simdfunction.jl
+++ b/src/simdfunction.jl
@@ -25,7 +25,7 @@ struct SIMDFunction{F,C1,C2}
end

@inline (sf::SIMDFunction{F,C1,C2})(i, x, θ) where {F,C1,C2} = sf.f(i, x, θ)
-@inline (sf::SIMDFunction{F,C1,C2})(i, x, θ) where {F<:Real,C1,C2} = sf.f
+@inline (sf::SIMDFunction{F, C1, C2})(i, x, θ) where {F <: Real, C1, C2} = sf.f

"""
    SIMDFunction(gen::Base.Generator, o0 = 0, o1 = 0, o2 = 0)
@@ -38,12 +38,12 @@ Returns a `SIMDFunction` using the `gen`.
- `o1`: offset for the derivative evalution
- `o2`: offset for the second-order derivative evalution
"""
-function SIMDFunction(gen::Base.Generator, offset_exps, isexp, o0=0, o1=0, o2=0)
+function SIMDFunction(gen::Base.Generator, offset_exps, isexp, o0 = 0, o1 = 0, o2 = 0)
    f = gen.f(ParSource())
-    _simdfunction(f, offset_exps, isexp, o0, o1, o2)
+    return _simdfunction(f, offset_exps, isexp, o0, o1, o2)
end

-function _simdfunction(f::F, exps, isexp, o0, o1, o2) where {F<:Real}
+function _simdfunction(f::F, exps, isexp, o0, o1, o2) where {F <: Real}
    SIMDFunction(
        f,
        ExaModels.Compressor{Tuple{}}(()),
@@ -73,7 +73,7 @@ function compress_ref_cnts(y, a)
        end
    end
    push!(ret, cnt)
-    ret
+    return ret
end

function _simdfunction(f, offset_exps, isexp, o0, o1, o2)
diff --git a/test/ADTest/expression.jl b/test/ADTest/expression.jl
index ca7bc47..9efe9e0 100644
--- a/test/ADTest/expression.jl
+++ b/test/ADTest/expression.jl
@@ -1,6 +1,6 @@

function test_expression()
-    @testset "AD Expression Tests" begin
+    return @testset "AD Expression Tests" begin
        @testset "Basic tests" begin
            m = ExaCore()
            v = variable(m, 5)
@@ -31,13 +31,13 @@ function test_expression()

            # Test Hessian values (objective only)
            hess_buffer = zeros(mod.meta.nnzh)
-            hess_coord!(mod, x, hess_buffer; obj_weight=1.0)
+            hess_coord!(mod, x, hess_buffer; obj_weight = 1.0)
            @test all(isfinite, hess_buffer)

            # Test Hessian values (with constraints)
            y = ones(mod.meta.ncon)
            hess_buffer2 = zeros(mod.meta.nnzh)
-            hess_coord!(mod, x, y, hess_buffer2; obj_weight=1.0)
+            hess_coord!(mod, x, y, hess_buffer2; obj_weight = 1.0)
            @test all(isfinite, hess_buffer2)
        end

@@ -48,7 +48,7 @@ function test_expression()
            m = ExaCore()
            v = variable(m, 1)
            o = objective(m, v[1]^2)
-            c = constraint(m, v[1]^2; lcon=1.0, ucon=1.0)
+            c = constraint(m, v[1]^2; lcon = 1.0, ucon = 1.0)
            mod = ExaModel(m)

            x = [3.0]  # arbitrary point
@@ -60,13 +60,13 @@ function test_expression()

            # Objective Hessian only
            hess_obj = zeros(mod.meta.nnzh)
-            hess_coord!(mod, x, hess_obj; obj_weight=1.0)
+            hess_coord!(mod, x, hess_obj; obj_weight = 1.0)
            @test any(h ≈ 2.0 for h in hess_obj)

            # Full Hessian (obj + constraints)
            y = [1.0]  # constraint multiplier
            hess_full = zeros(mod.meta.nnzh)
-            hess_coord!(mod, x, y, hess_full; obj_weight=1.0)
+            hess_coord!(mod, x, y, hess_full; obj_weight = 1.0)
            @test sum(hess_full) ≈ 4.0
        end

@@ -76,7 +76,7 @@ function test_expression()
            v = variable(m, 2)
            e1 = subexpr(m, (1,), v[1] * v[2] for _ in 1:1)  # e = x*y
            o = objective(m, e1[1] for _ in 1:1)  # f = e = x*y
-            c = constraint(m, e1[1] for _ in 1:1; lcon=1.0, ucon=1.0)  # c = x*y = 1
+            c = constraint(m, e1[1] for _ in 1:1; lcon = 1.0, ucon = 1.0)  # c = x*y = 1
            mod = ExaModel(m)

            x = zeros(mod.meta.nvar)
@@ -89,13 +89,13 @@ function test_expression()

            # Objective Hessian only
            hess_obj = zeros(mod.meta.nnzh)
-            hess_coord!(mod, x, hess_obj; obj_weight=1.0)
+            hess_coord!(mod, x, hess_obj; obj_weight = 1.0)
            @test any(h ≈ 1.0 for h in hess_obj)

            # Full Hessian
            y = [1.0]
            hess_full = zeros(mod.meta.nnzh)
-            hess_coord!(mod, x, y, hess_full; obj_weight=1.0)
+            hess_coord!(mod, x, y, hess_full; obj_weight = 1.0)
            @test any(h ≈ 2.0 for h in hess_full) || (sum(hess_full) ≈ 2.0)
        end

@@ -118,7 +118,7 @@ function test_expression()

            # Objective Hessian only
            hess_obj = zeros(mod.meta.nnzh)
-            hess_coord!(mod, x, hess_obj; obj_weight=1.0)
+            hess_coord!(mod, x, hess_obj; obj_weight = 1.0)
            # Sum of hessian entries should be 2
            @test sum(hess_obj) ≈ 2.0
        end
diff --git a/test/NLPTest/NLPTest.jl b/test/NLPTest/NLPTest.jl
index 76445ca..27c9eac 100644
--- a/test/NLPTest/NLPTest.jl
+++ b/test/NLPTest/NLPTest.jl
@@ -69,9 +69,9 @@ function test_nlp((m1, varis1), (m2, varis2); full = false)
        v = randn(eltype(m1.meta.x0), m1.meta.ncon)
        u2 = length(x02) == length(x01) ? u : u[varis2]

-        @test NLPModels.obj(m1, x01) ≈ NLPModels.obj(m2, x02) atol = 1e-6
-        @test NLPModels.cons(m1, x01) ≈ NLPModels.cons(m2, x02) atol = 1e-6
-        @test NLPModels.grad(m1,...*[Comment body truncated]*

@amontoison
Copy link
Copy Markdown
Member

@hfytr The diff is very big and hard to review. Can you update the PR to only keep the modifications related to expression support?

@michel2323
Copy link
Copy Markdown
Member

@hfytr I think you used runic on all files. Let me know if you need help with this.

@hfytr
Copy link
Copy Markdown
Collaborator Author

hfytr commented Feb 18, 2026

Hello michel. That’s correct. I’m traveling, but expect code by Tuesday / Wednesday.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants