@@ -198,112 +198,105 @@ end
198198
199199# ## Operator Forms
200200
201- struct JacVec {F, T1, T2, xType}
201+ mutable struct FwdModeAutoDiffVecProd {F,U,C,V,V!} <: AbstractAutoDiffVecProd
202202 f:: F
203- cache1 :: T1
204- cache2 :: T2
205- x :: xType
206- autodiff :: Bool
203+ u :: U
204+ cache :: C
205+ vecprod :: V
206+ vecprod! :: V!
207207end
208208
209- function JacVec (f, x:: AbstractArray , tag = DeivVecTag (); autodiff = true )
210- if autodiff
211- cache1 = Dual{typeof (ForwardDiff. Tag (tag, eltype (x))), eltype (x), 1
212- }. (x, ForwardDiff. Partials .(tuple .(x)))
213- cache2 = Dual{typeof (ForwardDiff. Tag (tag, eltype (x))), eltype (x), 1
214- }. (x, ForwardDiff. Partials .(tuple .(x)))
215- else
216- cache1 = similar (x)
217- cache2 = similar (x)
218- end
219- JacVec (f, cache1, cache2, x, autodiff)
209+ function update_coefficients (L:: FwdModeAutoDiffVecProd , u, p, t)
210+ FwdModeAutoDiffVecProd (L. f, u, L. vecprod, L. vecprod!, L. cache)
220211end
221212
222- Base. eltype (L:: JacVec ) = eltype (L. x)
223- Base. size (L:: JacVec ) = (length (L. cache1), length (L. cache1))
224- Base. size (L:: JacVec , i:: Int ) = length (L. cache1)
225- function Base.:* (L:: JacVec , v:: AbstractVector )
226- L. autodiff ? auto_jacvec (_x -> L. f (_x), L. x, v) :
227- num_jacvec (_x -> L. f (_x), L. x, v)
213+ function update_coefficients! (L:: FwdModeAutoDiffVecProd , u, p, t)
214+ L. u .= u
215+ L
228216end
229217
230- function LinearAlgebra. mul! (dy:: AbstractVector , L:: JacVec , v:: AbstractVector )
231- if L. autodiff
232- auto_jacvec! (dy, (_y, _x) -> L. f (_y, _x), L. x, v, L. cache1, L. cache2)
233- else
234- num_jacvec! (dy, (_y, _x) -> L. f (_y, _x), L. x, v, L. cache1, L. cache2)
235- end
218+ function (L:: FwdModeAutoDiffVecProd )(v, p, t)
219+ L. vecprod (L. f, L. u, v)
236220end
237221
238- struct HesVec{F, T1, T2, xType}
239- f:: F
240- cache1:: T1
241- cache2:: T2
242- cache3:: T2
243- x:: xType
244- autodiff:: Bool
222+ function (L:: FwdModeAutoDiffVecProd )(dv, v, p, t)
223+ L. vecprod! (dv, L. f, L. u, v, L. cache... )
245224end
246225
247- function HesVec (f, x:: AbstractArray ; autodiff = true )
226+ function JacVec (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = true )
227+
248228 if autodiff
249- cache1 = ForwardDiff. GradientConfig (f, x)
250- cache2 = similar (x)
251- cache3 = similar (x)
229+ cache1 = Dual{
230+ typeof (ForwardDiff. Tag (DeivVecTag (),eltype (u))), eltype (u), 1
231+ }. (u, ForwardDiff. Partials .(tuple .(u)))
232+
233+ cache2 = copy (cache1)
252234 else
253- cache1 = similar (x)
254- cache2 = similar (x)
255- cache3 = similar (x)
235+ cache1 = similar (u)
236+ cache2 = similar (u)
256237 end
257- HesVec (f, cache1, cache2, cache3, x, autodiff)
258- end
259238
260- Base. size (L:: HesVec ) = (length (L. cache2), length (L. cache2))
261- Base. size (L:: HesVec , i:: Int ) = length (L. cache2)
262- function Base.:* (L:: HesVec , v:: AbstractVector )
263- L. autodiff ? numauto_hesvec (L. f, L. x, v) : num_hesvec (L. f, L. x, v)
264- end
239+ cache = (cache1, cache2,)
265240
266- function LinearAlgebra. mul! (dy:: AbstractVector , L:: HesVec , v:: AbstractVector )
267- if L. autodiff
268- numauto_hesvec! (dy, L. f, L. x, v, L. cache1, L. cache2, L. cache3)
269- else
270- num_hesvec! (dy, L. f, L. x, v, L. cache1, L. cache2, L. cache3)
271- end
272- end
241+ vecprod = autodiff ? auto_jacvec : num_jacvec
242+ vecprod! = autodiff ? auto_jacvec! : num_jacvec!
273243
274- struct HesVecGrad{G, T1, T2, uType}
275- g :: G
276- cache1 :: T1
277- cache2 :: T2
278- x :: uType
279- autodiff :: Bool
244+ L = FwdModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!)
245+
246+ FunctionOperator (L, u, u;
247+ isinplace = true , outofplace = true ,
248+ p = p, t = t, islinear = true ,
249+ )
280250end
281251
282- function HesVecGrad (g, x:: AbstractArray , tag = DeivVecTag (); autodiff = false )
252+ function HesVec (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = true )
253+
283254 if autodiff
284- cache1 = Dual{typeof (ForwardDiff. Tag (tag, eltype (x))), eltype (x), 1
285- }. (x, ForwardDiff. Partials .(tuple .(x)))
286- cache2 = Dual{typeof (ForwardDiff. Tag (tag, eltype (x))), eltype (x), 1
287- }. (x, ForwardDiff. Partials .(tuple .(x)))
255+ cache1 = ForwardDiff. GradientConfig (f, u)
256+ cache2 = similar (u)
257+ cache3 = similar (u)
288258 else
289- cache1 = similar (x)
290- cache2 = similar (x)
259+ cache1 = similar (u)
260+ cache2 = similar (u)
261+ cache3 = similar (u)
291262 end
292- HesVecGrad (g, cache1, cache2, x, autodiff)
293- end
294263
295- Base. size (L:: HesVecGrad ) = (length (L. cache2), length (L. cache2))
296- Base. size (L:: HesVecGrad , i:: Int ) = length (L. cache2)
297- function Base.:* (L:: HesVecGrad , v:: AbstractVector )
298- L. autodiff ? auto_hesvecgrad (L. g, L. x, v) : num_hesvecgrad (L. g, L. x, v)
264+ cache = (cache1, cache2, cache3,)
265+
266+ vecprod = autodiff ? numauto_hesvec : num_hesvec
267+ vecprod! = autodiff ? numauto_hesvec! : num_hesvec!
268+
269+ L = FwdModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!)
270+
271+ FunctionOperator (L, u, u;
272+ isinplace = true , outofplace = true ,
273+ p = p, t = t, islinear = true ,
274+ )
299275end
300276
301- function LinearAlgebra. mul! (dy:: AbstractVector ,
302- L:: HesVecGrad ,
303- v:: AbstractVector )
304- if L. autodiff
305- auto_hesvecgrad! (dy, L. g, L. x, v, L. cache1, L. cache2)
277+ function HesVecGrad (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = true )
278+
279+ if autodiff
280+ cache1 = Dual{
281+ typeof (ForwardDiff. Tag (DeivVecTag (), eltype (u))), eltype (u), 1
282+ }. (u, ForwardDiff. Partials .(tuple .(u)))
283+
284+ cache2 = copy (cache1)
306285 else
307- num_hesvecgrad! (dy, L. g, L. x, v, L. cache1, L. cache2)
286+ cache1 = similar (u)
287+ cache2 = similar (u)
308288 end
289+
290+ cache = (cache1, cache2,)
291+
292+ vecprod = autodiff ? auto_hesvecgrad : num_hesvecgrad
293+ vecprod! = autodiff ? auto_hesvecgrad! : num_hesvecgrad!
294+
295+ L = FwdModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!)
296+
297+ FunctionOperator (L, u, u;
298+ isinplace = true , outofplace = true ,
299+ p = p, t = t, islinear = true ,
300+ )
309301end
302+ #
0 commit comments