Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit c20400a

Browse files
committed
Add tests for recursive updates of f in JacVec etc.
1 parent 59b5c76 commit c20400a

1 file changed

Lines changed: 59 additions & 23 deletions

File tree

test/test_jaches_products.jl

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,62 @@
11
using SparseDiffTools, ForwardDiff, FiniteDiff, Zygote, IterativeSolvers
22
using LinearAlgebra, Test
3+
import SciMLOperators: update_coefficients, update_coefficients!
34

45
using Random
56
Random.seed!(123)
67
N = 300
78
const A = rand(N, N)
8-
f(y, x) = mul!(y, A, x)
9-
f(x) = A * x
9+
10+
_f(y, x) = mul!(y, A, x)
11+
_f(x) = A * x
12+
1013
x = rand(N)
1114
v = rand(N)
1215
a, b = rand(2)
1316
dy = similar(x)
14-
g(x) = sum(abs2, x)
15-
function h(x)
16-
FiniteDiff.finite_difference_gradient(g, x)
17+
_g(x) = sum(abs2, x)
18+
function _h(x)
19+
FiniteDiff.finite_difference_gradient(_g, x)
20+
end
21+
function _h(dy, x)
22+
FiniteDiff.finite_difference_gradient!(dy, _g, x)
23+
end
24+
25+
# Define state-dependent (i.e. dependent on u/p/t) functions for tests of operators
26+
27+
mutable struct WrapFunc{F,U,P,T}
28+
func::F
29+
u::U
30+
p::P
31+
t::T
32+
end
33+
34+
(w::WrapFunc)(u) = sum(w.u) * w.p * w.t * w.func(u)
35+
function (w::WrapFunc)(v, u)
36+
w.func(v, u)
37+
lmul!(sum(w.u) * w.p * w.t, v)
38+
end
39+
40+
update_coefficients(w::WrapFunc, u, p, t) = WrapFunc(w.func, u, p, t)
41+
function update_coefficients!(w::WrapFunc, u, p, t)
42+
w.u = u
43+
w.p = p
44+
w.t = t
1745
end
18-
function h(dy, x)
19-
FiniteDiff.finite_difference_gradient!(dy, g, x)
46+
47+
# Helper function for testing correct update coefficients behaviour of operators
48+
function update_coefficients_for_test!(L, u, p, t)
49+
update_coefficients!(L, u, p, t)
50+
# Force function hiding inside L to update. Should be a null-op if previous line works correctly
51+
update_coefficients!(L.op.f, u, p, t)
2052
end
2153

54+
f = WrapFunc(_f, ones(N) * 2, 1.0, 1.0)
55+
g = WrapFunc(_g, ones(N), 1.0, 1.0)
56+
h = WrapFunc(_h, ones(N), 1.0, 1.0)
57+
58+
###
59+
2260
cache1 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))),
2361
eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(v)))
2462
cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(v)))
@@ -67,21 +105,21 @@ cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x)
67105

68106
@info "JacVec"
69107

70-
L = JacVec(f, x)
108+
L = JacVec(f, x, 1.0, 1.0)
71109
@test L * x auto_jacvec(f, x, x)
72110
@test L * v auto_jacvec(f, x, v)
73111
@test mul!(dy, L, v) auto_jacvec(f, x, v)
74112
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*auto_jacvec(f,x,v) + b*_dy
75-
update_coefficients!(L, v, nothing, 0.0)
113+
update_coefficients_for_test!(L, v, 3.0, 4.0)
76114
@test mul!(dy, L, v) auto_jacvec(f, v, v)
77115
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*auto_jacvec(f,x,v) + b*_dy
78116

79-
L = JacVec(f, x, autodiff = AutoFiniteDiff())
117+
L = JacVec(f, x, 1.0, 1.0; autodiff = AutoFiniteDiff())
80118
@test L * x num_jacvec(f, x, x)
81119
@test L * v num_jacvec(f, x, v)
82120
@test mul!(dy, L, v)num_jacvec(f, x, v) rtol=1e-6
83121
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_jacvec(f,x,v) + b*_dy rtol=1e-6
84-
update_coefficients!(L, v, nothing, 0.0)
122+
update_coefficients_for_test!(L, v, 3.0, 4.0)
85123
@test mul!(dy, L, v)num_jacvec(f, v, v) rtol=1e-6
86124
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_jacvec(f,x,v) + b*_dy rtol=1e-6
87125

