Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/ext/ExaModelsMOI.jl b/ext/ExaModelsMOI.jl
index 174b8f4..cb10a32 100644
--- a/ext/ExaModelsMOI.jl
+++ b/ext/ExaModelsMOI.jl
@@ -315,8 +315,8 @@ function exafy_con(
set = MOI.get(moim, MOI.ConstraintSet(), ci)
con_to_idx[ci] = offset + i
start = if MOI.supports(
- moim, MOI.ConstraintPrimalStart(), typeof(ci)
- )
+ moim, MOI.ConstraintPrimalStart(), typeof(ci)
+ )
MOI.get(moim, MOI.ConstraintPrimalStart(), ci)
else
nothing
diff --git a/src/gradient.jl b/src/gradient.jl
index 03986e0..e124a56 100644
--- a/src/gradient.jl
+++ b/src/gradient.jl
@@ -8,25 +8,25 @@ Performs dense gradient evaluation via the reverse pass on the computation (sub)
- `y`: result vector
- `adj`: adjoint propagated up to the current node
"""
-@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNull}
+@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D <: AdjointNull}
nothing
end
-@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNode1}
+@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D <: AdjointNode1}
drpass(e, e_starts, e_cnts, d.inner, y, adj * d.y)
nothing
end
-@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNode2}
+@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D <: AdjointNode2}
drpass(e, e_starts, e_cnts, d.inner1, y, adj * d.y1)
drpass(e, e_starts, e_cnts, d.inner2, y, adj * d.y2)
nothing
end
-@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNodeVar}
+@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D <: AdjointNodeVar}
@inbounds y[d.i] += adj
nothing
end
-@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D<:AdjointNodeExpr}
+@inline function drpass(e, e_starts, e_cnts, d::D, y, adj) where {D <: AdjointNodeExpr}
y[d.i] += e[e_starts[d.i][2]]
- nothing
+ return nothing
end
"""
diff --git a/src/graph.jl b/src/graph.jl
index d19f56e..8b469b9 100644
--- a/src/graph.jl
+++ b/src/graph.jl
@@ -134,7 +134,7 @@ struct Identity end
@inline (v::Var{I})(i, x, θ) where {I<:AbstractNode} = @inbounds x[v.i(i, x, θ)]
@inline (v::Var{I})(i, x, θ) where {I} = @inbounds x[v.i]
-@inline (e::Exp{I})(i, x, θ) where {I<:AbstractNode} = @inbounds x[e.i(i, x, θ)]
+@inline (e::Exp{I})(i, x, θ) where {I <: AbstractNode} = @inbounds x[e.i(i, x, θ)]
@inline (e::Exp{I})(i, x, θ) where {I} = @inbounds x[e.i]
@inline (v::ParameterNode{I})(i, x, θ) where {I<:AbstractNode} = @inbounds θ[v.i(i, x, θ)]
@@ -200,7 +200,7 @@ struct AdjointNodeVar{I,T} <: AbstractAdjointNode
x::T
end
-struct AdjointNodeExpr{I,T} <: AbstractAdjointNode
+struct AdjointNodeExpr{I, T} <: AbstractAdjointNode
i::I
x::T
end
@@ -213,7 +213,7 @@ A source of `AdjointNode`. `adjoint_node_source[i]` returns an `AdjointNodeVar`
# Fields:
- `inner::VT`: variable vector
"""
-struct AdjointNodeSource{VT,OE}
+struct AdjointNodeSource{VT, OE}
inner::VT
offset_exps::OE
end
@@ -223,23 +223,23 @@ end
@inline AdjointNode2(f::F, x::T, y1, y2, inner1::I1, inner2::I2) where {F,T,I1,I2} =
AdjointNode2{F,T,I1,I2}(x, y1, y2, inner1, inner2)
-@inline function Base.getindex(x::I, i) where {I<:AdjointNodeSource{Nothing,Nothing}}
+@inline function Base.getindex(x::I, i) where {I <: AdjointNodeSource{Nothing, Nothing}}
(i, isexp, theta) = i
- @inbounds isexp ? AdjointNodeExpr(i, NaN) : AdjointNodeVar(i, NaN)
+ return @inbounds isexp ? AdjointNodeExpr(i, NaN) : AdjointNodeVar(i, NaN)
end
-@inline function Base.getindex(x::I, i) where {I<:AdjointNodeSource{Nothing}}
+@inline function Base.getindex(x::I, i) where {I <: AdjointNodeSource{Nothing}}
(i, isexp, theta) = i
- if isexp
+ return if isexp
dump(i)
- offset = typeof(i) <: Node2{typeof(+),T,Int} where T ? i.inner2 : 0
+ offset = typeof(i) <: Node2{typeof(+), T, Int} where {T} ? i.inner2 : 0
x.offset_exps[offset].f.f(Identity(), x, theta, i)
else
AdjointNodeVar(i, NaN)
end
end
-@inline function Base.getindex(x::I, i) where {I<:AdjointNodeSource}
+@inline function Base.getindex(x::I, i) where {I <: AdjointNodeSource}
(i, isexp, theta) = i
- @inbounds isexp ? AdjointNodeExpr(i, x.inner[i]) : AdjointNodeVar(i, x.inner[i])
+ return @inbounds isexp ? AdjointNodeExpr(i, x.inner[i]) : AdjointNodeVar(i, x.inner[i])
end
"""
@@ -299,7 +299,7 @@ struct SecondAdjointNodeVar{I,T} <: AbstractSecondAdjointNode
x::T
end
-struct SecondAdjointNodeExpr{I,T} <: AbstractSecondAdjointNode
+struct SecondAdjointNodeExpr{I, T} <: AbstractSecondAdjointNode
i::I
x::T
end
@@ -313,7 +313,7 @@ A source of `AdjointNode`. `adjoint_node_source[i]` returns an `AdjointNodeVar`
- `inner::VT`: variable vector
- 'isexp::VTI': expression vector
"""
-struct SecondAdjointNodeSource{VT,OE}
+struct SecondAdjointNodeSource{VT, OE}
inner::VT
offset_exps::OE
end
@@ -333,22 +333,22 @@ end
) where {F,T,I1,I2} =
SecondAdjointNode2{F,T,I1,I2}(x, y1, y2, h11, h12, h22, inner1, inner2)
-@inline function Base.getindex(x::I, i) where {I<:SecondAdjointNodeSource{Nothing,Nothing}}
+@inline function Base.getindex(x::I, i) where {I <: SecondAdjointNodeSource{Nothing, Nothing}}
(i, isexp, theta) = i
- @inbounds isexp ? SecondAdjointNodeExpr(i, NaN) : SecondAdjointNodeVar(i, NaN)
+ return @inbounds isexp ? SecondAdjointNodeExpr(i, NaN) : SecondAdjointNodeVar(i, NaN)
end
-@inline function Base.getindex(x::I, i) where {I<:SecondAdjointNodeSource{Nothing}}
+@inline function Base.getindex(x::I, i) where {I <: SecondAdjointNodeSource{Nothing}}
(i, isexp, theta) = i
- if isexp
- offset = typeof(i) <: Node2{typeof(+),T,Int} where T ? i.inner2 : 0
+ return if isexp
+ offset = typeof(i) <: Node2{typeof(+), T, Int} where {T} ? i.inner2 : 0
x.offset_exps[offset].f.f(Identity(), x, theta, i)
else
SecondAdjointNodeVar(i, NaN)
end
end
-@inline function Base.getindex(x::I, i) where {I<:SecondAdjointNodeSource}
+@inline function Base.getindex(x::I, i) where {I <: SecondAdjointNodeSource}
(i, isexp, theta) = i
- @inbounds isexp ? SecondAdjointNodeExpr(i, x.inner[i]) : SecondAdjointNodeVar(i, x.inner[i])
+ return @inbounds isexp ? SecondAdjointNodeExpr(i, x.inner[i]) : SecondAdjointNodeVar(i, x.inner[i])
end
@inline (v::Null{Nothing})(i, x::V, θ) where {T,V<:AbstractVector{T}} = zero(T)
@@ -356,29 +356,29 @@ end
@inline (v::Null{N})(i, x::AdjointNodeSource{T}, θ) where {N,T} = AdjointNull()
@inline (v::Null{N})(i, x::SecondAdjointNodeSource{T}, θ) where {N,T} = SecondAdjointNull()
-const NodeSource = Union{AdjointNodeSource,SecondAdjointNodeSource}
+const NodeSource = Union{AdjointNodeSource, SecondAdjointNodeSource}
-@inline (v::Var{I})(i, x::AdjointNodeSource, θ) where {I<:AbstractNode} = @inbounds AdjointNodeVar(v.i(i, x, θ), NaN)
-@inline (v::Var{I})(i, x::SecondAdjointNodeSource, θ) where {I<:AbstractNode} = @inbounds SecondAdjointNodeVar(v.i(i, x, θ), NaN)
+@inline (v::Var{I})(i, x::AdjointNodeSource, θ) where {I <: AbstractNode} = @inbounds AdjointNodeVar(v.i(i, x, θ), NaN)
+@inline (v::Var{I})(i, x::SecondAdjointNodeSource, θ) where {I <: AbstractNode} = @inbounds SecondAdjointNodeVar(v.i(i, x, θ), NaN)
@inline (v::Var{I})(i, x::AdjointNodeSource, θ) where {I} = @inbounds AdjointNodeVar(i, NaN)
@inline (v::Var{I})(i, x::SecondAdjointNodeSource, θ) where {I} = @inbounds SecondAdjointNodeVar(i, NaN)
-@inline (v::Var{I})(i::Identity, x::AdjointNodeSource, θ) where {I<:AbstractNode} = @inbounds AdjointNodeVar(v.i, NaN)
-@inline (v::Var{I})(i::Identity, x::SecondAdjointNodeSource, θ) where {I<:AbstractNode} = @inbounds SecondAdjointNodeVar(v.i, NaN)
+@inline (v::Var{I})(i::Identity, x::AdjointNodeSource, θ) where {I <: AbstractNode} = @inbounds AdjointNodeVar(v.i, NaN)
+@inline (v::Var{I})(i::Identity, x::SecondAdjointNodeSource, θ) where {I <: AbstractNode} = @inbounds SecondAdjointNodeVar(v.i, NaN)
-@inline (v::Exp{I})(i, x::AdjointNodeSource, θ) where {I<:AbstractNode} = @inbounds AdjointNodeExpr(v.i(i, x, θ), NaN)
-@inline (v::Exp{I})(i, x::SecondAdjointNodeSource, θ) where {I<:AbstractNode} = @inbounds SecondAdjointNodeExpr(v.i(i, x, θ), NaN)
+@inline (v::Exp{I})(i, x::AdjointNodeSource, θ) where {I <: AbstractNode} = @inbounds AdjointNodeExpr(v.i(i, x, θ), NaN)
+@inline (v::Exp{I})(i, x::SecondAdjointNodeSource, θ) where {I <: AbstractNode} = @inbounds SecondAdjointNodeExpr(v.i(i, x, θ), NaN)
-@inline function (e::Exp{I})(i::Identity, x::AdjointNodeSource, θ) where {I<:AbstractNode}
- offset = typeof(e.i) <: Node2{typeof(+),T,Int} where T ? e.i.inner2 : 0
- x.offset_exps[offset].f.f(e.i(i, x, θ), x, θ)
+@inline function (e::Exp{I})(i::Identity, x::AdjointNodeSource, θ) where {I <: AbstractNode}
+ offset = typeof(e.i) <: Node2{typeof(+), T, Int} where {T} ? e.i.inner2 : 0
+ return x.offset_exps[offset].f.f(e.i(i, x, θ), x, θ)
end
-@inline function (e::Exp{I})(i::Identity, x::SecondAdjointNodeSource, θ) where {I<:AbstractNode}
- offset = typeof(e.i) <: Node2{typeof(+),T,Int} where T ? e.i.inner2 : 0
- x.offset_exps[offset].f.f(e.i(i, x, θ), x, θ)
+@inline function (e::Exp{I})(i::Identity, x::SecondAdjointNodeSource, θ) where {I <: AbstractNode}
+ offset = typeof(e.i) <: Node2{typeof(+), T, Int} where {T} ? e.i.inner2 : 0
+ return x.offset_exps[offset].f.f(e.i(i, x, θ), x, θ)
end
-@inline (v::Exp{I})(i, x::X, θ) where {I, X<:AdjointNodeSource} = @inbounds AdjointNodeExpr(i, NaN)
-@inline (v::Exp{I})(i, x::X, θ) where {I, X<:SecondAdjointNodeSource} = @inbounds SecondAdjointNodeExpr(i, NaN)
+@inline (v::Exp{I})(i, x::X, θ) where {I, X <: AdjointNodeSource} = @inbounds AdjointNodeExpr(i, NaN)
+@inline (v::Exp{I})(i, x::X, θ) where {I, X <: SecondAdjointNodeSource} = @inbounds SecondAdjointNodeExpr(i, NaN)
diff --git a/src/hessian.jl b/src/hessian.jl
index 396beee..d76c97b 100644
--- a/src/hessian.jl
+++ b/src/hessian.jl
@@ -18,9 +18,9 @@ Performs sparse hessian evaluation (`(df1/dx)(df2/dx)'` portion) via the reverse
- `adj`: second adjoint propagated up to the current node
"""
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::T1,
t2::T2,
comp,
@@ -34,9 +34,9 @@ Performs sparse hessian evaluation (`(df1/dx)(df2/dx)'` portion) via the reverse
cnt
end
function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::SecondAdjointNode1,
t2::SecondAdjointNode1,
comp::Nothing,
@@ -52,9 +52,9 @@ end
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::T1,
t2::T2,
comp,
@@ -63,14 +63,14 @@ end
o2,
cnt,
adj,
-) where {T1<:Union{SecondAdjointNodeVar,SecondAdjointNodeExpr},T2<:SecondAdjointNode1}
+ ) where {T1 <: Union{SecondAdjointNodeVar, SecondAdjointNodeExpr}, T2 <: SecondAdjointNode1}
cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner, comp, y1, y2, o2, cnt, adj * t2.y)
cnt
end
function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::SecondAdjointNodeVar,
t2::SecondAdjointNode1,
comp::Nothing,
@@ -81,29 +81,29 @@ function hdrpass(
adj,
) # despecialized
cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner, comp, y1, y2, o2, cnt, adj * t2.y)
- cnt
+ return cnt
end
function hdrpass(
- e,
- e_starts,
- e_cnts,
- t1::SecondAdjointNodeExpr,
- t2::SecondAdjointNode1,
- comp::Nothing,
- y1,
- y2,
- o2,
- cnt,
- adj,
-) # despecialized
+ e,
+ e_starts,
+ e_cnts,
+ t1::SecondAdjointNodeExpr,
+ t2::SecondAdjointNode1,
+ comp::Nothing,
+ y1,
+ y2,
+ o2,
+ cnt,
+ adj,
+ ) # despecialized
cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner, comp, y1, y2, o2, cnt, adj * t2.y)
cnt
end
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::T1,
t2::T2,
comp,
@@ -112,14 +112,14 @@ end
o2,
cnt,
adj,
-) where {T1<:SecondAdjointNode1,T2<:Union{SecondAdjointNodeVar,SecondAdjointNodeExpr}}
+ ) where {T1 <: SecondAdjointNode1, T2 <: Union{SecondAdjointNodeVar, SecondAdjointNodeExpr}}
cnt = hdrpass(e, e_starts, e_cnts, t1.inner, t2, comp, y1, y2, o2, cnt, adj * t1.y)
cnt
end
function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::SecondAdjointNode1,
t2::SecondAdjointNodeVar,
comp::Nothing,
@@ -130,30 +130,30 @@ function hdrpass(
adj,
) # despecialized
cnt = hdrpass(e, e_starts, e_cnts, t1.inner, t2, comp, y1, y2, o2, cnt, adj * t1.y)
- cnt
+ return cnt
end
function hdrpass(
- e,
- e_starts,
- e_cnts,
- t1::SecondAdjointNode1,
- t2::SecondAdjointNodeExpr,
- comp::Nothing,
- y1,
- y2,
- o2,
- cnt,
- adj,
-) # despecialized
+ e,
+ e_starts,
+ e_cnts,
+ t1::SecondAdjointNode1,
+ t2::SecondAdjointNodeExpr,
+ comp::Nothing,
+ y1,
+ y2,
+ o2,
+ cnt,
+ adj,
+ ) # despecialized
cnt = hdrpass(e, e_starts, e_cnts, t1.inner, t2, comp, y1, y2, o2, cnt, adj * t1.y)
cnt
end
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::T1,
t2::T2,
comp,
@@ -170,9 +170,9 @@ end
cnt
end
function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::SecondAdjointNode2,
t2::SecondAdjointNode2,
comp::Nothing,
@@ -191,9 +191,9 @@ end
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::T1,
t2::T2,
comp,
@@ -208,9 +208,9 @@ end
cnt
end
function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::SecondAdjointNode1,
t2::SecondAdjointNode2,
comp::Nothing,
@@ -226,9 +226,9 @@ function hdrpass(
end
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::T1,
t2::T2,
comp,
@@ -243,9 +243,9 @@ end
cnt
end
function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::SecondAdjointNode2,
t2::SecondAdjointNode1,
comp::Nothing,
@@ -261,9 +261,9 @@ function hdrpass(
end
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::T1,
t2::T2,
comp,
@@ -272,15 +272,15 @@ end
o2,
cnt,
adj,
-) where {T1<:Union{SecondAdjointNodeVar,SecondAdjointNodeExpr},T2<:SecondAdjointNode2}
+ ) where {T1 <: Union{SecondAdjointNodeVar, SecondAdjointNodeExpr}, T2 <: SecondAdjointNode2}
cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner1, comp, y1, y2, o2, cnt, adj * t2.y1)
cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner2, comp, y1, y2, o2, cnt, adj * t2.y2)
cnt
end
function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::SecondAdjointNodeVar,
t2::SecondAdjointNode2,
comp::Nothing,
@@ -292,30 +292,30 @@ function hdrpass(
) # despecialized
cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner1, comp, y1, y2, o2, cnt, adj * t2.y1)
cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner2, comp, y1, y2, o2, cnt, adj * t2.y2)
- cnt
+ return cnt
end
function hdrpass(
- e,
- e_starts,
- e_cnts,
- t1::SecondAdjointNodeExpr,
- t2::SecondAdjointNode2,
- comp::Nothing,
- y1,
- y2,
- o2,
- cnt,
- adj,
-) # despecialized
+ e,
+ e_starts,
+ e_cnts,
+ t1::SecondAdjointNodeExpr,
+ t2::SecondAdjointNode2,
+ comp::Nothing,
+ y1,
+ y2,
+ o2,
+ cnt,
+ adj,
+ ) # despecialized
cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner1, comp, y1, y2, o2, cnt, adj * t2.y1)
cnt = hdrpass(e, e_starts, e_cnts, t1, t2.inner2, comp, y1, y2, o2, cnt, adj * t2.y2)
cnt
end
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::T1,
t2::T2,
comp,
@@ -324,15 +324,15 @@ end
o2,
cnt,
adj,
-) where {T1<:SecondAdjointNode2,T2<:Union{SecondAdjointNodeVar,SecondAdjointNodeExpr}}
+ ) where {T1 <: SecondAdjointNode2, T2 <: Union{SecondAdjointNodeVar, SecondAdjointNodeExpr}}
cnt = hdrpass(e, e_starts, e_cnts, t1.inner1, t2, comp, y1, y2, o2, cnt, adj * t1.y1)
cnt = hdrpass(e, e_starts, e_cnts, t1.inner2, t2, comp, y1, y2, o2, cnt, adj * t1.y2)
cnt
end
function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::SecondAdjointNode2,
t2::SecondAdjointNodeVar,
comp::Nothing,
@@ -344,30 +344,30 @@ function hdrpass(
) # despecialized
cnt = hdrpass(e, e_starts, e_cnts, t1.inner1, t2, comp, y1, y2, o2, cnt, adj * t1.y1)
cnt = hdrpass(e, e_starts, e_cnts, t1.inner2, t2, comp, y1, y2, o2, cnt, adj * t1.y2)
- cnt
+ return cnt
end
function hdrpass(
- e,
- e_starts,
- e_cnts,
- t1::SecondAdjointNode2,
- t2::SecondAdjointNodeExpr,
- comp::Nothing,
- y1,
- y2,
- o2,
- cnt,
- adj,
-) # despecialized
+ e,
+ e_starts,
+ e_cnts,
+ t1::SecondAdjointNode2,
+ t2::SecondAdjointNodeExpr,
+ comp::Nothing,
+ y1,
+ y2,
+ o2,
+ cnt,
+ adj,
+ ) # despecialized
cnt = hdrpass(e, e_starts, e_cnts, t1.inner1, t2, comp, y1, y2, o2, cnt, adj * t1.y1)
cnt = hdrpass(e, e_starts, e_cnts, t1.inner2, t2, comp, y1, y2, o2, cnt, adj * t1.y2)
cnt
end
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::T1,
t2::T2,
comp,
@@ -379,72 +379,72 @@ end
) where {T1<:SecondAdjointNodeVar,T2<:SecondAdjointNodeVar}
i, j = t1.i, t2.i
@inbounds if i == j
- y1[o2+comp(cnt += 1)] += 2 * adj
+ y1[o2 + comp(cnt += 1)] += 2 * adj
else
- y1[o2+comp(cnt += 1)] += adj
+ y1[o2 + comp(cnt += 1)] += adj
end
cnt
end
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
- t1::T1,
- t2::T2,
- comp,
- y1,
- y2,
- o2,
- cnt,
- adj,
-) where {T1<:SecondAdjointNodeExpr,T2<:SecondAdjointNodeVar}
+ e,
+ e_starts,
+ e_cnts,
+ t1::T1,
+ t2::T2,
+ comp,
+ y1,
+ y2,
+ o2,
+ cnt,
+ adj,
+ ) where {T1 <: SecondAdjointNodeExpr, T2 <: SecondAdjointNodeVar}
(cnt_start, e_start) = e_starts[t1.i]
len = e_cnts[cnt_start]
cnt += 1
for i in 1:len
- @inbounds y1[o2+comp(cnt)] += e[e_start+i-1] * adj
- cnt += e_cnts[cnt_start+i]
+ @inbounds y1[o2 + comp(cnt)] += e[e_start + i - 1] * adj
+ cnt += e_cnts[cnt_start + i]
end
return cnt
end
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
- t1::T1,
- t2::T2,
- comp,
- y1,
- y2,
- o2,
- cnt,
- adj,
-) where {T1<:SecondAdjointNodeVar,T2<:SecondAdjointNodeExpr}
+ e,
+ e_starts,
+ e_cnts,
+ t1::T1,
+ t2::T2,
+ comp,
+ y1,
+ y2,
+ o2,
+ cnt,
+ adj,
+ ) where {T1 <: SecondAdjointNodeVar, T2 <: SecondAdjointNodeExpr}
(cnt_start, e_start) = e_starts[t2.i]
len = e_cnts[cnt_start]
cnt += 1
for i in 1:len
- @inbounds y1[o2+comp(cnt)] += e[e_start+i-1] * adj
- cnt += e_cnts[cnt_start+i]
+ @inbounds y1[o2 + comp(cnt)] += e[e_start + i - 1] * adj
+ cnt += e_cnts[cnt_start + i]
end
return cnt
end
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
- t1::T1,
- t2::T2,
- comp,
- y1,
- y2,
- o2,
- cnt,
- adj,
-) where {T1<:SecondAdjointNodeExpr,T2<:SecondAdjointNodeExpr}
+ e,
+ e_starts,
+ e_cnts,
+ t1::T1,
+ t2::T2,
+ comp,
+ y1,
+ y2,
+ o2,
+ cnt,
+ adj,
+ ) where {T1 <: SecondAdjointNodeExpr, T2 <: SecondAdjointNodeExpr}
(cnt_start1, e_start1) = e_starts[t1.i]
len1 = e_cnts[cnt_start1]
(cnt_start2, e_start2) = e_starts[t2.i]
@@ -452,27 +452,27 @@ end
cnt += 1
for i in 1:len1
- val1 = e[e_start1+i-1]
+ val1 = e[e_start1 + i - 1]
for j in 1:len2
- val2 = e[e_start2+j-1]
+ val2 = e[e_start2 + j - 1]
ind = o2 + comp(cnt)
@inbounds if t1.i == t2.i && i == j
y1[ind] += 2 * val1 * val2 * adj
else
y1[ind] += val1 * val2 * adj
end
- cnt += e_cnts[cnt_start2+j]
+ cnt += e_cnts[cnt_start2 + j]
end
- cnt += e_cnts[cnt_start1+i]
+ cnt += e_cnts[cnt_start1 + i]
end
return cnt
end
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::T1,
t2::T2,
comp,
@@ -521,12 +521,12 @@ Performs sparse hessian evaluation (`d²f/dx²` portion) via the reverse pass on
"""
@inline function hrpass(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
t::D,
comp,
y1,
@@ -540,12 +540,12 @@ Performs sparse hessian evaluation (`d²f/dx²` portion) via the reverse pass on
end
@inline function hrpass(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
t::D,
comp,
y1,
@@ -554,39 +554,39 @@ end
cnt,
adj,
adj2,
-) where {D<:SecondAdjointNodeExpr}
+ ) where {D <: SecondAdjointNodeExpr}
(cnt_start2, e_start2) = e2_starts[t.i]
len2 = e2_cnts[cnt_start2]
cnt += 1
for i in 1:len2
- @inbounds y1[o2+comp(cnt)] += adj * e2[e_start2+i-1]
- cnt += e2_cnts[cnt_start2+i]
+ @inbounds y1[o2 + comp(cnt)] += adj * e2[e_start2 + i - 1]
+ cnt += e2_cnts[cnt_start2 + i]
end
return cnt
end
@inline function hrpass(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
- t::D,
- comp,
- y1::V,
- y2::V,
- o2,
- cnt,
- adj,
- adj2,
-) where {D<:SecondAdjointNodeExpr,I<:Integer,V<:AbstractVector{I}}
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
+ t::D,
+ comp,
+ y1::V,
+ y2::V,
+ o2,
+ cnt,
+ adj,
+ adj2,
+ ) where {D <: SecondAdjointNodeExpr, I <: Integer, V <: AbstractVector{I}}
(cnt_start2, e_start2) = e2_starts[t.i]
len2 = e2_cnts[cnt_start2]
cnt += 1
for i in 1:len2
ind = o2 + comp(cnt)
- val = e2[e_start2+i-1]
+ val = e2[e_start2 + i - 1]
r = unpack_row(val)
c = unpack_col(val)
if y1 === y2
@@ -599,38 +599,38 @@ end
@inbounds y2[ind] = c
end
end
- cnt += e2_cnts[cnt_start2+i]
+ cnt += e2_cnts[cnt_start2 + i]
end
return cnt
end
@inline function hrpass(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
- t::D,
- comp,
- y1,
- y2,
- o2,
- cnt,
- adj,
- adj2,
-) where {D<:SecondAdjointNode1}
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
+ t::D,
+ comp,
+ y1,
+ y2,
+ o2,
+ cnt,
+ adj,
+ adj2,
+ ) where {D <: SecondAdjointNode1}
cnt = hrpass(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t.inner, comp, y1, y2, o2, cnt, adj * t.y, adj2 * (t.y)^2 + adj * t.h)
cnt
end
@inline function hrpass(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
t::D,
comp,
y1,
@@ -639,7 +639,7 @@ end
cnt,
adj,
adj2,
-) where {D<:SecondAdjointNode2}
+ ) where {D <: SecondAdjointNode2}
adj2y1y2 = adj2 * t.y1 * t.y2
adjh12 = adj * t.h12
cnt = hrpass(e, e_starts, e_cnts, e2, e2_starts, e2_cnts, t.inner1, comp, y1, y2, o2, cnt, adj * t.y1, adj2 * (t.y1)^2 + adj * t.h11)
@@ -651,12 +651,12 @@ end
@inline hrpass0(args...) = hrpass(args...)
@inline function hrpass0(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
t::D,
comp,
y1,
@@ -671,12 +671,12 @@ end
end
@inline function hrpass0(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
t::D,
comp,
y1,
@@ -691,12 +691,12 @@ end
end
@inline function hrpass0(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
t::D,
comp,
y1,
@@ -711,12 +711,12 @@ end
end
@inline function hrpass0(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
t::D,
comp,
y1,
@@ -731,12 +731,12 @@ end
end
@inline function hrpass0(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
t::D,
comp,
y1,
@@ -751,12 +751,12 @@ end
end
@inline function hrpass0(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
t::D,
comp,
y1,
@@ -771,12 +771,12 @@ end
end
@inline function hrpass0(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
t::D,
comp,
y1,
@@ -792,12 +792,12 @@ end
end
@inline function hrpass0(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
t::D,
comp,
y1,
@@ -813,12 +813,12 @@ end
end
@inline function hrpass0(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
t::T,
comp,
y1,
@@ -832,12 +832,12 @@ end
end
@inline function hrpass0(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
t::T,
comp::Nothing,
y1,
@@ -851,27 +851,27 @@ end
end
@inline function hrpass0(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
- t::T,
- comp,
- y1,
- y2,
- o2,
- cnt,
- adj,
- adj2,
-) where {T<:SecondAdjointNodeExpr}
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
+ t::T,
+ comp,
+ y1,
+ y2,
+ o2,
+ cnt,
+ adj,
+ adj2,
+ ) where {T <: SecondAdjointNodeExpr}
(cnt_start2, e_start2) = e2_starts[t.i]
len2 = e2_cnts[cnt_start2]
cnt += 1
for i in 1:len2
- @inbounds y1[o2+comp(cnt)] += adj * e2[e_start2+i-1]
- cnt += e2_cnts[cnt_start2+i]
+ @inbounds y1[o2 + comp(cnt)] += adj * e2[e_start2 + i - 1]
+ cnt += e2_cnts[cnt_start2 + i]
end
@@ -879,51 +879,51 @@ end
end
function hdrpass(
- e,
- e_starts,
- e_cnts,
- t1::T1,
- t2::T2,
+ e,
+ e_starts,
+ e_cnts,
+ t1::T1,
+ t2::T2,
comp::Nothing,
y1,
y2,
o2,
cnt,
adj,
-) where {T1<:SecondAdjointNodeVar, T2<:SecondAdjointNodeVar}
+ ) where {T1 <: SecondAdjointNodeVar, T2 <: SecondAdjointNodeVar}
cnt += 1
push!(y1, (t1.i, t2.i))
cnt
end
function hrpass(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
- t::T,
- comp::Nothing,
- y1,
- y2,
- o2,
- cnt,
- adj,
- adj2
-) where {T<:SecondAdjointNodeVar}
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
+ t::T,
+ comp::Nothing,
+ y1,
+ y2,
+ o2,
+ cnt,
+ adj,
+ adj2
+ ) where {T <: SecondAdjointNodeVar}
cnt += 1
push!(y1, (t.i, t.i))
cnt
end
@inline function hrpass(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
t::T,
comp,
y1::Tuple{V1,V2},
@@ -939,12 +939,12 @@ end
end
@inline function hrpass(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
t::T,
comp,
y1,
@@ -954,17 +954,17 @@ end
adj,
adj2,
) where {T<:SecondAdjointNodeVar}
- @inbounds y1[o2+comp(cnt += 1)] += adj2
+ @inbounds y1[o2 + comp(cnt += 1)] += adj2
cnt
end
@inline function hrpass(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
t::T,
comp,
y1::V,
@@ -989,12 +989,12 @@ end
end
@inline function hrpass(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
t::T,
comp,
y1::V,
@@ -1014,9 +1014,9 @@ end
@inline unpack_col(v) = Int(v & 0xFFFFFFFF)
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::T1,
t2::T2,
comp,
@@ -1052,9 +1052,9 @@ end
end
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
t1::T1,
t2::T2,
comp,
@@ -1080,27 +1080,27 @@ end
end
@inline function hrpass0(
- e,
- e_starts,
- e_cnts,
- e2,
- e2_starts,
- e2_cnts,
- t::T,
- comp,
- y1::V,
- y2::V,
- o2,
- cnt,
- adj,
- adj2,
-) where {T<:SecondAdjointNodeExpr,I<:Integer,V<:AbstractVector{I}}
+ e,
+ e_starts,
+ e_cnts,
+ e2,
+ e2_starts,
+ e2_cnts,
+ t::T,
+ comp,
+ y1::V,
+ y2::V,
+ o2,
+ cnt,
+ adj,
+ adj2,
+ ) where {T <: SecondAdjointNodeExpr, I <: Integer, V <: AbstractVector{I}}
(cnt_start2, e_start2) = e2_starts[t.i]
len2 = e2_cnts[cnt_start2]
cnt += 1
for i in 1:len2
ind = o2 + comp(cnt)
- val = e2[e_start2+i-1]
+ val = e2[e_start2 + i - 1]
r = unpack_row(val)
c = unpack_col(val)
if y1 === y2
@@ -1113,31 +1113,31 @@ end
@inbounds y2[ind] = c
end
end
- cnt += e2_cnts[cnt_start2+i]
+ cnt += e2_cnts[cnt_start2 + i]
end
return cnt
end
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
- t1::T1,
- t2::T2,
- comp,
- y1::V,
- y2::V,
- o2,
- cnt,
- adj,
-) where {T1<:SecondAdjointNodeExpr,T2<:SecondAdjointNodeVar,I<:Integer,V<:AbstractVector{I}}
+ e,
+ e_starts,
+ e_cnts,
+ t1::T1,
+ t2::T2,
+ comp,
+ y1::V,
+ y2::V,
+ o2,
+ cnt,
+ adj,
+ ) where {T1 <: SecondAdjointNodeExpr, T2 <: SecondAdjointNodeVar, I <: Integer, V <: AbstractVector{I}}
(cnt_start, e_start) = e_starts[t1.i]
len = e_cnts[cnt_start]
j = t2.i
cnt += 1
for i in 1:len
ind = o2 + comp(cnt)
- idx = e[e_start+i-1]
+ idx = e[e_start + i - 1]
if y1 === y2
if idx != 0 || j != 0
@inbounds if idx >= j
@@ -1157,31 +1157,31 @@ end
end
end
end
- cnt += e_cnts[cnt_start+i]
+ cnt += e_cnts[cnt_start + i]
end
return cnt
end
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
- t1::T1,
- t2::T2,
- comp,
- y1::V,
- y2::V,
- o2,
- cnt,
- adj,
-) where {T1<:SecondAdjointNodeVar,T2<:SecondAdjointNodeExpr,I<:Integer,V<:AbstractVector{I}}
+ e,
+ e_starts,
+ e_cnts,
+ t1::T1,
+ t2::T2,
+ comp,
+ y1::V,
+ y2::V,
+ o2,
+ cnt,
+ adj,
+ ) where {T1 <: SecondAdjointNodeVar, T2 <: SecondAdjointNodeExpr, I <: Integer, V <: AbstractVector{I}}
i = t1.i
(cnt_start, e_start) = e_starts[t2.i]
len = e_cnts[cnt_start]
cnt += 1
for k in 1:len
ind = o2 + comp(cnt)
- idx = e[e_start+k-1]
+ idx = e[e_start + k - 1]
if y1 === y2
if i != 0 || idx != 0
@inbounds if i >= idx
@@ -1201,24 +1201,24 @@ end
end
end
end
- cnt += e_cnts[cnt_start+k]
+ cnt += e_cnts[cnt_start + k]
end
return cnt
end
@inline function hdrpass(
- e,
- e_starts,
- e_cnts,
- t1::T1,
- t2::T2,
- comp,
- y1::V,
- y2::V,
- o2,
- cnt,
- adj,
-) where {T1<:SecondAdjointNodeExpr,T2<:SecondAdjointNodeExpr,I<:Integer,V<:AbstractVector{I}}
+ e,
+ e_starts,
+ e_cnts,
+ t1::T1,
+ t2::T2,
+ comp,
+ y1::V,
+ y2::V,
+ o2,
+ cnt,
+ adj,
+ ) where {T1 <: SecondAdjointNodeExpr, T2 <: SecondAdjointNodeExpr, I <: Integer, V <: AbstractVector{I}}
(cnt_start1, e_start1) = e_starts[t1.i]
len1 = e_cnts[cnt_start1]
(cnt_start2, e_start2) = e_starts[t2.i]
@@ -1226,9 +1226,9 @@ end
cnt += 1
for i in 1:len1
- idx1 = e[e_start1+i-1]
+ idx1 = e[e_start1 + i - 1]
for j in 1:len2
- idx2 = e[e_start2+j-1]
+ idx2 = e[e_start2 + j - 1]
ind = o2 + comp(cnt)
if y1 === y2
if idx1 != 0 || idx2 != 0
@@ -1249,9 +1249,9 @@ end
end
end
end
- cnt += e_cnts[cnt_start2+j]
+ cnt += e_cnts[cnt_start2 + j]
end
- cnt += e_cnts[cnt_start1+i]
+ cnt += e_cnts[cnt_start1 + i]
end
return cnt
end
@@ -1297,7 +1297,7 @@ function shessian!(y1, y2, f, x, θ, e1, e1_starts, e1_cnts, e2, e2_starts, e2_c
end
end
-function shessian!(y1, y2, f, x, θ, e1, e1_starts, e1_cnts, e2, e2_starts, e2_cnts, adj1s::V, adj2, isexp) where {V<:AbstractVector}
+function shessian!(y1, y2, f, x, θ, e1, e1_starts, e1_cnts, e2, e2_starts, e2_cnts, adj1s::V, adj2, isexp) where {V <: AbstractVector}
@simd for k in eachindex(f.itr)
@inbounds shessian!(
y1,
@@ -1319,5 +1319,5 @@ end
function shessian!(y1, y2, f, p, x, θ, e1, e1_starts, e1_cnts, e2, e2_starts, e2_cnts, comp, o2, adj1, adj2, isexp)
graph = f(p, SecondAdjointNodeSource(x, nothing), θ)
- hrpass0(e1, e1_starts, e1_cnts, e2, e2_starts, e2_cnts, graph, comp, y1, y2, o2, 0, adj1, adj2)
+ return hrpass0(e1, e1_starts, e1_cnts, e2, e2_starts, e2_cnts, graph, comp, y1, y2, o2, 0, adj1, adj2)
end
diff --git a/src/jacobian.jl b/src/jacobian.jl
index 21ab3d0..41f8df5 100644
--- a/src/jacobian.jl
+++ b/src/jacobian.jl
@@ -15,9 +15,9 @@ Performs sparse jacobian evaluation via the reverse pass on the computation (sub
"""
@inline function jrpass(
d::D,
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
comp,
i,
y1,
@@ -28,26 +28,26 @@ Performs sparse jacobian evaluation via the reverse pass on the computation (sub
) where {D<:Union{AdjointNull,Real}}
return cnt
end
-@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D<:AdjointNode1}
+@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D <: AdjointNode1}
cnt = jrpass(d.inner, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj * d.y)
return cnt
end
-@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D<:AdjointNode2}
+@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D <: AdjointNode2}
cnt = jrpass(d.inner1, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj * d.y1)
cnt = jrpass(d.inner2, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj * d.y2)
return cnt
end
# jac_coord
-@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D<:AdjointNodeVar}
+@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D <: AdjointNodeVar}
@inbounds y1[o1+comp(cnt+=1)] += adj
return cnt
end
-@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D<:AdjointNodeExpr}
+@inline function jrpass(d::D, e, e_starts, e_cnts, comp, o0, y1, y2, o1, cnt, adj) where {D <: AdjointNodeExpr}
(cnt_start, e_start) = e_starts[d.i]
len = e_cnts[cnt_start]
cnt += 1
for i in 1:len
- @inbounds y1[o1+comp(cnt)] += adj * e[e_start + i - 1]
+ @inbounds y1[o1 + comp(cnt)] += adj * e[e_start + i - 1]
cnt += e_cnts[cnt_start + i]
end
return cnt
@@ -55,13 +55,13 @@ end
# jprod_nln
@inline function jrpass(
d::D,
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
comp,
- o0,
+ o0,
y1::Tuple{V1,V2},
- y2::Nothing,
+ y2::Nothing,
o1,
cnt,
adj,
@@ -74,12 +74,12 @@ end
# jtprod_nln
@inline function jrpass(
d::D,
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
comp,
- o0,
- y1::Nothing,
+ o0,
+ y1::Nothing,
y2::Tuple{V1,V2},
o1,
cnt,
@@ -93,11 +93,11 @@ end
# jac_structure
@inline function jrpass(
d::D,
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
comp,
- o0,
+ o0,
y1::V,
y2::V,
o1,
@@ -111,17 +111,17 @@ end
end
@inline function jrpass(
d::D,
- e,
- e_starts,
- e_cnts,
+ e,
+ e_starts,
+ e_cnts,
comp,
- o0,
- y1::V,
- y2::V,
- o1,
- cnt,
- adj,
-) where {D<:AdjointNodeExpr,I<:Integer,V<:AbstractVector{I}}
+ o0,
+ y1::V,
+ y2::V,
+ o1,
+ cnt,
+ adj,
+ ) where {D <: AdjointNodeExpr, I <: Integer, V <: AbstractVector{I}}
(cnt_start, e_start) = e_starts[d.i]
len = e_cnts[cnt_start]
cnt += 1
@@ -135,35 +135,35 @@ end
end
# no rows when precomputing expressions
@inline function jrpass(
- d::D,
- e,
- e_starts,
- e_cnts,
- comp,
- o0,
- y1::Nothing,
- y2::V,
- o1,
- cnt,
- adj,
-) where {D<:AdjointNodeVar,I<:Integer,V<:AbstractVector{I}}
+ d::D,
+ e,
+ e_starts,
+ e_cnts,
+ comp,
+ o0,
+ y1::Nothing,
+ y2::V,
+ o1,
+ cnt,
+ adj,
+ ) where {D <: AdjointNodeVar, I <: Integer, V <: AbstractVector{I}}
ind = o1 + comp(cnt += 1)
@inbounds y2[ind] = d.i
return cnt
end
@inline function jrpass(
- d::D,
- e,
- e_starts,
- e_cnts,
- comp,
- o0,
- y1::Nothing,
- y2::V,
- o1,
- cnt,
- adj,
-) where {D<:AdjointNodeExpr,I<:Integer,V<:AbstractVector{I}}
+ d::D,
+ e,
+ e_starts,
+ e_cnts,
+ comp,
+ o0,
+ y1::Nothing,
+ y2::V,
+ o1,
+ cnt,
+ adj,
+ ) where {D <: AdjointNodeExpr, I <: Integer, V <: AbstractVector{I}}
(cnt_start, e_start) = e_starts[d.i]
len = e_cnts[cnt_start]
cnt += 1
@@ -175,12 +175,12 @@ end
return cnt
end
@inline function jrpass(
- d::D,
- e,
- e_starts,
- e_cnts,
- comp,
- o0,
+ d::D,
+ e,
+ e_starts,
+ e_cnts,
+ comp,
+ o0,
y1::V,
y2,
o1,
@@ -228,5 +228,5 @@ end
function sjacobian!(isexp, y1, y2, f, e, e_starts, e_cnts, p, x, θ, comp, o0, o1, adj)
s = AdjointNodeSource(x, nothing)
graph = f(p, s, θ)
- jrpass(graph, e, e_starts, e_cnts, comp, o0, y1, y2, o1, 0, adj)
+ return jrpass(graph, e, e_starts, e_cnts, comp, o0, y1, y2, o1, 0, adj)
end
diff --git a/src/nlp.jl b/src/nlp.jl
index b20e8d7..3034733 100644
--- a/src/nlp.jl
+++ b/src/nlp.jl
@@ -54,7 +54,7 @@ Objective
)
-struct Expression{R,F,I,O,S} <: AbstractExpression
+struct Expression{R, F, I, O, S} <: AbstractExpression
inner::R
f::F
itr::I
@@ -64,13 +64,13 @@ end
Base.show(io::IO, v::Expression) = print(
io,
"""
-Expression
+ Expression
- s.t. (...)
- g♭ ≤ [g(x,θ,p)]_{p ∈ P} ≤ g♯
+ s.t. (...)
+ g♭ ≤ [g(x,θ,p)]_{p ∈ P} ≤ g♯
- where |P| = $(length(v.itr))
-""",
+ where |P| = $(length(v.itr))
+ """,
)
@@ -92,7 +92,7 @@ Constraint
""",
)
-struct ExpressionAug{R,F,I} <: AbstractConstraint
+struct ExpressionAug{R, F, I} <: AbstractConstraint
inner::R
f::F
itr::I
@@ -161,21 +161,21 @@ An ExaCore"""
Deprecated as of v0.7function ExaCore(::Type{T}, backend) where {T<:AbstractFloat}
-ExaCore(::Type{T}; backend=nothing, kwargs...) where {T<:AbstractFloat} =
depth(a) = depth(a.inner) + 1 -struct ExaModel{T,VT,VI,E,O,C,EX,VII} <: NLPModels.AbstractNLPModel{T,VT} |
|
@hfytr The diff is very big and hard to review. Can you update the PR to only keep the modifications related to expression support? |
|
@hfytr I think you used runic on all files. Let me know if you need help with this. |
|
Hello michel. That’s correct. I’m traveling, but expect code by Tuesday / Wednesday. |
This PR changes expressions to no longer re-evaluate on every reference, and instead pre-computes once for every call to jac_coord / hess_coord / etc.