@@ -101,24 +101,33 @@ end
101101
102102# # autoauto_color_hessian
103103
104- mutable struct ForwardAutoColorHesCache{TS,TC}
105- jac_cache:: Any
106- grad!:: Any
104+ mutable struct ForwardAutoColorHesCache{TJC,TG, TS,TC}
105+ jac_cache:: TJC
106+ grad!TG
107107 sparsity:: TS
108108 colorvec:: TC
109109end
110110
111111function ForwardAutoColorHesCache (f,
112- x:: AbstractVector{<:Number } ,
112+ x:: AbstractVector{V } ,
113113 colorvec:: AbstractVector{<:Integer} = eachindex (x),
114- sparsity:: Union{AbstractMatrix,Nothing} = nothing )
114+ sparsity:: Union{AbstractMatrix,Nothing} = nothing ) where V
115115
116116 if sparsity === nothing
117117 sparsity = sparse (ones (length (x), length (x)))
118118 end
119119
120- jac_cache = nothing
121- g! = nothing
120+ tag = ForwardDiff. Tag (f, V)
121+ chunksize = ForwardDiff. pickchunksize (maximum (colorvec))
122+ chunk = ForwardDiff. Chunk (chunksize)
123+
124+ jacobian_config = ForwardDiff. JacobianConfig (f, x, chunk, tag)
125+ gradient_config = ForwardDiff. GradientConfig (f, jacobian_config. duals, chunk, tag)
126+
127+ outer_tag = get_tag (jacobian_config. duals)
128+ g! = (G, x) -> ForwardDiff. gradient! (G, f, x, gradient_config, Val (false ))
129+
130+ jac_cache = ForwardColorJacCache (g!, x; colorvec, sparsity, tag= outer_tag)
122131
123132 return ForwardAutoColorHesCache (jac_cache, g!, sparsity, colorvec)
124133end
@@ -128,17 +137,6 @@ function autoauto_color_hessian!(H::AbstractMatrix{<:Number},
128137 x:: AbstractArray{<:Number} ,
129138 hes_cache:: ForwardAutoColorHesCache )
130139
131- if hes_cache. jac_cache === nothing
132- grad_config = nothing
133- g! = function (G, x)
134- if grad_config === nothing
135- grad_config = ForwardDiff. GradientConfig (f, x)
136- end
137- ForwardDiff. gradient! (G, f, x, grad_config)
138- end
139- hes_cache. grad! = g!
140- hes_cache. jac_cache = ForwardColorJacCache (hes_cache. grad!, x; hes_cache. colorvec, hes_cache. sparsity)
141- end
142140 forwarddiff_color_jacobian! (H, hes_cache. grad!, x, hes_cache. jac_cache)
143141end
144142
154152
155153function autoauto_color_hessian (f,
156154 x:: AbstractArray{<:Number} ,
157- hes_cache:: ForwardColorHesCache )
155+ hes_cache:: ForwardAutoColorHesCache )
158156 H = convert .(eltype (x), hes_cache. sparsity)
159157 autoauto_color_hessian! (H, f, x, hes_cache)
160158 return H
@@ -168,4 +166,4 @@ function autoauto_color_hessian(f,
168166 H = convert .(eltype (x), hes_cache. sparsity)
169167 autoauto_color_hessian! (H, f, x, hes_cache)
170168 return H
171- end
169+ end
0 commit comments