@@ -92,38 +130,36 @@ gmres!(out, L, v)
92130

93131
x = rand(N)
94132
v = rand(N)
95-
L = HesVec(g, x, autodiff = AutoFiniteDiff())
133+
L = HesVec(g, x, 1.0, 1.0, autodiff = AutoFiniteDiff())
96134
@test L * x num_hesvec(g, x, x)
97135
@test L * v num_hesvec(g, x, v)
98136
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
99137
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_hesvec(g,x,v) + b*_dy rtol=1e-2
100-
update_coefficients!(L, v, nothing, 0.0)
138+
update_coefficients_for_test!(L, v, 3.0, 4.0)
101139
@test mul!(dy, L, v)num_hesvec(g, v, v) rtol=1e-2
102140
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_hesvec(g,x,v) + b*_dy rtol=1e-2
103141

104-
L = HesVec(g, x)
142+
L = HesVec(g, x, 1.0, 1.0)
105143
@test L * x numauto_hesvec(g, x, x)
106144
@test L * v numauto_hesvec(g, x, v)
107145
@test mul!(dy, L, v)numauto_hesvec(g, x, v) rtol=1e-8
108146
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
109-
update_coefficients!(L, v, nothing, 0.0)
147+
update_coefficients_for_test!(L, v, 3.0, 4.0)
110148
@test mul!(dy, L, v)numauto_hesvec(g, v, v) rtol=1e-8
111149
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
112150

113151
out = similar(v)
114152
gmres!(out, L, v)
115153

116-
using Zygote
117-
118154
x = rand(N)
119155
v = rand(N)
120156

121-
L = HesVec(g, x, autodiff = AutoZygote())
157+
L = HesVec(g, x, 1.0, 1.0; autodiff = AutoZygote())
122158
@test L * x autoback_hesvec(g, x, x)
123159
@test L * v autoback_hesvec(g, x, v)
124160
@test mul!(dy, L, v)autoback_hesvec(g, x, v) rtol=1e-8
125161
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8
126-
update_coefficients!(L, v, nothing, 0.0)
162+
update_coefficients_for_test!(L, v, 3.0, 4.0)
127163
@test mul!(dy, L, v)autoback_hesvec(g, v, v) rtol=1e-8
128164
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8
129165

@@ -134,21 +170,21 @@ gmres!(out, L, v)
134170

135171
x = rand(N)
136172
v = rand(N)
137-
L = HesVecGrad(h, x, autodiff = AutoFiniteDiff())
173+
L = HesVecGrad(h, x, 1.0, 1.0; autodiff = AutoFiniteDiff())
138174
@test L * x num_hesvec(g, x, x) rtol=1e-2
139175
@test L * v num_hesvec(g, x, v) rtol=1e-2
140176
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
141177
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*num_hesvec(g,x,v)+b*_dy rtol=1e-2
142-
update_coefficients!(L, v, nothing, 0.0)
178+
update_coefficients_for_test!(L, v, 3.0, 4.0)
143179
@test mul!(dy, L, v)num_hesvec(g, v, v) rtol=1e-2
144180
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*num_hesvec(g,x,v)+b*_dy rtol=1e-2
145181

146-
L = HesVecGrad(h, x)
182+
L = HesVecGrad(h, x, 1.0, 1.0)
147183
@test L * x autonum_hesvec(g, x, x)
148184
@test L * v numauto_hesvec(g, x, v)
149185
@test mul!(dy, L, v)numauto_hesvec(g, x, v) rtol=1e-8
150186
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
151-
update_coefficients!(L, v, nothing, 0.0)
187+
update_coefficients_for_test!(L, v, 3.0, 4.0)
152188
@test mul!(dy, L, v)numauto_hesvec(g, v, v) rtol=1e-8
153189
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
154190

0 commit comments

Comments
 (0)