Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 9c3b77d

Browse files
committed
add auto-auto hessian
1 parent 108df97 commit 9c3b77d

1 file changed

Lines changed: 100 additions & 27 deletions

File tree

Lines changed: 100 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct ForwardColorHesCache{THS, THC, TI<:Integer, TD, TGF, TGC, TG}
1+
struct ForwardColorHesCache{THS,THC,TI<:Integer,TD,TGF,TGC,TG}
22
sparsity::THS
33
colors::THC
44
ncolors::TI
@@ -16,40 +16,40 @@ function make_hessian_buffers(colorvec, x)
1616
buffer = similar(D)
1717
G1 = similar(x)
1818
G2 = similar(x)
19-
return (;ncolors, D, buffer, G1, G2)
19+
return (; ncolors, D, buffer, G1, G2)
2020
end
2121

22-
function ForwardColorHesCache(f,
23-
x::AbstractVector{<:Number},
24-
colorvec::AbstractVector{<:Integer}=eachindex(x),
25-
sparsity::Union{AbstractMatrix, Nothing}=nothing,
26-
g! = (G, x, grad_config) -> ForwardDiff.gradient!(G, f, x, grad_config))
22+
function ForwardColorHesCache(f,
23+
x::AbstractVector{<:Number},
24+
colorvec::AbstractVector{<:Integer}=eachindex(x),
25+
sparsity::Union{AbstractMatrix,Nothing}=nothing,
26+
(g!)=(G, x, grad_config) -> ForwardDiff.gradient!(G, f, x, grad_config))
2727
ncolors, D, buffer, G, G2 = make_hessian_buffers(colorvec, x)
2828
grad_config = ForwardDiff.GradientConfig(f, x)
29-
29+
3030
# If user supplied their own gradient function, make sure it has the right
3131
# signature (i.e. g!(G, x) or g!(G, x, grad_config::ForwardDiff.GradientConfig))
32-
if ! hasmethod(g!, (typeof(G), typeof(G), typeof(grad_config)))
33-
if ! hasmethod(g!, (typeof(G), typeof(G)))
32+
if !hasmethod(g!, (typeof(G), typeof(G), typeof(grad_config)))
33+
if !hasmethod(g!, (typeof(G), typeof(G)))
3434
throw(ArgumentError("Signature of `g!` must be either `g!(G, x)` or `g!(G, x, grad_config::ForwardDiff.GradientConfig)`"))
3535
end
3636
# define new method that takes a GradientConfig but doesn't use it
3737
g1!(G, x, grad_config) = g!(G, x)
3838
else
3939
g1! = g!
4040
end
41-
41+
4242
if sparsity === nothing
4343
sparsity = sparse(ones(length(x), length(x)))
4444
end
4545
return ForwardColorHesCache(sparsity, colorvec, ncolors, D, buffer, g1!, grad_config, G, G2)
4646
end
4747

