3737
3838# ## Operator Forms
3939
40- struct RevModeAutoDiffVecProd{ad, iip, oop, F, U, C, V, V!} <: AbstractAutoDiffVecProd
40+ """
41+ VecJac(f, u, [p, t]; autodiff = AutoFiniteDiff())
42+ """
43+ function VecJac (f, u:: AbstractArray , p = nothing , t = nothing ;
44+ autodiff = AutoFiniteDiff (), kwargs... )
45+
46+ L = _vecjac (f, u, autodiff)
47+ IIP, OOP = get_iip_oop (L)
48+
49+ FunctionOperator (L, u, u; isinplace = IIP, outofplace = OOP,
50+ p = p, t = t, islinear = true , kwargs... )
51+ end
52+
53+ function _vecjac (f, u, autodiff:: AutoFiniteDiff )
54+
55+ cache = (similar (u), similar (u))
56+ pullback = nothing
57+
58+ AutoDiffVJP (f, u, cache, autodiff, pullback)
59+ end
60+
61+ mutable struct AutoDiffVJP{AD, IIP, OOP, F, U, C, PB} <: AbstractAutoDiffVecProd
4162 f:: F
4263 u:: U
4364 cache:: C
44- vecprod:: V
45- vecprod!:: V!
46- autodiff:: ad
65+ autodiff:: AD
66+ pullback:: PB
4767
48- function RevModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!, autodiff )
68+ function AutoDiffVJP (f, u, cache, autodiff, pullback )
4969
5070 outofplace = static_hasmethod (f, typeof ((u,)))
5171 isinplace = static_hasmethod (f, typeof ((u, u)))
@@ -62,83 +82,45 @@ struct RevModeAutoDiffVecProd{ad, iip, oop, F, U, C, V, V!} <: AbstractAutoDiffV
6282 typeof (f),
6383 typeof (u),
6484 typeof (cache),
65- typeof (vecprod),
66- typeof (vecprod!)
85+ typeof (pullback),
6786 }(
68- f, u, cache, vecprod, vecprod!, autodiff ,
87+ f, u, cache, autodiff, pullback ,
6988 )
7089 end
7190end
7291
73- function get_iip_oop (:: RevModeAutoDiffVecProd{ad, iip, oop } ) where {ad, iip, oop }
74- iip, oop
92+ function get_iip_oop (:: AutoDiffVJP{AD, IIP, OOP } ) where {AD, IIP, OOP }
93+ IIP, OOP
7594end
7695
77- function update_coefficients (L:: RevModeAutoDiffVecProd , u, p, t)
96+ function update_coefficients (L:: AutoDiffVJP{AD} , u, p, t) where {AD <: AutoFiniteDiff }
7897 @set! L. f = update_coefficients (L. f, u, p, t)
7998 @set! L. u = u
8099end
81100
82- function update_coefficients! (L:: RevModeAutoDiffVecProd , u, p, t)
101+ function update_coefficients! (L:: AutoDiffVJP{AD} , u, p, t) where {AD <: AutoFiniteDiff }
83102 update_coefficients! (L. f, u, p, t)
84103 copy! (L. u, u)
85104 L
86105end
87106
88- # Interpret the call as df/du' * u
89- function (L:: RevModeAutoDiffVecProd )(v, p, t)
90- L . vecprod (L. f, L. u, v)
107+ # Interpret the call as df/du' * v
108+ function (L:: AutoDiffVJP{AD} )(v, p, t) where {AD <: AutoFiniteDiff }
109+ num_vecjac (L. f, L. u, v)
91110end
92111
93- # prefer non in-place method
94- function (L:: RevModeAutoDiffVecProd{ad, iip, true} )(dv, v, p, t) where {ad, iip}
95- L. vecprod! (dv, L. f, L. u, v, L. cache... )
112+ function (L:: AutoDiffVJP{AD} )(dv, v, p, t) where {AD <: AutoFiniteDiff }
113+ num_vecjac! (dv, L. f, L. u, v, L. cache... )
96114end
97115
98- function (L:: RevModeAutoDiffVecProd{ad, true, false} )(dv, v, p, t) where {ad}
99- L. vecprod! (dv, L. f, L. u, v, L. cache... )
100- end
101-
102- function Base. resize! (L:: RevModeAutoDiffVecProd , n:: Integer )
116+ function Base. resize! (L:: AutoDiffVJP , n:: Integer )
103117
104118 static_hasmethod (resize!, typeof ((L. f, n))) && resize! (L. f, n)
105119 resize! (L. u, n)
106120
107121 for v in L. cache
108122 resize! (v, n)
109123 end
110- end
111-
112- """
113- VecJac(f, u, [p, t]; autodiff = AutoFiniteDiff())
114-
115- Returns FunctionOperator that computes
116- """
117- function VecJac (f, u:: AbstractArray , p = nothing , t = nothing ;
118- autodiff = AutoFiniteDiff (), kwargs... )
119-
120- vecprod, vecprod!, cache = if autodiff isa AutoFiniteDiff
121- num_vecjac, num_vecjac!, (similar (u), similar (u))
122- elseif autodiff isa AutoZygote
123- @assert static_hasmethod (auto_vecjac, typeof ((f, u, u))) " To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"
124-
125- auto_vecjac, auto_vecjac!, ()
126- end
127-
128- L = RevModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!, autodiff)
129-
130- iip, oop = get_iip_oop (L)
131-
132- FunctionOperator (L, u, u; isinplace = iip, outofplace = oop,
133- p = p, t = t, islinear = true , kwargs... )
134- end
135-
136-
137- function FixedVecJac (f, u:: AbstractArray , p = nothing , t = nothing ;
138- autodiff = AutoFiniteDiff (), kwargs... )
139- _fixedvecjac (f, u, p, t, autodiff, kwargs)
140- end
141124
142- function _fixedvecjac (f, u, p, t, ad:: AutoFiniteDiff , kwargs)
143125end
144126#
0 commit comments