@@ -25,21 +25,21 @@ function numback_hesvec(f, x, v)
2525end
2626
2727function autoback_hesvec! (dy, f, x, v,
28- cache2 = Dual{typeof (ForwardDiff. Tag (DeivVecTag, eltype (x))),
28+ cache1 = Dual{typeof (ForwardDiff. Tag (DeivVecTag, eltype (x))),
2929 eltype (x), 1
3030 }. (x,
3131 ForwardDiff. Partials .(Tuple .(reshape (v, size (x))))),
32- cache3 = Dual{typeof (ForwardDiff. Tag (DeivVecTag, eltype (x))),
32+ cache2 = Dual{typeof (ForwardDiff. Tag (DeivVecTag, eltype (x))),
3333 eltype (x), 1
3434 }. (x,
3535 ForwardDiff. Partials .(Tuple .(reshape (v, size (x))))))
3636 g = let f = f
3737 (dx, x) -> dx .= first (Zygote. gradient (f, x))
3838 end
39- cache2 .= Dual{typeof (ForwardDiff. Tag (DeivVecTag, eltype (x))), eltype (x), 1
39+ cache1 .= Dual{typeof (ForwardDiff. Tag (DeivVecTag, eltype (x))), eltype (x), 1
4040 }. (x, ForwardDiff. Partials .(Tuple .(reshape (v, size (x)))))
41- g (cache3, cache2 )
42- dy .= partials .(cache3 , 1 )
41+ g (cache2, cache1 )
42+ dy .= partials .(cache2 , 1 )
4343end
4444
4545function autoback_hesvec (f, x, v)
@@ -48,3 +48,38 @@ function autoback_hesvec(f, x, v)
4848 }. (x, ForwardDiff. Partials .(Tuple .(reshape (v, size (x)))))
4949 ForwardDiff. partials .(g (y), 1 )
5050end
51+
52+ # ## Operator Forms
53+
54+ function ZygoteHesVec (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = true )
55+
56+ if autodiff
57+ cache1 = Dual{
58+ typeof (ForwardDiff. Tag (DeivVecTag (),eltype (u))), eltype (u), 1
59+ }. (u, ForwardDiff. Partials .(tuple .(u)))
60+ cache2 = copy (u)
61+ else
62+ cache1 = similar (u)
63+ cache2 = similar (u)
64+ end
65+
66+ cache = (cache1, cache2,)
67+
68+ vecprod = autodiff ? autoback_hesvec : numback_hesvec
69+ vecprod! = autodiff ? autoback_hesvec! : numback_hesvec!
70+
71+ outofplace = static_hasmethod (f, typeof ((u,)))
72+ isinplace = static_hasmethod (f, typeof ((u,)))
73+
74+ if ! (isinplace) & ! (outofplace)
75+ error (" $f must have signature f(u)." )
76+ end
77+
78+ L = FwdModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!)
79+
80+ FunctionOperator (L, u, u;
81+ isinplace = isinplace, outofplace = outofplace,
82+ p = p, t = t, islinear = true ,
83+ )
84+ end
85+ #
0 commit comments