Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit f05a7a5

Browse files
committed
separated VJP implementation for AutoZygote, FiniteDiff
1 parent 3c743ff commit f05a7a5

2 files changed

Lines changed: 88 additions & 68 deletions

File tree

ext/SparseDiffToolsZygote.jl

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
module SparseDiffToolsZygote
22

3-
if isdefined(Base, :get_extension)
4-
import Zygote
5-
using LinearAlgebra
6-
using SparseDiffTools: SparseDiffTools, DeivVecTag
7-
using ForwardDiff: ForwardDiff, Dual, partials
8-
else
9-
import ..Zygote
10-
using ..LinearAlgebra
11-
using ..SparseDiffTools: SparseDiffTools, DeivVecTag
12-
using ..ForwardDiff: ForwardDiff, Dual, partials
13-
end
3+
import Zygote
4+
using ADTypes
5+
using LinearAlgebra
6+
using SparseDiffTools: SparseDiffTools, DeivVecTag, AutoDiffVJP
7+
using ForwardDiff: ForwardDiff, Dual, partials
8+
import SciMLOperators: update_coefficients, update_coefficients!
9+
import Setfield: @set!
1410

1511
### Jac, Hes products
1612

@@ -75,14 +71,56 @@ end
7571

7672
## VecJac products
7773

74+
# VJP methods
7875
function SparseDiffTools.auto_vecjac!(du, f, x, v)
7976
!hasmethod(f, (typeof(x),)) && error("For inplace function use autodiff = AutoFiniteDiff()")
8077
du .= reshape(SparseDiffTools.auto_vecjac(f, x, v), size(du))
8178
end
8279

8380
function SparseDiffTools.auto_vecjac(f, x, v)
84-
vv, back = Zygote.pullback(f, x)
85-
return vec(back(reshape(v, size(vv)))[1])
81+
y, back = Zygote.pullback(f, x)
82+
return vec(back(reshape(v, size(y)))[1])
83+
end
84+
85+
# overload operator interface
86+
function SparseDiffTools._vecjac(f, u, autodiff::AutoZygote)
87+
88+
cache = ()
89+
pullback = Zygote.pullback(f, u)
90+
91+
AutoDiffVJP(f, u, cache, autodiff, pullback)
92+
end
93+
94+
function update_coefficients(L::AutoDiffVJP{AD}, u, p, t) where{AD <: AutoZygote}
95+
@set! L.f = update_coefficients(L.f, u, p, t)
96+
@set! L.u = u
97+
@set! L.pullback = Zygote.pullback(L.f, u)
98+
end
99+
100+
function update_coefficients!(L::AutoDiffVJP{AD}, u, p, t) where{AD <: AutoZygote}
101+
update_coefficients!(L.f, u, p, t)
102+
copy!(L.u, u)
103+
L.pullback = Zygote.pullback(L.f, u)
104+
L
105+
end
106+
107+
# Interpret the call as df/du' * v
108+
function (L::AutoDiffVJP{AD})(v, p, t) where{AD <: AutoZygote}
109+
110+
y, back = L.pullback
111+
V = reshape(v, size(y))
112+
113+
back(V)[1] |> vec
114+
end
115+
116+
# prefer non in-place method
117+
function (L::AutoDiffVJP{AD, IIP, true})(dv, v, p, t) where {AD <: AutoZygote, IIP}
118+
_dv = L(v, p, t)
119+
copy!(dv, _dv)
120+
end
121+
122+
function (L::AutoDiffVJP{AD, true, false})(dv, v, p, t) where {AD <: AutoZygote}
123+
SparseDiffTools.auto_vecjac!(dv, L.f, L.u, v, L.cache...)
86124
end
87125

88126
end # module

src/differentiation/vecjac_products.jl

Lines changed: 37 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,35 @@ end
3737

3838
### Operator Forms
3939

40-
struct RevModeAutoDiffVecProd{ad, iip, oop, F, U, C, V, V!} <: AbstractAutoDiffVecProd
40+
"""
41+
VecJac(f, u, [p, t]; autodiff = AutoFiniteDiff())
42+
"""
43+
function VecJac(f, u::AbstractArray, p = nothing, t = nothing;
44+
autodiff = AutoFiniteDiff(), kwargs...)
45+
46+
L = _vecjac(f, u, autodiff)
47+
IIP, OOP = get_iip_oop(L)
48+
49+
FunctionOperator(L, u, u; isinplace = IIP, outofplace = OOP,
50+
p = p, t = t, islinear = true, kwargs...)
51+
end
52+
53+
function _vecjac(f, u, autodiff::AutoFiniteDiff)
54+
55+
cache = (similar(u), similar(u))
56+
pullback = nothing
57+
58+
AutoDiffVJP(f, u, cache, autodiff, pullback)
59+
end
60+
61+
mutable struct AutoDiffVJP{AD, IIP, OOP, F, U, C, PB} <: AbstractAutoDiffVecProd
4162
f::F
4263
u::U
4364
cache::C
44-
vecprod::V
45-
vecprod!::V!
46-
autodiff::ad
65+
autodiff::AD
66+
pullback::PB
4767

