@@ -3,19 +3,13 @@ module SparseDiffToolsZygote
33if isdefined (Base, :get_extension )
44 import Zygote
55 using LinearAlgebra
6- using SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac
6+ using SparseDiffTools: SparseDiffTools, DeivVecTag
77 using ForwardDiff: ForwardDiff, Dual, partials
8- using SciMLOperators: FunctionOperator
9- using Tricks: static_hasmethod
10- using ADTypes
118else
129 import .. Zygote
1310 using .. LinearAlgebra
14- using .. SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac
11+ using .. SparseDiffTools: SparseDiffTools, DeivVecTag
1512 using .. ForwardDiff: ForwardDiff, Dual, partials
16- using .. SciMLOperators: FunctionOperator
17- using .. Tricks: static_hasmethod
18- using .. ADTypes
1913end
2014
2115# ## Jac, Hes products
@@ -71,39 +65,6 @@ function SparseDiffTools.autoback_hesvec(f, x, v)
7165 ForwardDiff. partials .(g (y), 1 )
7266end
7367
74- # Operator Forms
75-
76- function SparseDiffTools. ZygoteHesVec (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = AutoZygote ())
77-
78- cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
79- cache1 = similar (u)
80- cache2 = similar (u)
81-
82- (cache1, cache2), SparseDiffTools. numback_hesvec, SparseDiffTools. numback_hesvec!
83- elseif autodiff isa AutoZygote
84- cache1 = Dual{
85- typeof (ForwardDiff. Tag (DeivVecTag (),eltype (u))), eltype (u), 1
86- }. (u, ForwardDiff. Partials .(tuple .(u)))
87- cache2 = copy (u)
88-
89- (cache1, cache2), SparseDiffTools. autoback_hesvec, SparseDiffTools. autoback_hesvec!
90- end
91-
92- outofplace = static_hasmethod (f, typeof ((u,)))
93- isinplace = static_hasmethod (f, typeof ((u,)))
94-
95- if ! (isinplace) & ! (outofplace)
96- error (" $f must have signature f(u)." )
97- end
98-
99- L = FwdModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!)
100-
101- FunctionOperator (L, u, u;
102- isinplace = isinplace, outofplace = outofplace,
103- p = p, t = t, islinear = true ,
104- )
105- end
106-
10768# # VecJac products
10869
10970function SparseDiffTools. auto_vecjac! (du, f, x, v, cache1 = nothing , cache2 = nothing )
@@ -116,8 +77,4 @@ function SparseDiffTools.auto_vecjac(f, x, v)
11677 return vec (back (reshape (v, size (vv)))[1 ])
11778end
11879
119- function SparseDiffTools. ZygoteVecJac (args... ; kwargs... )
120- VecJac (args... ; autodiff = AutoZygote (), kwargs... )
121- end
122-
12380end # module
0 commit comments