48-
function numauto_color_hessian!(H::AbstractMatrix{<:Number},
49-
f,
50-
x::AbstractArray{<:Number},
51-
hes_cache::ForwardColorHesCache;
52-
safe = true)
48+
function numauto_color_hessian!(H::AbstractMatrix{<:Number},
49+
f,
50+
x::AbstractArray{<:Number},
51+
hes_cache::ForwardColorHesCache;
52+
safe=true)
5353
ϵ = cbrt(eps(eltype(x)))
5454
for j in 1:hes_cache.ncolors
5555
x .+= ϵ .* @view hes_cache.D[:, j]
@@ -69,30 +69,103 @@ function numauto_color_hessian!(H::AbstractMatrix{<:Number},
6969
return H
7070
end
7171

72-
function numauto_color_hessian!(H::AbstractMatrix{<:Number},
73-
f,
74-
x::AbstractArray{<:Number},
75-
colorvec::AbstractVector{<:Integer}=eachindex(x),
76-
sparsity::Union{AbstractMatrix, Nothing}=nothing)
72+
function numauto_color_hessian!(H::AbstractMatrix{<:Number},
73+
f,
74+
x::AbstractArray{<:Number},
75+
colorvec::AbstractVector{<:Integer}=eachindex(x),
76+
sparsity::Union{AbstractMatrix,Nothing}=nothing)
7777
hes_cache = ForwardColorHesCache(f, x, colorvec, sparsity)
7878
numauto_color_hessian!(H, f, x, hes_cache)
7979
return H
8080
end
8181

82-
function numauto_color_hessian(f,
83-
x::AbstractArray{<:Number},
84-
hes_cache::ForwardColorHesCache)
82+
function numauto_color_hessian(f,
83+
x::AbstractArray{<:Number},
84+
hes_cache::ForwardColorHesCache)
8585
H = convert.(eltype(x), hes_cache.sparsity)
8686
numauto_color_hessian!(H, f, x, hes_cache)
8787
return H
8888
end
8989

9090
function numauto_color_hessian(f,
91-
x::AbstractArray{<:Number},
92-
colorvec::AbstractVector{<:Integer}=eachindex(x),
93-
sparsity::Union{AbstractMatrix, Nothing}=nothing)
91+
x::AbstractArray{<:Number},
92+
colorvec::AbstractVector{<:Integer}=eachindex(x),
93+
sparsity::Union{AbstractMatrix,Nothing}=nothing)
9494
hes_cache = ForwardColorHesCache(f, x, colorvec, sparsity)
9595
H = convert.(eltype(x), hes_cache.sparsity)
9696
numauto_color_hessian!(H, f, x, hes_cache)
9797
return H
9898
end
99+
100+
101+
102+
## autoauto_color_hessian
103+
104+
mutable struct ForwardAutoColorHesCache{TS,TC}
105+
jac_cache::Any
106+
grad!::Any
107+
sparsity::TS
108+
colorvec::TC
109+
end
110+
111+
function ForwardAutoColorHesCache(f,
112+
x::AbstractVector{<:Number},
113+
colorvec::AbstractVector{<:Integer}=eachindex(x),
114+
sparsity::Union{AbstractMatrix,Nothing}=nothing)
115+
116+
if sparsity === nothing
117+
sparsity = sparse(ones(length(x), length(x)))
118+
end
119+
120+
jac_cache = nothing
121+
g! = nothing
122+
123+
return ForwardAutoColorHesCache(jac_cache, g!, sparsity, colorvec)
124+
end
125+
126+
function autoauto_color_hessian!(H::AbstractMatrix{<:Number},
127+
f,
128+
x::AbstractArray{<:Number},
129+
hes_cache::ForwardAutoColorHesCache)
130+
131+
if hes_cache.jac_cache === nothing
132+
grad_config = nothing
133+
g! = function (G, x)
134+
if grad_config === nothing
135+
grad_config = ForwardDiff.GradientConfig(f, x)
136+
end
137+
ForwardDiff.gradient!(G, f, x, grad_config)
138+
end
139+
hes_cache.grad! = g!
140+
hes_cache.jac_cache = ForwardColorJacCache(hes_cache.grad!, x; hes_cache.colorvec, hes_cache.sparsity)
141+
end
142+
forwarddiff_color_jacobian!(H, hes_cache.grad!, x, hes_cache.jac_cache)
143+
end
144+
145+
function autoauto_color_hessian!(H::AbstractMatrix{<:Number},
146+
f,
147+
x::AbstractArray{<:Number},
148+
colorvec::AbstractVector{<:Integer}=eachindex(x),
149+
sparsity::Union{AbstractMatrix,Nothing}=nothing)
150+
hes_cache = ForwardAutoColorHesCache(f, x, colorvec, sparsity)
151+
autoauto_color_hessian!(H, f, x, hes_cache)
152+
return H
153+
end
154+
155+
function autoauto_color_hessian(f,
156+
x::AbstractArray{<:Number},
157+
hes_cache::ForwardColorHesCache)
158+
H = convert.(eltype(x), hes_cache.sparsity)
159+
autoauto_color_hessian!(H, f, x, hes_cache)
160+
return H
161+
end
162+
163+
function autoauto_color_hessian(f,
164+
x::AbstractArray{<:Number},
165+
colorvec::AbstractVector{<:Integer}=eachindex(x),
166+
sparsity::Union{AbstractMatrix,Nothing}=nothing)
167+
hes_cache = ForwardAutoColorHesCache(f, x, colorvec, sparsity)
168+
H = convert.(eltype(x), hes_cache.sparsity)
169+
autoauto_color_hessian!(H, f, x, hes_cache)
170+
return H
171+
end

0 commit comments

Comments
 (0)