@@ -4,7 +4,7 @@ if isdefined(Base, :get_extension)
44 import Zygote
55 using LinearAlgebra
66 using SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac
7- using ForwardDiff: ForwardDiff, Dual
7+ using ForwardDiff: ForwardDiff, Dual, partials
88 using SciMLOperators: FunctionOperator
99 using Tricks: static_hasmethod
1010else
1818
1919# ## Jac, Hes products
2020
21- function numback_hesvec! (dy, f, x, v, cache1 = similar (v), cache2 = similar (v))
21+ function SparseDiffTools . numback_hesvec! (dy, f, x, v, cache1 = similar (v), cache2 = similar (v))
2222 g = let f = f
2323 (dx, x) -> dx .= first (Zygote. gradient (f, x))
2424 end
@@ -32,7 +32,7 @@ function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
3232 @. dy = (cache1 - cache2) / (2 ϵ)
3333end
3434
35- function numback_hesvec (f, x, v)
35+ function SparseDiffTools . numback_hesvec (f, x, v)
3636 g = x -> first (Zygote. gradient (f, x))
3737 T = eltype (x)
3838 # Should it be min? max? mean?
@@ -44,7 +44,7 @@ function numback_hesvec(f, x, v)
4444 (gxp - gxm) / (2 ϵ)
4545end
4646
47- function autoback_hesvec! (dy, f, x, v,
47+ function SparseDiffTools . autoback_hesvec! (dy, f, x, v,
4848 cache1 = Dual{typeof (ForwardDiff. Tag (DeivVecTag, eltype (x))),
4949 eltype (x), 1
5050 }. (x,
@@ -62,7 +62,7 @@ function autoback_hesvec!(dy, f, x, v,
6262 dy .= partials .(cache2, 1 )
6363end
6464
65- function autoback_hesvec (f, x, v)
65+ function SparseDiffTools . autoback_hesvec (f, x, v)
6666 g = x -> first (Zygote. gradient (f, x))
6767 y = Dual{typeof (ForwardDiff. Tag (DeivVecTag, eltype (x))), eltype (x), 1
6868 }. (x, ForwardDiff. Partials .(Tuple .(reshape (v, size (x)))))
7171
7272# Operator Forms
7373
74- function ZygoteHesVec (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = true )
74+ function SparseDiffTools . ZygoteHesVec (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = true )
7575
7676 if autodiff
7777 cache1 = Dual{
@@ -85,8 +85,8 @@ function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff =
8585
8686 cache = (cache1, cache2,)
8787
88- vecprod = autodiff ? autoback_hesvec : numback_hesvec
89- vecprod! = autodiff ? autoback_hesvec! : numback_hesvec!
88+ vecprod = autodiff ? SparseDiffTools . autoback_hesvec : SparseDiffTools . numback_hesvec
89+ vecprod! = autodiff ? SparseDiffTools . autoback_hesvec! : SparseDiffTools . numback_hesvec!
9090
9191 outofplace = static_hasmethod (f, typeof ((u,)))
9292 isinplace = static_hasmethod (f, typeof ((u,)))
@@ -105,16 +105,18 @@ end
105105
106106# # VecJac products
107107
108- function auto_vecjac! (du, f, x, v, cache1 = nothing , cache2 = nothing )
108+ function SparseDiffTools . auto_vecjac! (du, f, x, v, cache1 = nothing , cache2 = nothing )
109109 ! hasmethod (f, (typeof (x),)) && error (" For inplace function use autodiff = false" )
110- du .= reshape (auto_vecjac (f, x, v), size (du))
110+ du .= reshape (SparseDiffTools . auto_vecjac (f, x, v), size (du))
111111end
112112
113- function auto_vecjac (f, x, v)
113+ function SparseDiffTools . auto_vecjac (f, x, v)
114114 vv, back = Zygote. pullback (f, x)
115115 return vec (back (reshape (v, size (vv)))[1 ])
116116end
117117
118- ZygoteVecJac (args... ; autodiff = true , kwargs... ) = VecJac (args... ; autodiff = autodiff, kwargs... )
118+ function SparseDiffTools. ZygoteVecJac (args... ; autodiff = true , kwargs... )
119+ VecJac (args... ; autodiff = autodiff, kwargs... )
120+ end
119121
120122end # module
0 commit comments