Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 9375983

Browse files
vpuri3ChrisRackauckas
authored andcommitted
tests
1 parent 84542bb commit 9375983

4 files changed

Lines changed: 77 additions & 122 deletions

File tree

test/runtests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ if GROUP == "All"
2222
@time @safetestset "Integration test" begin include("test_integration.jl") end
2323
@time @safetestset "Special matrices" begin include("test_specialmatrices.jl") end
2424
@time @safetestset "Jac Vecs and Hes Vecs" begin include("test_jaches_products.jl") end
25-
@time @safetestset "Operator tests" begin include("test_ops.jl") end
26-
25+
@time @safetestset "Vec Jac Products" begin include("test_vecjac_products.jl") end
2726
end
2827

2928
if GROUP == "GPU"

test/test_jaches_products.jl

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@ using LinearAlgebra, Test
33

44
using Random
55
Random.seed!(123)
6-
7-
const A = rand(300, 300)
6+
N = 300
7+
const A = rand(N, N)
88
f(y, x) = mul!(y, A, x)
99
f(x) = A * x
10-
x = rand(300)
11-
v = rand(300)
10+
x = rand(N)
11+
v = rand(N)
12+
a, b = rand(2)
1213
dy = similar(x)
1314
g(x) = sum(abs2, x)
1415
function h(x)
@@ -20,8 +21,7 @@ end
2021

