3737
3838# ## Operator Forms
3939
40- struct RevModeAutoDiffVecProd{ad, iip, oop, F, U, C, V, V!} <: AbstractAutoDiffVecProd
40+ """
41+ VecJac(f, u, [p, t]; autodiff = AutoFiniteDiff())
42+
43+ Returns SciMLOperators.FunctionOperator which computes vector-jacobian
44+ product `df/du * v`.
45+
46+ ```
47+ L = VecJac(f, u)
48+
49+ L * v # = df/du * v
50+ mul!(w, L, v) # = df/du * v
51+
52+ L(v, p, t; VJP_input = w) # = df/dw * v
53+ L(x, v, p, t; VJP_input = w) # = df/dw * v
54+ ```
55+ """
56+ function VecJac (f, u:: AbstractArray , p = nothing , t = nothing ;
57+ autodiff = AutoFiniteDiff (), kwargs... )
58+
59+ L = _vecjac (f, u, autodiff)
60+ IIP, OOP = get_iip_oop (L)
61+
62+ if isa (autodiff, AutoZygote) & ! OOP
63+ msg = " Zygote requires an out of place method with signature f(u)."
64+ throw (ArgumentError (msg))
65+ end
66+
67+ FunctionOperator (L, u, u; isinplace = IIP, outofplace = OOP,
68+ p = p, t = t, islinear = true ,
69+ accepted_kwargs = (:VJP_input ,), kwargs... )
70+ end
71+
72+ function _vecjac (f, u, autodiff:: AutoFiniteDiff )
73+
74+ cache = (similar (u), similar (u))
75+ pullback = nothing
76+
77+ AutoDiffVJP (f, u, cache, autodiff, pullback)
78+ end
79+
80+ mutable struct AutoDiffVJP{AD, IIP, OOP, F, U, C, PB} <: AbstractAutoDiffVecProd
81+ """ Compute VJP of `f` at `u`, applied to vector `v`: `df/du' * u` """
4182 f:: F
83+ """ input to `f` """
4284 u:: U
85+ """ Cache for num_vecjac! when autodiff isa AutoFintieDiff """
4386 cache:: C
44- vecprod:: V
45- vecprod!:: V!
87+ """ Type of automatic differentiation algorithm """
88+ autodiff:: AD
89+ """ stores the result of Zygote.pullback for AutoZygote """
90+ pullback:: PB
91+
92+ function AutoDiffVJP (f, u, cache, autodiff, pullback)
4693
47- function RevModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!;
48- autodiff = AutoFiniteDiff (),
49- isinplace = false , outofplace = true )
50- @assert isinplace || outofplace
94+ outofplace = static_hasmethod (f, typeof ((u,)))
95+ isinplace = static_hasmethod (f, typeof ((u, u)))
96+
97+ if ! (isinplace) & ! (outofplace)
98+ msg = " $f must have signature f(u), or f(du, u)"
99+ throw (ArgumentError (msg))
100+ end
51101
52102 new{
53103 typeof (autodiff),
@@ -56,72 +106,58 @@ struct RevModeAutoDiffVecProd{ad, iip, oop, F, U, C, V, V!} <: AbstractAutoDiffV
56106 typeof (f),
57107 typeof (u),
58108 typeof (cache),
59- typeof (vecprod),
60- typeof (vecprod!)
61- }(f, u, cache, vecprod, vecprod!)
109+ typeof (pullback),
110+ }(
111+ f, u, cache, autodiff, pullback,
112+ )
62113 end
63114end
64115
65- function update_coefficients (L:: RevModeAutoDiffVecProd , u, p, t)
66- f = update_coefficients (L. f, u, p, t)
67- RevModeAutoDiffVecProd (f, u, L. vecprod, L. vecprod!, L. cache)
116+ function get_iip_oop (:: AutoDiffVJP{AD, IIP, OOP} ) where {AD, IIP, OOP}
117+ IIP, OOP
68118end
69119
70- function update_coefficients! (L:: RevModeAutoDiffVecProd , u, p, t)
71- update_coefficients! (L. f, u, p, t)
72- copy! (L. u, u)
73- L
120+ function update_coefficients (L:: AutoDiffVJP{AD} , u, p, t; VJP_input = nothing ,
121+ ) where {AD <: AutoFiniteDiff }
122+
123+ if ! isnothing (VJP_input)
124+ @set! L. u = VJP_input
125+ end
126+
127+ @set! L. f = update_coefficients (L. f, L. u, p, t)
74128end
75129
76- # Interpret the call as df/du' * u
77- function (L:: RevModeAutoDiffVecProd )(v, p, t)
78- L. vecprod (L. f, L. u, v)
130+ function update_coefficients! (L:: AutoDiffVJP{AD} , u, p, t; VJP_input = nothing ,
131+ ) where {AD <: AutoFiniteDiff }
132+
133+ if ! isnothing (VJP_input)
134+ copy! (L. u, VJP_input)
135+ end
136+
137+ update_coefficients! (L. f, L. u, p, t)
138+
139+ L
79140end
80141
81- # prefer non in-place method
82- function (L:: RevModeAutoDiffVecProd{ad, iip, true} )(dv, v, p, t) where {ad, iip}
83- L. vecprod! (dv, L. f, L. u, v, L. cache... )
142+ # Interpret the call as df/du' * v
143+ function (L:: AutoDiffVJP{AD} )(v, p, t; VJP_input = nothing ,) where {AD <: AutoFiniteDiff }
144+ # ignore VJP_input as L.u was set in update_coefficients(...)
145+ num_vecjac (L. f, L. u, v)
84146end
85147
86- function (L:: RevModeAutoDiffVecProd{ad, true, false} )(dv, v, p, t) where {ad}
87- L. vecprod! (dv, L. f, L. u, v, L. cache... )
148+ function (L:: AutoDiffVJP{AD} )(dv, v, p, t; VJP_input = nothing ,) where {AD <: AutoFiniteDiff }
149+ # ignore VJP_input as L.u was set in update_coefficients!(...)
150+ num_vecjac! (dv, L. f, L. u, v, L. cache... )
88151end
89152
90- function Base. resize! (L:: RevModeAutoDiffVecProd , n:: Integer )
153+ function Base. resize! (L:: AutoDiffVJP , n:: Integer )
91154
92155 static_hasmethod (resize!, typeof ((L. f, n))) && resize! (L. f, n)
93156 resize! (L. u, n)
94157
95158 for v in L. cache
96159 resize! (v, n)
97160 end
98- end
99-
100- function VecJac (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = AutoFiniteDiff (),
101- kwargs... )
102- vecprod, vecprod! = if autodiff isa AutoFiniteDiff
103- num_vecjac, num_vecjac!
104- elseif autodiff isa AutoZygote
105- @assert static_hasmethod (auto_vecjac, typeof ((f, u, u))) " To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"
106161
107- auto_vecjac, auto_vecjac!
108- end
109-
110- cache = (similar (u), similar (u))
111-
112- outofplace = static_hasmethod (f, typeof ((u,)))
113- isinplace = static_hasmethod (f, typeof ((u, u)))
114-
115- if ! (isinplace) & ! (outofplace)
116- error (" $f must have signature f(u), or f(du, u)" )
117- end
118-
119- L = RevModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!; autodiff = autodiff,
120- isinplace = isinplace, outofplace = outofplace)
121-
122- FunctionOperator (L, u, u;
123- isinplace = isinplace, outofplace = outofplace,
124- p = p, t = t, islinear = true ,
125- kwargs... )
126162end
127163#
0 commit comments