Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 91dc8ad

Browse files
committed
preinitialize caches for auto-auto hessians
1 parent 9c3b77d commit 91dc8ad

2 files changed

Lines changed: 21 additions & 20 deletions

File tree

src/SparseDiffTools.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ export contract_color,
3030
ForwardColorJacCache,
3131
numauto_color_hessian!,
3232
numauto_color_hessian,
33+
autoauto_color_hessian!,
34+
autoauto_color_hessian,
3335
ForwardColorHesCache,
36+
ForwardAutoColorHesCache,
3437
auto_jacvec,auto_jacvec!,
3538
num_jacvec,num_jacvec!,
3639
num_vecjac,num_vecjac!,

src/differentiation/compute_hessian_ad.jl

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -101,24 +101,33 @@ end
101101

102102
## autoauto_color_hessian
103103

104-
mutable struct ForwardAutoColorHesCache{TS,TC}
105-
jac_cache::Any
106-
grad!::Any
104+
mutable struct ForwardAutoColorHesCache{TJC,TG,TS,TC}
105+
jac_cache::TJC
106+
grad!TG
107107
sparsity::TS
108108
colorvec::TC
109109
end
110110

111111
function ForwardAutoColorHesCache(f,
112-
x::AbstractVector{<:Number},
112+
x::AbstractVector{V},
113113
colorvec::AbstractVector{<:Integer}=eachindex(x),
114-
sparsity::Union{AbstractMatrix,Nothing}=nothing)
114+
sparsity::Union{AbstractMatrix,Nothing}=nothing) where V
115115

116116
if sparsity === nothing
117117
sparsity = sparse(ones(length(x), length(x)))
118118
end
119119

120-
jac_cache = nothing
121-
g! = nothing
120+
tag = ForwardDiff.Tag(f, V)
121+
chunksize = ForwardDiff.pickchunksize(maximum(colorvec))
122+
chunk = ForwardDiff.Chunk(chunksize)
123+
124+
jacobian_config = ForwardDiff.JacobianConfig(f, x, chunk, tag)
125+
gradient_config = ForwardDiff.GradientConfig(f, jacobian_config.duals, chunk, tag)
126+
127+
outer_tag = get_tag(jacobian_config.duals)
128+
g! = (G, x) -> ForwardDiff.gradient!(G, f, x, gradient_config, Val(false))
129+
130+
jac_cache = ForwardColorJacCache(g!, x; colorvec, sparsity, tag=outer_tag)
122131

123132
return ForwardAutoColorHesCache(jac_cache, g!, sparsity, colorvec)
124133
end
@@ -128,17 +137,6 @@ function autoauto_color_hessian!(H::AbstractMatrix{<:Number},
128137
x::AbstractArray{<:Number},
129138
hes_cache::ForwardAutoColorHesCache)
130139

131-
if hes_cache.jac_cache === nothing
132-
grad_config = nothing
133-
g! = function (G, x)
134-
if grad_config === nothing
135-
grad_config = ForwardDiff.GradientConfig(f, x)
136-
end
137-
ForwardDiff.gradient!(G, f, x, grad_config)
138-
end
139-
hes_cache.grad! = g!
140-
hes_cache.jac_cache = ForwardColorJacCache(hes_cache.grad!, x; hes_cache.colorvec, hes_cache.sparsity)
141-
end
142140
forwarddiff_color_jacobian!(H, hes_cache.grad!, x, hes_cache.jac_cache)
143141
end
144142

@@ -154,7 +152,7 @@ end
154152

155153
function autoauto_color_hessian(f,
156154
x::AbstractArray{<:Number},
157-
hes_cache::ForwardColorHesCache)
155+
hes_cache::ForwardAutoColorHesCache)
158156
H = convert.(eltype(x), hes_cache.sparsity)
159157
autoauto_color_hessian!(H, f, x, hes_cache)
160158
return H
@@ -168,4 +166,4 @@ function autoauto_color_hessian(f,
168166
H = convert.(eltype(x), hes_cache.sparsity)
169167
autoauto_color_hessian!(H, f, x, hes_cache)
170168
return H
171-
end
169+
end

0 commit comments

Comments
 (0)