Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 3c743ff

Browse files
committed
add tests
1 parent 5a02552 commit 3c743ff

3 files changed

Lines changed: 44 additions & 34 deletions

File tree

ext/SparseDiffToolsZygote.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ end
7575

7676
## VecJac products
7777

78-
function SparseDiffTools.auto_vecjac!(du, f, x, v, cache1 = nothing, cache2 = nothing)
79-
!hasmethod(f, (typeof(x),)) && error("For inplace function use autodiff = false")
78+
function SparseDiffTools.auto_vecjac!(du, f, x, v)
79+
!hasmethod(f, (typeof(x),)) && error("For inplace function use autodiff = AutoFiniteDiff()")
8080
du .= reshape(SparseDiffTools.auto_vecjac(f, x, v), size(du))
8181
end
8282

test/test_jaches_products.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ L = JacVec(f, copy(x), 1.0, 1.0; tag = MyTag())
143143

144144
# Resize test
145145
for M in (100, 400)
146-
L = JacVec(f2, copy(x), 1.0, 1.0)
146+
local L = JacVec(f2, copy(x), 1.0, 1.0)
147147
resize!(L, M)
148148
_x = resize!(copy(x), M)
149149
_u = rand(M)

test/test_vecjac_products.jl

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using SparseDiffTools, ForwardDiff, FiniteDiff, Zygote, IterativeSolvers
1+
using SparseDiffTools, Zygote
22
using LinearAlgebra, Test
33

44
using Random
@@ -13,52 +13,62 @@ v = rand(Float32, N)
1313
x0 = copy(x)
1414
v0 = copy(v)
1515

16-
a, b = rand(2)
17-
dy = similar(x)
16+
a, b = rand(Float32, 2)
1817

1918
A = rand(Float32, N, N)
20-
_f(du, u) = mul!(du, A, u)
21-
_f(u) = A * u
19+
_f(y, x) = mul!(y, A, x .^ 2)
20+
_f(x) = A * (x .^ 2)
2221

2322
# Define state-dependent functions for operator tests
2423
include("update_coeffs_testutils.jl")
2524
f = WrapFunc(_f, 1.0f0, 1.0f0)
2625

26+
@test auto_vecjac(f, x, v) Zygote.jacobian(f, x)[1]' * v
27+
@test auto_vecjac!(zero(x), f, x, v) auto_vecjac(f, x, v)
28+
@test num_vecjac!(zero(x), f, copy(x), v) num_vecjac(f, copy(x), v)
29+
@test auto_vecjac(f, x, v) num_vecjac(f, copy(x), copy(v)) rtol = 1e-2
30+
2731
# Compute Jacobian via Zygote
2832

29-
@info "VecJac"
33+
@info "VecJac AutoZygote"
3034

3135
L = VecJac(f, copy(x), 1.0f0, 1.0f0; autodiff = AutoZygote())
32-
update_coefficients!(f, x, 1.0, 1.0)
33-
actual_jac = Zygote.jacobian(f, x)[1]
34-
@test L * x actual_jac' * x
35-
@test L * v actual_jac' * v
36-
@test mul!(dy, L, v) actual_jac' * v
36+
37+
Jtrue = Zygote.jacobian(f, x)[1]
38+
39+
@test L * x Jtrue' * x
40+
y=zero(x); @test mul!(y, L, v) Jtrue' * v
41+
@test L(x, 1.0f0, 1.0f0) Jtrue' * x
42+
y=zero(x); @test L(y, x, 1.0f0, 1.0f0) Jtrue' * x
43+
44+
@test L * v Jtrue' * v
45+
y=zero(x); @test mul!(y, L, v) Jtrue' * v
46+
# @test L(v, 1.0f0, 1.0f0) ≈ Jtrue' * v
47+
# y=zero(v); @test L(y, v, 1.0f0, 1.0f0) ≈ Jtrue' * v
48+
3749
update_coefficients!(L, v, 3.0, 4.0)
38-
update_coefficients!(f, v, 3.0, 4.0)
39-
actual_jac = Zygote.jacobian(f, v)[1]
40-
@test mul!(dy, L, x) actual_jac' * x
41-
_dy = copy(dy);
42-
@test mul!(dy, L, x, a, b) a * actual_jac' * x + b * _dy;
50+
Jtrue = Zygote.jacobian(f, v)[1]
51+
@test mul!(y, L, x) Jtrue' * x
52+
_y=copy(y); @test mul!(y, L, x, a, b) a * Jtrue' * x + b * _y;
53+
4354
update_coefficients!(f, v, 5.0, 6.0)
44-
actual_jac = Zygote.jacobian(f, v)[1]
45-
@test L(dy, v, 5.0, 6.0) actual_jac' * v
55+
Jtrue = Zygote.jacobian(f, v)[1]
56+
y=zero(x); @test L(y, v, 5.0, 6.0) Jtrue' * v
57+
58+
@info "VecJac AutoFiniteDiff"
4659

4760
L = VecJac(f, copy(x), 1.0f0, 1.0f0; autodiff = AutoFiniteDiff())
48-
update_coefficients!(f, x, 1.0, 1.0)
49-
actual_jac = Zygote.jacobian(f, x)[1]
50-
@test L * x actual_jac' * x
51-
@test L * v actual_jac' * v
52-
@test mul!(dy, L, v) actual_jac' * v
61+
62+
@test L * x num_vecjac(f, copy(x), x)
63+
@test L * v num_vecjac(f, copy(x), v)
64+
y=zero(x); @test mul!(y, L, v) num_vecjac(f, copy(x), v)
65+
5366
update_coefficients!(L, v, 3.0, 4.0)
54-
update_coefficients!(f, v, 3.0, 4.0)
55-
actual_jac = Zygote.jacobian(f, v)[1]
56-
@test mul!(dy, L, x) actual_jac' * x
57-
_dy = copy(dy);
58-
@test mul!(dy, L, x, a, b) a * actual_jac' * x + b * _dy;
67+
@test mul!(y, L, x) num_vecjac(f, copy(v), x)
68+
_y = copy(y); @test mul!(y, L, x, a, b) a * num_vecjac(f,copy(v),x) + b * _y
69+
5970
update_coefficients!(f, v, 5.0, 6.0)
60-
actual_jac = Zygote.jacobian(f, v)[1]
61-
@test L(dy, v, 5.0, 6.0) actual_jac' * v
71+
@test L(y, v, 5.0, 6.0) num_vecjac(f, copy(v), v)
6272

6373
# Test that x and v were not mutated
6474
@test x x0
@@ -69,7 +79,7 @@ f2(x) = 2x
6979
f2(y, x) = (copy!(y, x); lmul!(2, y); y)
7080

7181
for M in (100, 400)
72-
L = VecJac(f2, copy(x), 1.0f0, 1.0f0; autodiff = AutoZygote())
82+
local L = VecJac(f2, copy(x), 1.0f0, 1.0f0; autodiff = AutoZygote())
7383
resize!(L, M)
7484
_x = resize!(copy(x), M)
7585
_u = rand(M)

0 commit comments

Comments
 (0)