Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit ab697a0

Browse files
vpuri3ChrisRackauckas
authored andcommitted
multivalue autodiff
1 parent e7150bf commit ab697a0

7 files changed

Lines changed: 65 additions & 58 deletions

File tree

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Pankaj Mishra <pankajmishra1511@gmail.com>", "Chris Rackauckas <cont
44
version = "2.0.0"
55

66
[deps]
7+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
78
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
89
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
910
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

ext/SparseDiffToolsZygote.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ if isdefined(Base, :get_extension)
77
using ForwardDiff: ForwardDiff, Dual, partials
88
using SciMLOperators: FunctionOperator
99
using Tricks: static_hasmethod
10+
using ADTypes
1011
else
1112
import ..Zygote
1213
using ..LinearAlgebra
1314
using ..SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac
1415
using ..ForwardDiff: ForwardDiff, Dual, partials
1516
using ..SciMLOperators: FunctionOperator
1617
using ..Tricks: static_hasmethod
18+
using ..ADTypes
1719
end
1820

1921
### Jac, Hes products
@@ -71,22 +73,21 @@ end
7173

7274
# Operator Forms
7375

74-
function SparseDiffTools.ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
76+
function SparseDiffTools.ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoZygote())
7577

76-
if autodiff
78+
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
79+
cache1 = similar(u)
80+
cache2 = similar(u)
81+
82+
(cache1, cache2), SparseDiffTools.numback_hesvec, SparseDiffTools.numback_hesvec!
83+
elseif autodiff isa AutoZygote
7784
cache1 = Dual{
7885
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
7986
}.(u, ForwardDiff.Partials.(tuple.(u)))
8087
cache2 = copy(u)
81-
else
82-
cache1 = similar(u)
83-
cache2 = similar(u)
84-
end
8588

86-
cache = (cache1, cache2,)
87-
88-
vecprod = autodiff ? SparseDiffTools.autoback_hesvec : SparseDiffTools.numback_hesvec
89-
vecprod! = autodiff ? SparseDiffTools.autoback_hesvec! : SparseDiffTools.numback_hesvec!
89+
(cache1, cache2), SparseDiffTools.autoback_hesvec, SparseDiffTools.autoback_hesvec!
90+
end
9091

9192
outofplace = static_hasmethod(f, typeof((u,)))
9293
isinplace = static_hasmethod(f, typeof((u,)))
@@ -115,8 +116,8 @@ function SparseDiffTools.auto_vecjac(f, x, v)
115116
return vec(back(reshape(v, size(vv)))[1])
116117
end
117118

118-
function SparseDiffTools.ZygoteVecJac(args...; autodiff = true, kwargs...)
119-
VecJac(args...; autodiff = autodiff, kwargs...)
119+
function SparseDiffTools.ZygoteVecJac(args...; kwargs...)
120+
VecJac(args...; autodiff = AutoZygote(), kwargs...)
120121
end
121122

122123
end # module

src/SparseDiffTools.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ using Graphs
77
using Graphs: SimpleGraph
88
using VertexSafeGraphs
99
using Adapt
10+
using Reexport
11+
@reexport using ADTypes
1012

1113
using LinearAlgebra
1214
using SparseArrays, ArrayInterface
@@ -69,7 +71,6 @@ parameterless_type(x) = parameterless_type(typeof(x))
6971
parameterless_type(x::Type) = __parameterless_type(x)
7072

7173
import Requires
72-
import Reexport
7374

7475
function numback_hesvec end
7576
function numback_hesvec! end
@@ -84,7 +85,7 @@ function ZygoteHesVec end
8485
function __init__()
8586
Requires.@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
8687
include("../ext/SparseDiffToolsZygote.jl")
87-
Reexport.@reexport using .SparseDiffToolsZygote
88+
@reexport using .SparseDiffToolsZygote
8889
end
8990
end
9091
end

src/differentiation/jaches_products.jl

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -223,24 +223,25 @@ function (L::FwdModeAutoDiffVecProd)(dv, v, p, t)
223223
L.vecprod!(dv, L.f, L.u, v, L.cache...)
224224
end
225225

226-
function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
226+
function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff())
227227

228-
if autodiff
228+
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
229+
cache1 = similar(u)
230+
cache2 = similar(u)
231+
232+
(cache1, cache2), num_jacvec, num_jacvec!
233+
elseif autodiff isa AutoForwardDiff
229234
cache1 = Dual{
230235
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
231236
}.(u, ForwardDiff.Partials.(tuple.(u)))
232237

