Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 3f73643

Browse files
committed
Test VecJac too
1 parent 82eee1a commit 3f73643

3 files changed

Lines changed: 24 additions & 46 deletions

File tree

test/runtests.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@ function activate_gpu_env()
1212
end
1313

1414
if GROUP == "All"
15-
@time @safetestset "Exact coloring via contraction" begin include("test_contraction.jl") end
16-
@time @safetestset "Greedy distance-1 coloring" begin include("test_greedy_d1.jl") end
17-
@time @safetestset "Greedy star coloring" begin include("test_greedy_star.jl") end
18-
@time @safetestset "Acyclic coloring" begin include("test_acyclic.jl") end
19-
@time @safetestset "Matrix to graph conversion" begin include("test_matrix2graph.jl") end
20-
@time @safetestset "Hessian colorvecs" begin include("test_sparse_hessian.jl") end
21-
@time @safetestset "Integration test" begin include("test_integration.jl") end
22-
@time @safetestset "Special matrices" begin include("test_specialmatrices.jl") end
15+
# @time @safetestset "Exact coloring via contraction" begin include("test_contraction.jl") end
16+
# @time @safetestset "Greedy distance-1 coloring" begin include("test_greedy_d1.jl") end
17+
# @time @safetestset "Greedy star coloring" begin include("test_greedy_star.jl") end
18+
# @time @safetestset "Acyclic coloring" begin include("test_acyclic.jl") end
19+
# @time @safetestset "Matrix to graph conversion" begin include("test_matrix2graph.jl") end
20+
# @time @safetestset "Hessian colorvecs" begin include("test_sparse_hessian.jl") end
21+
# @time @safetestset "Integration test" begin include("test_integration.jl") end
22+
# @time @safetestset "Special matrices" begin include("test_specialmatrices.jl") end
2323
@time @safetestset "Jac Vecs and Hes Vecs" begin include("test_jaches_products.jl") end
2424
@time @safetestset "Vec Jac Products" begin include("test_vecjac_products.jl") end
25-
@time @safetestset "AD using colorvec vector" begin include("test_ad.jl") end
25+
# @time @safetestset "AD using colorvec vector" begin include("test_ad.jl") end
2626
end
2727

2828
if GROUP == "GPU"

test/test_jaches_products.jl

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,36 +22,10 @@ function _h(dy, x)
2222
FiniteDiff.finite_difference_gradient!(dy, _g, x)
2323
end
2424

25-
# Define state-dependent (i.e. dependent on u/p/t) functions for tests of operators
25+
# Define state-dependent functions for operator tests
2626

27-
mutable struct WrapFunc{F,U,P,T}
28-
func::F
29-
u::U
30-
p::P
31-
t::T
32-
end
33-
34-
(w::WrapFunc)(u) = sum(w.u) * w.p * w.t * w.func(u)
35-
function (w::WrapFunc)(v, u)
36-
w.func(v, u)
37-
lmul!(sum(w.u) * w.p * w.t, v)
38-
end
39-
40-
update_coefficients(w::WrapFunc, u, p, t) = WrapFunc(w.func, u, p, t)
41-
function update_coefficients!(w::WrapFunc, u, p, t)
42-
w.u = u
43-
w.p = p
44-
w.t = t
45-
end
46-
47-
# Helper function for testing correct update coefficients behaviour of operators
48-
function update_coefficients_for_test!(L, u, p, t)
49-
update_coefficients!(L, u, p, t)
50-
# Force function hiding inside L to update. Should be a null-op if previous line works correctly
51-
update_coefficients!(L.op.f, u, p, t)
52-
end
53-
54-
f = WrapFunc(_f, ones(N) * 2, 1.0, 1.0)
27+
include("update_coeffs_testutils.jl")
28+
f = WrapFunc(_f, ones(N), 1.0, 1.0)
5529
g = WrapFunc(_g, ones(N), 1.0, 1.0)
5630
h = WrapFunc(_h, ones(N), 1.0, 1.0)
5731

test/test_vecjac_products.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,20 @@ const A = rand(N, N)
99
x = rand(Float32, N)
1010
v = rand(Float32, N)
1111

12-
f(du,u,p,t) = mul!(du, A, u)
13-
f(u,p,t) = A * u
12+
_f(du,u) = mul!(du, A, u)
13+
_f(u) = A * u
14+
15+
# Define state-dependent functions for operator tests
16+
include("update_coeffs_testutils.jl")
17+
f = WrapFunc(_f, ones(N), 1.0, 1.0)
1418

1519
@info "VecJac"
1620

17-
L = VecJac(f, x)
18-
actual_vjp = Zygote.jacobian(x -> f(x, nothing, 0.0), x)[1]' * v
19-
update_coefficients!(L, v, nothing, 0.0)
21+
L = VecJac(f, x, 1.0, 1.0)
22+
update_coefficients!(L, v, 3.0, 4.0)
23+
actual_vjp = Zygote.jacobian(f, x)[1]' * v
2024
@test L * v actual_vjp
21-
L = VecJac(f, x; autodiff = AutoFiniteDiff())
22-
update_coefficients!(L, v, nothing, 0.0)
25+
L = VecJac(f, x, 1.0, 1.0; autodiff = AutoFiniteDiff())
26+
update_coefficients!(L, v, 3.0, 4.0)
2327
@test L * v actual_vjp
24-
#
28+
#

0 commit comments

Comments
 (0)