Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit bc46b6b

Browse files
committed
rm ZygoteVecJac. rename ZygoteHesVec to BackHesVec, and move it to main
1 parent ab697a0 commit bc46b6b

4 files changed

Lines changed: 40 additions & 53 deletions

File tree

ext/SparseDiffToolsZygote.jl

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,13 @@ module SparseDiffToolsZygote
33
if isdefined(Base, :get_extension)
44
import Zygote
55
using LinearAlgebra
6-
using SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac
6+
using SparseDiffTools: SparseDiffTools, DeivVecTag
77
using ForwardDiff: ForwardDiff, Dual, partials
8-
using SciMLOperators: FunctionOperator
9-
using Tricks: static_hasmethod
10-
using ADTypes
118
else
129
import ..Zygote
1310
using ..LinearAlgebra
14-
using ..SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac
11+
using ..SparseDiffTools: SparseDiffTools, DeivVecTag
1512
using ..ForwardDiff: ForwardDiff, Dual, partials
16-
using ..SciMLOperators: FunctionOperator
17-
using ..Tricks: static_hasmethod
18-
using ..ADTypes
1913
end
2014

2115
### Jac, Hes products
@@ -71,39 +65,6 @@ function SparseDiffTools.autoback_hesvec(f, x, v)
7165
ForwardDiff.partials.(g(y), 1)
7266
end
7367

74-
# Operator Forms
75-
76-
function SparseDiffTools.ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoZygote())
77-
78-
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
79-
cache1 = similar(u)
80-
cache2 = similar(u)
81-
82-
(cache1, cache2), SparseDiffTools.numback_hesvec, SparseDiffTools.numback_hesvec!
83-
elseif autodiff isa AutoZygote
84-
cache1 = Dual{
85-
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
86-
}.(u, ForwardDiff.Partials.(tuple.(u)))
87-
cache2 = copy(u)
88-
89-
(cache1, cache2), SparseDiffTools.autoback_hesvec, SparseDiffTools.autoback_hesvec!
90-
end
91-
92-
outofplace = static_hasmethod(f, typeof((u,)))
93-
isinplace = static_hasmethod(f, typeof((u,)))
94-
95-
if !(isinplace) & !(outofplace)
96-
error("$f must have signature f(u).")
97-
end
98-
99-
L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)
100-
101-
FunctionOperator(L, u, u;
102-
isinplace = isinplace, outofplace = outofplace,
103-
p = p, t = t, islinear = true,
104-
)
105-
end
106-
10768
## VecJac products
10869

10970
function SparseDiffTools.auto_vecjac!(du, f, x, v, cache1 = nothing, cache2 = nothing)
@@ -116,8 +77,4 @@ function SparseDiffTools.auto_vecjac(f, x, v)
11677
return vec(back(reshape(v, size(vv)))[1])
11778
end
11879

119-
function SparseDiffTools.ZygoteVecJac(args...; kwargs...)
120-
VecJac(args...; autodiff = AutoZygote(), kwargs...)
121-
end
122-
12380
end # module

src/SparseDiffTools.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ export contract_color,
4949
autonum_hesvec, autonum_hesvec!,
5050
num_hesvecgrad, num_hesvecgrad!,
5151
auto_hesvecgrad, auto_hesvecgrad!,
52-
JacVec, HesVec, HesVecGrad, VecJac,
52+
JacVec, HesVec, BackHesVec, HesVecGrad, VecJac,
5353
update_coefficients, update_coefficients!,
5454
value!
5555

@@ -78,8 +78,6 @@ function autoback_hesvec end
7878
function autoback_hesvec! end
7979
function auto_vecjac end
8080
function auto_vecjac! end
81-
function ZygoteVecJac end
82-
function ZygoteHesVec end
8381

8482
@static if !isdefined(Base, :get_extension)
8583
function __init__()
@@ -93,7 +91,6 @@ end
9391
export
9492
numback_hesvec, numback_hesvec!,
9593
autoback_hesvec, autoback_hesvec!,
96-
auto_vecjac, auto_vecjac!,
97-
ZygoteVecJac, ZygoteHesVec
94+
auto_vecjac, auto_vecjac!
9895

9996
end # module

src/differentiation/jaches_products.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,39 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
290290
)
291291
end
292292

293+
function BackHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFiniteDiff())
294+
295+
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
296+
cache1 = similar(u)
297+
cache2 = similar(u)
298+
299+
(cache1, cache2), numback_hesvec, numback_hesvec!
300+
elseif autodiff isa AutoZygote
301+
@assert static_hasmethod(autoback_hesvec, typeof((f, u, u))) "To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"
302+
303+
cache1 = Dual{
304+
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
305+
}.(u, ForwardDiff.Partials.(tuple.(u)))
306+
cache2 = copy(u)
307+
308+
(cache1, cache2), autoback_hesvec, autoback_hesvec!
309+
end
310+
311+
outofplace = static_hasmethod(f, typeof((u,)))
312+
isinplace = static_hasmethod(f, typeof((u,)))
313+
314+
if !(isinplace) & !(outofplace)
315+
error("$f must have signature f(u).")
316+
end
317+
318+
L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)
319+
320+
FunctionOperator(L, u, u;
321+
isinplace = isinplace, outofplace = outofplace,
322+
p = p, t = t, islinear = true,
323+
)
324+
end
325+
293326
function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff())
294327

295328
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff

test/test_jaches_products.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,12 @@ dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy r
113113
out = similar(v)
114114
gmres!(out, L, v)
115115

116-
@info "ZygoteHesVec"
116+
@info "BackHesVec"
117117
using Zygote
118118
x = rand(N)
119119
v = rand(N)
120120

121-
L = ZygoteHesVec(g, x, autodiff = AutoFiniteDiff())
121+
L = BackHesVec(g, x)
122122
@test L * x numback_hesvec(g, x, x) rtol = 1e-2
123123
@test L * v numback_hesvec(g, x, v) rtol = 1e-2
124124
@test mul!(dy, L, v)numback_hesvec(g, x, v) rtol=1e-2
@@ -127,7 +127,7 @@ update_coefficients!(L, v, nothing, 0.0)
127127
@test mul!(dy, L, v)numback_hesvec(g, v, v) rtol=1e-2
128128
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*numback_hesvec(g,x,v) + b*_dy rtol=1e-2
129129

130-
L = ZygoteHesVec(g, x)
130+
L = BackHesVec(g, x, autodiff = AutoZygote())
131131
@test L * x autoback_hesvec(g, x, x)
132132
@test L * v autoback_hesvec(g, x, v)
133133
@test mul!(dy, L, v)autoback_hesvec(g, x, v) rtol=1e-8

0 commit comments

Comments
 (0)