233238
cache2 = copy(cache1)
239+
240+
(cache1, cache2), auto_jacvec, auto_jacvec!
234241
else
235-
cache1 = similar(u)
236-
cache2 = similar(u)
242+
@error("Set autodiff to either AutoForwardDiff(), or AutoFiniteDiff()")
237243
end
238244

239-
cache = (cache1, cache2,)
240-
241-
vecprod = autodiff ? auto_jacvec : num_jacvec
242-
vecprod! = autodiff ? auto_jacvec! : num_jacvec!
243-
244245
outofplace = static_hasmethod(f, typeof((u,)))
245246
isinplace = static_hasmethod(f, typeof((u, u,)))
246247

@@ -256,22 +257,23 @@ function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
256257
)
257258
end
258259

259-
function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
260+
function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff())
260261

261-
if autodiff
262-
cache1 = ForwardDiff.GradientConfig(f, u)
263-
cache2 = similar(u)
264-
cache3 = similar(u)
265-
else
262+
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
266263
cache1 = similar(u)
267264
cache2 = similar(u)
268265
cache3 = similar(u)
269-
end
270266

271-
cache = (cache1, cache2, cache3,)
267+
(cache1, cache2, cache3), num_hesvec, num_hesvec!
268+
elseif autodiff isa AutoForwardDiff
269+
cache1 = ForwardDiff.GradientConfig(f, u)
270+
cache2 = similar(u)
271+
cache3 = similar(u)
272272

273-
vecprod = autodiff ? numauto_hesvec : num_hesvec
274-
vecprod! = autodiff ? numauto_hesvec! : num_hesvec!
273+
(cache1, cache2, cache3), numauto_hesvec, numauto_hesvec!
274+
else
275+
@error("Set autodiff to either AutoForwardDiff(), or AutoFiniteDiff()")
276+
end
275277

276278
outofplace = static_hasmethod(f, typeof((u,)))
277279
isinplace = static_hasmethod(f, typeof((u,)))
@@ -288,24 +290,24 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
288290
)
289291
end
290292

291-
function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
293+
function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff())
294+
295+
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
296+
cache1 = similar(u)
297+
cache2 = similar(u)
292298

293-
if autodiff
299+
(cache1, cache2), num_hesvecgrad, num_hesvecgrad!
300+
elseif autodiff isa AutoForwardDiff
294301
cache1 = Dual{
295302
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
296303
}.(u, ForwardDiff.Partials.(tuple.(u)))
297-
298304
cache2 = copy(cache1)
305+
306+
(cache1, cache2), auto_hesvecgrad, auto_hesvecgrad!
299307
else
300-
cache1 = similar(u)
301-
cache2 = similar(u)
308+
@error("Set autodiff to either AutoForwardDiff(), or AutoFiniteDiff()")
302309
end
303310

304-
cache = (cache1, cache2,)
305-
306-
vecprod = autodiff ? auto_hesvecgrad : num_hesvecgrad
307-
vecprod! = autodiff ? auto_hesvecgrad! : num_hesvecgrad!
308-
309311
outofplace = static_hasmethod(f, typeof((u,)))
310312
isinplace = static_hasmethod(f, typeof((u, u,)))
311313

src/differentiation/vecjac_products.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,13 @@ struct RevModeAutoDiffVecProd{ad,iip,oop,F,U,C,V,V!} <: AbstractAutoDiffVecProd
4444
vecprod::V
4545
vecprod!::V!
4646

47-
function RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!; autodiff = false,
47+
function RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!;
48+
autodiff = AutoFiniteDiff(),
4849
isinplace = false, outofplace = true)
4950
@assert isinplace || outofplace
5051

5152
new{
52-
autodiff,
53+
typeof(autodiff),
5354
isinplace,
5455
outofplace,
5556
typeof(f),
@@ -86,18 +87,19 @@ function (L::RevModeAutoDiffVecProd{ad,true,false})(dv, v, p, t) where{ad}
8687
L.vecprod!(dv, (_du, _u) -> L.f(_du, _u, p, t), L.u, v, L.cache...)
8788
end
8889

89-
function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = false,
90+
function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFiniteDiff(),
9091
ishermitian = false, opnrom = true)
9192

