1- function numback_hesvec! (dy, f, x, v, cache1 = similar (v), cache2 = similar (v))
1+ module SparseDiffToolsZygote
2+
3+ if isdefined (Base, :get_extension )
4+ import Zygote
5+ using LinearAlgebra
6+ using SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac
7+ using ForwardDiff: ForwardDiff, Dual, partials
8+ using SciMLOperators: FunctionOperator
9+ using Tricks: static_hasmethod
10+ else
11+ import .. Zygote
12+ using .. LinearAlgebra
13+ using .. SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac
14+ using .. ForwardDiff: ForwardDiff, Dual, partials
15+ using .. SciMLOperators: FunctionOperator
16+ using .. Tricks: static_hasmethod
17+ end
18+
19+ # ## Jac, Hes products
20+
21+ function SparseDiffTools. numback_hesvec! (dy, f, x, v, cache1 = similar (v), cache2 = similar (v))
222 g = let f = f
323 (dx, x) -> dx .= first (Zygote. gradient (f, x))
424 end
@@ -12,7 +32,7 @@ function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
1232 @. dy = (cache1 - cache2) / (2 ϵ)
1333end
1434
15- function numback_hesvec (f, x, v)
35+ function SparseDiffTools . numback_hesvec (f, x, v)
1636 g = x -> first (Zygote. gradient (f, x))
1737 T = eltype (x)
1838 # Should it be min? max? mean?
@@ -24,7 +44,7 @@ function numback_hesvec(f, x, v)
2444 (gxp - gxm) / (2 ϵ)
2545end
2646
27- function autoback_hesvec! (dy, f, x, v,
47+ function SparseDiffTools . autoback_hesvec! (dy, f, x, v,
2848 cache1 = Dual{typeof (ForwardDiff. Tag (DeivVecTag, eltype (x))),
2949 eltype (x), 1
3050 }. (x,
@@ -42,16 +62,16 @@ function autoback_hesvec!(dy, f, x, v,
4262 dy .= partials .(cache2, 1 )
4363end
4464
45- function autoback_hesvec (f, x, v)
65+ function SparseDiffTools . autoback_hesvec (f, x, v)
4666 g = x -> first (Zygote. gradient (f, x))
4767 y = Dual{typeof (ForwardDiff. Tag (DeivVecTag, eltype (x))), eltype (x), 1
4868 }. (x, ForwardDiff. Partials .(Tuple .(reshape (v, size (x)))))
4969 ForwardDiff. partials .(g (y), 1 )
5070end
5171
52- # ## Operator Forms
72+ # Operator Forms
5373
54- function ZygoteHesVec (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = true )
74+ function SparseDiffTools . ZygoteHesVec (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = true )
5575
5676 if autodiff
5777 cache1 = Dual{
@@ -65,8 +85,8 @@ function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff =
6585
6686 cache = (cache1, cache2,)
6787
68- vecprod = autodiff ? autoback_hesvec : numback_hesvec
69- vecprod! = autodiff ? autoback_hesvec! : numback_hesvec!
88+ vecprod = autodiff ? SparseDiffTools . autoback_hesvec : SparseDiffTools . numback_hesvec
89+ vecprod! = autodiff ? SparseDiffTools . autoback_hesvec! : SparseDiffTools . numback_hesvec!
7090
7191 outofplace = static_hasmethod (f, typeof ((u,)))
7292 isinplace = static_hasmethod (f, typeof ((u,)))
@@ -82,4 +102,21 @@ function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff =
82102 p = p, t = t, islinear = true ,
83103 )
84104end
85- #
105+
106+ # # VecJac products
107+
108+ function SparseDiffTools. auto_vecjac! (du, f, x, v, cache1 = nothing , cache2 = nothing )
109+ ! hasmethod (f, (typeof (x),)) && error (" For inplace function use autodiff = false" )
110+ du .= reshape (SparseDiffTools. auto_vecjac (f, x, v), size (du))
111+ end
112+
113+ function SparseDiffTools. auto_vecjac (f, x, v)
114+ vv, back = Zygote. pullback (f, x)
115+ return vec (back (reshape (v, size (vv)))[1 ])
116+ end
117+
118+ function SparseDiffTools. ZygoteVecJac (args... ; autodiff = true , kwargs... )
119+ VecJac (args... ; autodiff = autodiff, kwargs... )
120+ end
121+
122+ end # module
0 commit comments