@@ -223,24 +223,25 @@ function (L::FwdModeAutoDiffVecProd)(dv, v, p, t)
223223 L. vecprod! (dv, L. f, L. u, v, L. cache... )
224224end
225225
226- function JacVec (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = true )
226+ function JacVec (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = AutoForwardDiff () )
227227
228- if autodiff
228+ cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
229+ cache1 = similar (u)
230+ cache2 = similar (u)
231+
232+ (cache1, cache2), num_jacvec, num_jacvec!
233+ elseif autodiff isa AutoForwardDiff
229234 cache1 = Dual{
230235 typeof (ForwardDiff. Tag (DeivVecTag (),eltype (u))), eltype (u), 1
231236 }. (u, ForwardDiff. Partials .(tuple .(u)))
232237
233238 cache2 = copy (cache1)
239+
240+ (cache1, cache2), auto_jacvec, auto_jacvec!
234241 else
235- cache1 = similar (u)
236- cache2 = similar (u)
242+ @error (" Set autodiff to either AutoForwardDiff(), or AutoFiniteDiff()" )
237243 end
238244
239- cache = (cache1, cache2,)
240-
241- vecprod = autodiff ? auto_jacvec : num_jacvec
242- vecprod! = autodiff ? auto_jacvec! : num_jacvec!
243-
244245 outofplace = static_hasmethod (f, typeof ((u,)))
245246 isinplace = static_hasmethod (f, typeof ((u, u,)))
246247
@@ -256,22 +257,32 @@ function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
256257 )
257258end
258259
259- function HesVec (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = true )
260+ function HesVec (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = AutoForwardDiff () )
260261
261- if autodiff
262- cache1 = ForwardDiff . GradientConfig (f, u)
262+ cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
263+ cache1 = similar ( u)
263264 cache2 = similar (u)
264265 cache3 = similar (u)
265- else
266- cache1 = similar (u)
266+
267+ (cache1, cache2, cache3), num_hesvec, num_hesvec!
268+ elseif autodiff isa AutoForwardDiff
269+ cache1 = ForwardDiff. GradientConfig (f, u)
267270 cache2 = similar (u)
268271 cache3 = similar (u)
269- end
270272
271- cache = (cache1, cache2, cache3,)
273+ (cache1, cache2, cache3), numauto_hesvec, numauto_hesvec!
274+ elseif autodiff isa AutoZygote
275+ @assert static_hasmethod (autoback_hesvec, typeof ((f, u, u))) " To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"
272276
273- vecprod = autodiff ? numauto_hesvec : num_hesvec
274- vecprod! = autodiff ? numauto_hesvec! : num_hesvec!
277+ cache1 = Dual{
278+ typeof (ForwardDiff. Tag (DeivVecTag (),eltype (u))), eltype (u), 1
279+ }. (u, ForwardDiff. Partials .(tuple .(u)))
280+ cache2 = copy (u)
281+
282+ (cache1, cache2), autoback_hesvec, autoback_hesvec!
283+ else
284+ @error (" Set autodiff to either AutoForwardDiff(), AutoZygote(), or AutoFiniteDiff()" )
285+ end
275286
276287 outofplace = static_hasmethod (f, typeof ((u,)))
277288 isinplace = static_hasmethod (f, typeof ((u,)))
@@ -288,24 +299,24 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
288299 )
289300end
290301
291- function HesVecGrad (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = true )
302+ function HesVecGrad (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = AutoForwardDiff () )
292303
293- if autodiff
304+ cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
305+ cache1 = similar (u)
306+ cache2 = similar (u)
307+
308+ (cache1, cache2), num_hesvecgrad, num_hesvecgrad!
309+ elseif autodiff isa AutoForwardDiff
294310 cache1 = Dual{
295311 typeof (ForwardDiff. Tag (DeivVecTag (), eltype (u))), eltype (u), 1
296312 }. (u, ForwardDiff. Partials .(tuple .(u)))
297-
298313 cache2 = copy (cache1)
314+
315+ (cache1, cache2), auto_hesvecgrad, auto_hesvecgrad!
299316 else
300- cache1 = similar (u)
301- cache2 = similar (u)
317+ @error (" Set autodiff to either AutoForwardDiff(), or AutoFiniteDiff()" )
302318 end
303319
304- cache = (cache1, cache2,)
305-
306- vecprod = autodiff ? auto_hesvecgrad : num_hesvecgrad
307- vecprod! = autodiff ? auto_hesvecgrad! : num_hesvecgrad!
308-
309320 outofplace = static_hasmethod (f, typeof ((u,)))
310321 isinplace = static_hasmethod (f, typeof ((u, u,)))
311322
0 commit comments