@@ -6,12 +6,15 @@ Random.seed!(123)
66N = 300
77
88# Use Float32 since Zygote defaults to Float32
9- x = rand (Float32, N)
9+ x1 = rand (Float32, N)
10+ x2 = rand (Float32, N)
11+
1012v = rand (Float32, N)
1113
1214# Save original values of x and v to make sure they are not ever mutated
13- x0 = copy (x)
14- v0 = copy (v)
15+ _x1 = copy (x1)
16+ _x2 = copy (x2)
17+ _v = copy (v)
1518
1619a, b = rand (Float32, 2 )
1720
@@ -23,81 +26,106 @@ _f(x) = A * (x .^ 2)
2326include (" update_coeffs_testutils.jl" )
2427f = WrapFunc (_f, 1.0f0 , 1.0f0 )
2528
26- @test auto_vecjac (f, x , v) ≈ Zygote. jacobian (f, x )[1 ]' * v
27- @test auto_vecjac! (zero (x ), f, x , v) ≈ auto_vecjac (f, x , v)
28- @test num_vecjac! (zero (x ), f, copy (x ), v) ≈ num_vecjac (f, copy (x ), v)
29- @test auto_vecjac (f, x , v) ≈ num_vecjac (f, copy (x ), copy (v)) rtol = 1e-2
29+ @test auto_vecjac (f, x1 , v) ≈ Zygote. jacobian (f, x1 )[1 ]' * v
30+ @test auto_vecjac! (zero (x1 ), f, x1 , v) ≈ auto_vecjac (f, x1 , v)
31+ @test num_vecjac! (zero (x1 ), f, copy (x1 ), v) ≈ num_vecjac (f, copy (x1 ), v)
32+ @test auto_vecjac (f, x1 , v) ≈ num_vecjac (f, copy (x1 ), copy (v)) rtol = 1e-2
3033
3134# Compute Jacobian via Zygote
3235
3336@info " VecJac AutoZygote"
3437
35- L = VecJac (f, copy (x), 1.0f0 , 1.0f0 ; autodiff = AutoZygote ())
38+ p, t = rand (Float32, 2 )
39+ L = VecJac (f, copy (x1), p, t; autodiff = AutoZygote ())
40+ update_coefficients! (L, v, p, t)
3641
37- Jx = Zygote. jacobian (f, x)[1 ]
38- Jv = Zygote. jacobian (f, v)[1 ]
42+ update_coefficients! (f, v, p, t)
43+ J1 = Zygote. jacobian (f, x1)[1 ]
44+ J2 = Zygote. jacobian (f, x2)[1 ]
3945
40- @test L * x ≈ Jx' * x
41- @test L * v ≈ Jx' * v
42- y= zero (x); @test mul! (y, L, v) ≈ Jx' * v
43- y= zero (x); @test mul! (y, L, v) ≈ Jx' * v
46+ # test operator application
47+ @test L * v ≈ J1' * v
48+ @test L (v, p, t) ≈ J1' * v
49+ y= zeros (N); @test mul! (y, L, v) ≈ J1' * v
50+ y= zeros (N); @test L (y, v, p, t) ≈ J1' * v
4451
45- @test L (x, 1.0f0 , 1.0f0 ) ≈ Jx' * x
46- y= zero (x); @test L (y, x, 1.0f0 , 1.0f0 ) ≈ Jx' * x
47- @test L (v, 1.0f0 , 1.0f0 ) ≈ Jv' * v
48- y= zero (v); @test L (y, v, 1.0f0 , 1.0f0 ) ≈ Jv' * v
52+ # use kwarg VJP_input = x2
53+ @test L (v, p, t; VJP_input = x2) ≈ J2' * v
54+ y= zeros (N); @test L (y, v, p, t; VJP_input = x2) ≈ J2' * v
4955
50- update_coefficients! (L, v, 3.0 , 4.0 )
56+ # update_coefficients
57+ p, t = rand (Float32, 2 )
58+ L = update_coefficients (L, v, p, t; JVP_input = x2)
5159
52- Jx = Zygote. jacobian (f, x)[1 ]
53- Jv = Zygote. jacobian (f, v)[1 ]
60+ update_coefficients! (f, v, p, t)
61+ J1 = Zygote. jacobian (f, x1)[1 ]
62+ J2 = Zygote. jacobian (f, x2)[1 ]
5463
55- @test L * x ≈ Jv' * x
56- @test L * v ≈ Jv' * v
57- y= zero (x); @test mul! (y, L, v) ≈ Jv' * v
58- y= zero (x); @test mul! (y, L, v) ≈ Jv' * v
64+ # @show p, t
65+ # @show f.p, f.t
66+ # @show L.op.f.p, L.op.f.t
5967
60- @test L (x, 3.0f0 , 4.0f0 ) ≈ Jx' * x
61- y= zero (x); @test L (y, x, 3.0f0 , 4.0f0 ) ≈ Jx' * x
62- @test L (v, 3.0f0 , 4.0f0 ) ≈ Jv' * v
63- y= zero (v); @test L (y, v, 3.0f0 , 4.0f0 ) ≈ Jv' * v
68+ @test L * v ≈ J2' * v
69+ @test L (v, p, t) ≈ J2' * v
70+ y= zeros (N); @test mul! (y, L, v) ≈ J2' * v
71+ y= zeros (N); @test L (y, v, p, t) ≈ J2' * v
72+
73+ # use kwarg VJP_input = x1
74+ @test L (v, p, t; VJP_input = x1) ≈ J1' * v
75+ y= zeros (N); @test L (y, v, p, t; VJP_input = x1) ≈ J1' * v
6476
6577@info " VecJac AutoFiniteDiff"
6678
67- L = VecJac (f, copy (x), 1.0f0 , 1.0f0 ; autodiff = AutoFiniteDiff ())
79+ p, t = rand (Float32, 2 )
80+ L = VecJac (f, copy (x1), 1.0f0 , 1.0f0 ; autodiff = AutoFiniteDiff ())
81+ update_coefficients! (L, v, p, t)
82+ update_coefficients! (f, v, p, t)
83+
84+ @test L * v ≈ num_vecjac (f, copy (x1), v)
85+ @test L (v, p, t) ≈ num_vecjac (f, copy (x1), v)
86+ y= zeros (N); @test mul! (y, L, v) ≈ num_vecjac (f, copy (x1), v)
87+ y= zeros (N); @test L (y, v, p, t) ≈ num_vecjac (f, copy (x1), v)
6888
69- @test L * x ≈ num_vecjac (f, copy (x), x)
70- @test L * v ≈ num_vecjac (f, copy (x ), v)
71- y= zero (x ); @test mul! (y, L, v ) ≈ num_vecjac (f, copy (x ), v)
89+ # use kwarg VJP_input = x2
90+ @test L (v, p, t; VJP_input = x2) ≈ num_vecjac (f, copy (x2 ), v)
91+ y= zeros (N ); @test L (y, v, p, t; VJP_input = x2 ) ≈ num_vecjac (f, copy (x2 ), v)
7292
73- update_coefficients! (L, v, 3.0 , 4.0 )
74- @test mul! (y, L, x) ≈ num_vecjac (f, copy (v), x)
75- _y = copy (y); @test mul! (y, L, x, a, b) ≈ a * num_vecjac (f,copy (v),x) + b * _y
93+ # update_coefficients
94+ p, t = rand (Float32, 2 )
95+ L = update_coefficients (L, v, p, t; JVP_input = x2)
96+ update_coefficients! (f, v, p, t)
7697
77- update_coefficients! (f, v, 5.0 , 6.0 )
78- @test L (y, v, 5.0 , 6.0 ) ≈ num_vecjac (f, copy (v), v)
98+ @test L * v ≈ num_vecjac (f, copy (x2), v)
99+ @test L (v, p, t) ≈ num_vecjac (f, copy (x2), v)
100+ y= zeros (N); @test mul! (y, L, v) ≈ num_vecjac (f, copy (x2), v)
101+ y= zeros (N); @test L (y, v, p, t) ≈ num_vecjac (f, copy (x2), v)
102+
103+ # use kwarg VJP_input = x2
104+ @test L (v, p, t; VJP_input = x1) ≈ num_vecjac (f, copy (x1), v)
105+ y= zeros (N); @test L (y, v, p, t; VJP_input = x1) ≈ num_vecjac (f, copy (x1), v)
79106
80107# Test that x and v were not mutated
81- @test x ≈ x0
82- @test v ≈ v0
108+ @test x1 ≈ _x1
109+ @test x2 ≈ _x2
110+ @test v ≈ v
83111
84112@info " Base.resize!"
85113
86114# Resize test
87115f2 (x) = 2 x
88116f2 (y, x) = (copy! (y, x); lmul! (2 , y); y)
89117
118+ x = rand (Float32, N)
90119for M in (100 , 400 )
91120 local L = VecJac (f2, copy (x), 1.0f0 , 1.0f0 ; autodiff = AutoZygote ())
92121 resize! (L, M)
93122
94123 _x = resize! (copy (x), M)
95124 _u = rand (M)
96- J2 = Zygote. jacobian (f2, _x)[1 ]
125+ local J2 = Zygote. jacobian (f2, _x)[1 ]
97126
98- update_coefficients! (L, _x , 1.0f0 , 1.0f0 )
127+ update_coefficients! (L, _u , 1.0f0 , 1.0f0 ; VJP_input = _x )
99128 @test L * _u ≈ J2' * _u rtol= 1e-6
100- _v = zeros (M); @test mul! (_v, L, _u) ≈ J2' * _u rtol= 1e-6
129+ local _v = zeros (M); @test mul! (_v, L, _u) ≈ J2' * _u rtol= 1e-6
101130end
102-
103131#
0 commit comments