48-
function RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!, autodiff)
68+
function AutoDiffVJP(f, u, cache, autodiff, pullback)
4969

5070
outofplace = static_hasmethod(f, typeof((u,)))
5171
isinplace = static_hasmethod(f, typeof((u, u)))
@@ -62,83 +82,45 @@ struct RevModeAutoDiffVecProd{ad, iip, oop, F, U, C, V, V!} <: AbstractAutoDiffV
6282
typeof(f),
6383
typeof(u),
6484
typeof(cache),
65-
typeof(vecprod),
66-
typeof(vecprod!)
85+
typeof(pullback),
6786
}(
68-
f, u, cache, vecprod, vecprod!, autodiff,
87+
f, u, cache, autodiff, pullback,
6988
)
7089
end
7190
end
7291

73-
function get_iip_oop(::RevModeAutoDiffVecProd{ad, iip, oop}) where{ad, iip, oop}
74-
iip, oop
92+
function get_iip_oop(::AutoDiffVJP{AD, IIP, OOP}) where{AD, IIP, OOP}
93+
IIP, OOP
7594
end
7695

77-
function update_coefficients(L::RevModeAutoDiffVecProd, u, p, t)
96+
function update_coefficients(L::AutoDiffVJP{AD}, u, p, t) where{AD <: AutoFiniteDiff}
7897
@set! L.f = update_coefficients(L.f, u, p, t)
7998
@set! L.u = u
8099
end
81100

82-
function update_coefficients!(L::RevModeAutoDiffVecProd, u, p, t)
101+
function update_coefficients!(L::AutoDiffVJP{AD}, u, p, t) where{AD <: AutoFiniteDiff}
83102
update_coefficients!(L.f, u, p, t)
84103
copy!(L.u, u)
85104
L
86105
end
87106

88-
# Interpret the call as df/du' * u
89-
function (L::RevModeAutoDiffVecProd)(v, p, t)
90-
L.vecprod(L.f, L.u, v)
107+
# Interpret the call as df/du' * v
108+
function (L::AutoDiffVJP{AD})(v, p, t) where{AD <: AutoFiniteDiff}
109+
num_vecjac(L.f, L.u, v)
91110
end
92111

93-
# prefer non in-place method
94-
function (L::RevModeAutoDiffVecProd{ad, iip, true})(dv, v, p, t) where {ad, iip}
95-
L.vecprod!(dv, L.f, L.u, v, L.cache...)
112+
function (L::AutoDiffVJP{AD})(dv, v, p, t) where{AD <: AutoFiniteDiff}
113+
num_vecjac!(dv, L.f, L.u, v, L.cache...)
96114
end
97115

98-
function (L::RevModeAutoDiffVecProd{ad, true, false})(dv, v, p, t) where {ad}
99-
L.vecprod!(dv, L.f, L.u, v, L.cache...)
100-
end
101-
102-
function Base.resize!(L::RevModeAutoDiffVecProd, n::Integer)
116+
function Base.resize!(L::AutoDiffVJP, n::Integer)
103117

104118
static_hasmethod(resize!, typeof((L.f, n))) && resize!(L.f, n)
105119
resize!(L.u, n)
106120

107121
for v in L.cache
108122
resize!(v, n)
109123
end
110-
end
111-
112-
"""
113-
VecJac(f, u, [p, t]; autodiff = AutoFiniteDiff())
114-
115-
Returns FunctionOperator that computes
116-
"""
117-
function VecJac(f, u::AbstractArray, p = nothing, t = nothing;
118-
autodiff = AutoFiniteDiff(), kwargs...)
119-
120-
vecprod, vecprod!, cache = if autodiff isa AutoFiniteDiff
121-
num_vecjac, num_vecjac!, (similar(u), similar(u))
122-
elseif autodiff isa AutoZygote
123-
@assert static_hasmethod(auto_vecjac, typeof((f, u, u))) "To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"
124-
125-
auto_vecjac, auto_vecjac!, ()
126-
end
127-
128-
L = RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!, autodiff)
129-
130-
iip, oop = get_iip_oop(L)
131-
132-
FunctionOperator(L, u, u; isinplace = iip, outofplace = oop,
133-
p = p, t = t, islinear = true, kwargs...)
134-
end
135-
136-
137-
function FixedVecJac(f, u::AbstractArray, p = nothing, t = nothing;
138-
autodiff = AutoFiniteDiff(), kwargs...)
139-
_fixedvecjac(f, u, p, t, autodiff, kwargs)
140-
end
141124

142-
function _fixedvecjac(f, u, p, t, ad::AutoFiniteDiff, kwargs)
143125
end
144126
#

0 commit comments

Comments
 (0)