92-
if autodiff
93-
@assert isdefined(SparseDiffTools, :auto_vecjac) "Please load Zygote with `using Zygote`, or `import Zygote` to use VecJac with `autodiff = true`."
93+
vecprod, vecprod! = if autodiff isa AutoFiniteDiff
94+
num_vecjac, num_vecjac!
95+
elseif autodiff isa AutoZygote
96+
@assert static_hasmethod(auto_vecjac, typeof((f, u, u))) "To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"
97+
98+
auto_vecjac, auto_vecjac!
9499
end
95100

96101
cache = (similar(u), similar(u),)
97102

98-
vecprod = autodiff ? auto_vecjac : num_vecjac
99-
vecprod! = autodiff ? auto_vecjac! : num_vecjac!
100-
101103
outofplace = static_hasmethod(f, typeof((u, p, t)))
102104
isinplace = static_hasmethod(f, typeof((u, u, p, t)))
103105

test/test_jaches_products.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ update_coefficients!(L, v, nothing, 0.0)
7676
@test mul!(dy, L, v) auto_jacvec(f, v, v)
7777
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*auto_jacvec(f,x,v) + b*_dy
7878

79-
L = JacVec(f, x, autodiff = false)
79+
L = JacVec(f, x, autodiff = AutoFiniteDiff())
8080
@test L * x num_jacvec(f, x, x)
8181
@test L * v num_jacvec(f, x, v)
8282
@test mul!(dy, L, v)num_jacvec(f, x, v) rtol=1e-6
@@ -92,7 +92,7 @@ gmres!(out, L, v)
9292

9393
x = rand(N)
9494
v = rand(N)
95-
L = HesVec(g, x, autodiff = false)
95+
L = HesVec(g, x, autodiff = AutoFiniteDiff())
9696
@test L * x num_hesvec(g, x, x)
9797
@test L * v num_hesvec(g, x, v)
9898
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
@@ -118,7 +118,7 @@ using Zygote
118118
x = rand(N)
119119
v = rand(N)
120120

121-
L = ZygoteHesVec(g, x, autodiff = false)
121+
L = ZygoteHesVec(g, x, autodiff = AutoFiniteDiff())
122122
@test L * x numback_hesvec(g, x, x) rtol = 1e-2
123123
@test L * v numback_hesvec(g, x, v) rtol = 1e-2
124124
@test mul!(dy, L, v)numback_hesvec(g, x, v) rtol=1e-2
@@ -144,7 +144,7 @@ gmres!(out, L, v)
144144

145145
x = rand(N)
146146
v = rand(N)
147-
L = HesVecGrad(h, x, autodiff = false)
147+
L = HesVecGrad(h, x, autodiff = AutoFiniteDiff())
148148
@test L * x num_hesvec(g, x, x)
149149
@test L * v num_hesvec(g, x, v)
150150
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
@@ -153,7 +153,7 @@ update_coefficients!(L, v, nothing, 0.0)
153153
@test mul!(dy, L, v)num_hesvec(g, v, v) rtol=1e-2
154154
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*num_hesvec(g,x,v)+b*_dy rtol=1e-2
155155

156-
L = HesVecGrad(h, x, autodiff = true)
156+
L = HesVecGrad(h, x)
157157
@test L * x autonum_hesvec(g, x, x)
158158
@test L * v numauto_hesvec(g, x, v)
159159
@test mul!(dy, L, v)numauto_hesvec(g, x, v) rtol=1e-8

test/test_vecjac_products.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ L = VecJac(f, x)
1818
actual_vjp = Zygote.jacobian(x -> f(x, nothing, 0.0), x)[1]' * v
1919
update_coefficients!(L, v, nothing, 0.0)
2020
@test L * v actual_vjp
21-
L = VecJac(f, x; autodiff = false)
21+
L = VecJac(f, x; autodiff = AutoFiniteDiff())
2222
update_coefficients!(L, v, nothing, 0.0)
2323
@test L * v actual_vjp
2424

@@ -28,7 +28,7 @@ L = ZygoteVecJac(f, x)
2828
actual_vjp = Zygote.jacobian(x -> f(x, nothing, 0.0), x)[1]' * v
2929
update_coefficients!(L, v, nothing, 0.0)
3030
@test L * v actual_vjp
31-
L = ZygoteVecJac(f, x; autodiff = false)
31+
L = ZygoteVecJac(f, x; autodiff = AutoFiniteDiff())
3232
update_coefficients!(L, v, nothing, 0.0)
3333
@test L * v actual_vjp
3434
#

0 commit comments

Comments
 (0)