2122
cache1 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))),
2223
eltype(x), 1}.(x, ForwardDiff.Partials.(Tuple.(v)))
23-
cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))),
24-
eltype(x), 1}.(x, ForwardDiff.Partials.(Tuple.(v)))
24+
cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), eltype(x), 1}.(x, ForwardDiff.Partials.(Tuple.(v)))
2525
@test num_jacvec!(dy, f, x, v)ForwardDiff.jacobian(f, similar(x), x) * v rtol=1e-6
2626
@test num_jacvec!(dy, f, x, v, similar(v),
2727
similar(v))ForwardDiff.jacobian(f, similar(x), x) * v rtol=1e-6
@@ -65,61 +65,90 @@ cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x)
6565
@test auto_hesvecgrad!(dy, h, x, v, cache1, cache2)ForwardDiff.hessian(g, x) * v rtol=1e-2
6666
@test auto_hesvecgrad(h, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
6767

68+
# JacVec
69+
6870
L = JacVec(f, x)
6971
@test L * x auto_jacvec(f, x, x)
7072
@test L * v auto_jacvec(f, x, v)
7173
@test mul!(dy, L, v) auto_jacvec(f, x, v)
72-
L.x .= v
74+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*auto_jacvec(f,x,v) + b*_dy
75+
update_coefficients!(L, v, nothing, 0.0)
7376
@test mul!(dy, L, v) auto_jacvec(f, v, v)
77+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*auto_jacvec(f,x,v) + b*_dy
7478

7579
L = JacVec(f, x, autodiff = false)
7680
@test L * x num_jacvec(f, x, x)
7781
@test L * v num_jacvec(f, x, v)
78-
L.x == x
7982
@test mul!(dy, L, v)num_jacvec(f, x, v) rtol=1e-6
80-
L.x .= v
83+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_jacvec(f,x,v) + b*_dy rtol=1e-6
84+
update_coefficients!(L, v, nothing, 0.0)
8185
@test mul!(dy, L, v)num_jacvec(f, v, v) rtol=1e-6
86+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_jacvec(f,x,v) + b*_dy rtol=1e-6
8287

83-
### Integration test with IterativeSolvers
8488
out = similar(v)
8589
gmres!(out, L, v)
8690

87-
x = rand(300)
88-
v = rand(300)
91+
#=
92+
ff1 = ODEFunction(lorenz, jac_prototype = JacVec{Float64}(lorenz, u0))
93+
ff2 = ODEFunction(lorenz, jac_prototype = JacVec{Float64}(lorenz, u0, autodiff=false))
94+
95+
for ff in [ff1, ff2]
96+
prob = ODEProblem(ff, u0, tspan)
97+
@test solve(prob, TRBDF2()).retcode == :Success
98+
@test solve(prob, TRBDF2(linsolve = KrylovJL_GMRES())).retcode == :Success
99+
@test solve(prob, Exprb32()).retcode == :Success
100+
@test solve(prob, Rosenbrock23()).retcode == :Success
101+
@test solve(prob, Rosenbrock23(linsolve = KrylovJL_GMRES())).retcode == :Success
102+
end
103+
=#
104+
105+
# HesVec
106+
107+
x = rand(N)
108+
v = rand(N)
89109
L = HesVec(g, x, autodiff = false)
90110
@test L * x num_hesvec(g, x, x)
91111
@test L * v num_hesvec(g, x, v)
92112
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
93-
L.x .= v
113+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_hesvec(g,x,v) + b*_dy rtol=1e-2
114+
update_coefficients!(L, v, nothing, 0.0)
94115
@test mul!(dy, L, v)num_hesvec(g, v, v) rtol=1e-2
116+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_hesvec(g,x,v) + b*_dy rtol=1e-2
95117

96118
L = HesVec(g, x)
97119
@test L * x numauto_hesvec(g, x, x)
98120
@test L * v numauto_hesvec(g, x, v)
99121
@test mul!(dy, L, v)numauto_hesvec(g, x, v) rtol=1e-8
100-
L.x .= v
122+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
123+
update_coefficients!(L, v, nothing, 0.0)
101124
@test mul!(dy, L, v)numauto_hesvec(g, v, v) rtol=1e-8
125+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
102126

103-
### Integration test with IterativeSolvers
104127
out = similar(v)
105128
gmres!(out, L, v)
106129

107-
x = rand(300)
108-
v = rand(300)
130+
# HesVecGrad
131+
132+
x = rand(N)
133+
v = rand(N)
109134
L = HesVecGrad(h, x, autodiff = false)
110135
@test L * x num_hesvec(g, x, x)
111136
@test L * v num_hesvec(g, x, v)
112137
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
113-
L.x .= v
138+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*num_hesvec(g,x,v)+b*_dy rtol=1e-2
139+
update_coefficients!(L, v, nothing, 0.0)
114140
@test mul!(dy, L, v)num_hesvec(g, v, v) rtol=1e-2
141+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*num_hesvec(g,x,v)+b*_dy rtol=1e-2
115142

116143
L = HesVecGrad(h, x, autodiff = true)
117144
@test L * x autonum_hesvec(g, x, x)
118145
@test L * v numauto_hesvec(g, x, v)
119146
@test mul!(dy, L, v)numauto_hesvec(g, x, v) rtol=1e-8
120-
L.x .= v
147+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
148+
update_coefficients!(L, v, nothing, 0.0)
121149
@test mul!(dy, L, v)numauto_hesvec(g, v, v) rtol=1e-8
150+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
122151

123-
### Integration test with IterativeSolvers
124152
out = similar(v)
125153
gmres!(out, L, v)
154+
#

test/test_ops.jl

Lines changed: 0 additions & 100 deletions
This file was deleted.

test/test_vecjac_products.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using SparseDiffTools, ForwardDiff, FiniteDiff, Zygote, IterativeSolvers
2+
using LinearAlgebra, Test
3+
4+
using Random
5+
Random.seed!(123)
6+
N = 300
7+
const A = rand(N, N)
8+
a, b = rand(2)
9+
10+
x = rand(Float32, N)
11+
v = rand(Float32, N)
12+
13+
f(du,u,p,t) = mul!(du, A, u)
14+
f(u,p,t) = A * u
15+
16+
# VecJac
17+
18+
L = VecJac(f, x)
19+
actual_vjp = Zygote.jacobian(x -> f(x, nothing, 0.0), x)[1]' * v
20+
update_coefficients!(L, v, nothing, 0.0)
21+
@test L * v actual_vjp
22+
L = VecJac(f, x; autodiff = false)
23+
update_coefficients!(L, v, nothing, 0.0)
24+
@test L * v actual_vjp
25+
#dy=rand(N); @test mul!(dy, L, v) ≈ actual_vjp
26+
#dy=rand(N); _dy=copy(dy); @test mul!(dy,L,v,a,b) ≈ a * actual_vjp + b * _dy
27+
#

0 commit comments

Comments
 (0)