Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 62cdc52

Browse files
committed
add tests for auto-auto hessians
1 parent 91dc8ad commit 62cdc52

2 files changed

Lines changed: 55 additions & 28 deletions

File tree

src/differentiation/compute_hessian_ad.jl

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct ForwardColorHesCache{THS,THC,TI<:Integer,TD,TGF,TGC,TG}
1+
struct ForwardColorHesCache{THS, THC, TI<:Integer, TD, TGF, TGC, TG}
22
sparsity::THS
33
colors::THC
44
ncolors::TI
@@ -16,40 +16,40 @@ function make_hessian_buffers(colorvec, x)
1616
buffer = similar(D)
1717
G1 = similar(x)
1818
G2 = similar(x)
19-
return (; ncolors, D, buffer, G1, G2)
19+
return (;ncolors, D, buffer, G1, G2)
2020
end
2121

22-
function ForwardColorHesCache(f,
23-
x::AbstractVector{<:Number},
24-
colorvec::AbstractVector{<:Integer}=eachindex(x),
25-
sparsity::Union{AbstractMatrix,Nothing}=nothing,
26-
(g!)=(G, x, grad_config) -> ForwardDiff.gradient!(G, f, x, grad_config))
22+
function ForwardColorHesCache(f,
23+
x::AbstractVector{<:Number},
24+
colorvec::AbstractVector{<:Integer}=eachindex(x),
25+
sparsity::Union{AbstractMatrix, Nothing}=nothing,
26+
g! = (G, x, grad_config) -> ForwardDiff.gradient!(G, f, x, grad_config))
2727
ncolors, D, buffer, G, G2 = make_hessian_buffers(colorvec, x)
2828
grad_config = ForwardDiff.GradientConfig(f, x)
29-
29+
3030
# If user supplied their own gradient function, make sure it has the right
3131
# signature (i.e. g!(G, x) or g!(G, x, grad_config::ForwardDiff.GradientConfig))
32-
if !hasmethod(g!, (typeof(G), typeof(G), typeof(grad_config)))
33-
if !hasmethod(g!, (typeof(G), typeof(G)))
32+
if ! hasmethod(g!, (typeof(G), typeof(G), typeof(grad_config)))
33+
if ! hasmethod(g!, (typeof(G), typeof(G)))
3434
throw(ArgumentError("Signature of `g!` must be either `g!(G, x)` or `g!(G, x, grad_config::ForwardDiff.GradientConfig)`"))
3535
end
3636
# define new method that takes a GradientConfig but doesn't use it
3737
g1!(G, x, grad_config) = g!(G, x)
3838
else
3939
g1! = g!
4040
end
41-
41+
4242
if sparsity === nothing
4343
sparsity = sparse(ones(length(x), length(x)))
4444
end
4545
return ForwardColorHesCache(sparsity, colorvec, ncolors, D, buffer, g1!, grad_config, G, G2)
4646
end
4747

48-
function numauto_color_hessian!(H::AbstractMatrix{<:Number},
49-
f,
50-
x::AbstractArray{<:Number},
51-
hes_cache::ForwardColorHesCache;
52-
safe=true)
48+
function numauto_color_hessian!(H::AbstractMatrix{<:Number},
49+
f,
50+
x::AbstractArray{<:Number},
51+
hes_cache::ForwardColorHesCache;
52+
safe = true)
5353
ϵ = cbrt(eps(eltype(x)))
5454
for j in 1:hes_cache.ncolors
5555
x .+= ϵ .* @view hes_cache.D[:, j]
@@ -69,28 +69,28 @@ function numauto_color_hessian!(H::AbstractMatrix{<:Number},
6969
return H
7070
end
7171

72-
function numauto_color_hessian!(H::AbstractMatrix{<:Number},
73-
f,
74-
x::AbstractArray{<:Number},
75-
colorvec::AbstractVector{<:Integer}=eachindex(x),
76-
sparsity::Union{AbstractMatrix,Nothing}=nothing)
72+
function numauto_color_hessian!(H::AbstractMatrix{<:Number},
73+
f,
74+
x::AbstractArray{<:Number},
75+
colorvec::AbstractVector{<:Integer}=eachindex(x),
76+
sparsity::Union{AbstractMatrix, Nothing}=nothing)
7777
hes_cache = ForwardColorHesCache(f, x, colorvec, sparsity)
7878
numauto_color_hessian!(H, f, x, hes_cache)
7979
return H
8080
end
8181

82-
function numauto_color_hessian(f,
83-
x::AbstractArray{<:Number},
84-
hes_cache::ForwardColorHesCache)
82+
function numauto_color_hessian(f,
83+
x::AbstractArray{<:Number},
84+
hes_cache::ForwardColorHesCache)
8585
H = convert.(eltype(x), hes_cache.sparsity)
8686
numauto_color_hessian!(H, f, x, hes_cache)
8787
return H
8888
end
8989

9090
function numauto_color_hessian(f,
91-
x::AbstractArray{<:Number},
92-
colorvec::AbstractVector{<:Integer}=eachindex(x),
93-
sparsity::Union{AbstractMatrix,Nothing}=nothing)
91+
x::AbstractArray{<:Number},
92+
colorvec::AbstractVector{<:Integer}=eachindex(x),
93+
sparsity::Union{AbstractMatrix, Nothing}=nothing)
9494
hes_cache = ForwardColorHesCache(f, x, colorvec, sparsity)
9595
H = convert.(eltype(x), hes_cache.sparsity)
9696
numauto_color_hessian!(H, f, x, hes_cache)
@@ -103,7 +103,7 @@ end
103103

104104
mutable struct ForwardAutoColorHesCache{TJC,TG,TS,TC}
105105
jac_cache::TJC
106-
grad!TG
106+
grad!::TG
107107
sparsity::TS
108108
colorvec::TC
109109
end

test/test_sparse_hessian.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,30 @@ for (i, hescache) in enumerate([hescache1, hescache2, hescache3, hescache4, hesc
9898
# for _ in 1:100)
9999
# @test t_unsafe <= t_safe
100100
end
101+
102+
103+
hescache1 = ForwardAutoColorHesCache(fscalar, x, colors, sparsity)
104+
hescache2 = ForwardAutoColorHesCache(fscalar, x)
105+
106+
107+
for (i, hescache) in enumerate([hescache1, hescache2])
108+
109+
H = SparseDiffTools.autoauto_color_hessian(fscalar, x, colors, sparsity)
110+
H1 = SparseDiffTools.autoauto_color_hessian(fscalar, x, hescache)
111+
H2 = SparseDiffTools.autoauto_color_hessian(fscalar, x)
112+
@test all(isapprox.(Hforward, H, rtol=1e-6))
113+
@test all(isapprox.(H, H1, rtol=1e-6))
114+
@test all(isapprox.(H2, H1, rtol=1e-6))
115+
116+
H1 = similar(H)
117+
118+
SparseDiffTools.autoauto_color_hessian!(H1, fscalar, x, collect(hescache.colorvec), hescache.sparsity)
119+
@test all(isapprox.(H1, H))
120+
121+
SparseDiffTools.autoauto_color_hessian!(H2, fscalar, x)
122+
@test all(isapprox.(H2, H))
123+
124+
SparseDiffTools.autoauto_color_hessian!(H1, fscalar, x, hescache)
125+
@test all(isapprox.(H1, H))
126+
127+
end

0 commit comments

Comments
 (0)