@@ -198,112 +198,126 @@ end
198198
199199# ## Operator Forms
200200
201- struct JacVec {F, T1, T2, xType}
201+ struct FwdModeAutoDiffVecProd {F,U,C,V,V!} <: AbstractAutoDiffVecProd
202202 f:: F
203- cache1 :: T1
204- cache2 :: T2
205- x :: xType
206- autodiff :: Bool
203+ u :: U
204+ cache :: C
205+ vecprod :: V
206+ vecprod! :: V!
207207end
208208
209- function JacVec (f, x:: AbstractArray , tag = DeivVecTag (); autodiff = true )
210- if autodiff
211- cache1 = Dual{typeof (ForwardDiff. Tag (tag, eltype (x))), eltype (x), 1
212- }. (x, ForwardDiff. Partials .(tuple .(x)))
213- cache2 = Dual{typeof (ForwardDiff. Tag (tag, eltype (x))), eltype (x), 1
214- }. (x, ForwardDiff. Partials .(tuple .(x)))
215- else
216- cache1 = similar (x)
217- cache2 = similar (x)
218- end
219- JacVec (f, cache1, cache2, x, autodiff)
209+ function update_coefficients (L:: FwdModeAutoDiffVecProd , u, p, t)
210+ FwdModeAutoDiffVecProd (L. f, u, L. vecprod, L. vecprod!, L. cache)
220211end
221212
222- Base. eltype (L:: JacVec ) = eltype (L. x)
223- Base. size (L:: JacVec ) = (length (L. cache1), length (L. cache1))
224- Base. size (L:: JacVec , i:: Int ) = length (L. cache1)
225- function Base.:* (L:: JacVec , v:: AbstractVector )
226- L. autodiff ? auto_jacvec (_x -> L. f (_x), L. x, v) :
227- num_jacvec (_x -> L. f (_x), L. x, v)
213+ function update_coefficients! (L:: FwdModeAutoDiffVecProd , u, p, t)
214+ copy! (L. u, u)
215+ L
228216end
229217
230- function LinearAlgebra. mul! (dy:: AbstractVector , L:: JacVec , v:: AbstractVector )
231- if L. autodiff
232- auto_jacvec! (dy, (_y, _x) -> L. f (_y, _x), L. x, v, L. cache1, L. cache2)
233- else
234- num_jacvec! (dy, (_y, _x) -> L. f (_y, _x), L. x, v, L. cache1, L. cache2)
235- end
218+ function (L:: FwdModeAutoDiffVecProd )(v, p, t)
219+ L. vecprod (L. f, L. u, v)
236220end
237221
238- struct HesVec{F, T1, T2, xType}
239- f:: F
240- cache1:: T1
241- cache2:: T2
242- cache3:: T2
243- x:: xType
244- autodiff:: Bool
222+ function (L:: FwdModeAutoDiffVecProd )(dv, v, p, t)
223+ L. vecprod! (dv, L. f, L. u, v, L. cache... )
245224end
246225
247- function HesVec (f, x:: AbstractArray ; autodiff = true )
226+ function JacVec (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = true )
227+
248228 if autodiff
249- cache1 = ForwardDiff. GradientConfig (f, x)
250- cache2 = similar (x)
251- cache3 = similar (x)
229+ cache1 = Dual{
230+ typeof (ForwardDiff. Tag (DeivVecTag (),eltype (u))), eltype (u), 1
231+ }. (u, ForwardDiff. Partials .(tuple .(u)))
232+
233+ cache2 = copy (cache1)
252234 else
253- cache1 = similar (x)
254- cache2 = similar (x)
255- cache3 = similar (x)
235+ cache1 = similar (u)
236+ cache2 = similar (u)
256237 end
257- HesVec (f, cache1, cache2, cache3, x, autodiff)
258- end
259238
260- Base. size (L:: HesVec ) = (length (L. cache2), length (L. cache2))
261- Base. size (L:: HesVec , i:: Int ) = length (L. cache2)
262- function Base.:* (L:: HesVec , v:: AbstractVector )
263- L. autodiff ? numauto_hesvec (L. f, L. x, v) : num_hesvec (L. f, L. x, v)
264- end
239+ cache = (cache1, cache2,)
265240
266- function LinearAlgebra. mul! (dy:: AbstractVector , L:: HesVec , v:: AbstractVector )
267- if L. autodiff
268- numauto_hesvec! (dy, L. f, L. x, v, L. cache1, L. cache2, L. cache3)
269- else
270- num_hesvec! (dy, L. f, L. x, v, L. cache1, L. cache2, L. cache3)
241+ vecprod = autodiff ? auto_jacvec : num_jacvec
242+ vecprod! = autodiff ? auto_jacvec! : num_jacvec!
243+
244+ outofplace = static_hasmethod (f, typeof ((u,)))
245+ isinplace = static_hasmethod (f, typeof ((u, u,)))
246+
247+ if ! (isinplace) & ! (outofplace)
248+ error (" $f must have signature f(u), or f(du, u)." )
271249 end
272- end
273250
274- struct HesVecGrad{G, T1, T2, uType}
275- g :: G
276- cache1 :: T1
277- cache2 :: T2
278- x :: uType
279- autodiff :: Bool
251+ L = FwdModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!)
252+
253+ FunctionOperator (L, u, u;
254+ isinplace = isinplace, outofplace = outofplace,
255+ p = p, t = t, islinear = true ,
256+ )
280257end
281258
282- function HesVecGrad (g, x:: AbstractArray , tag = DeivVecTag (); autodiff = false )
259+ function HesVec (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = true )
260+
283261 if autodiff
284- cache1 = Dual{typeof (ForwardDiff. Tag (tag, eltype (x))), eltype (x), 1
285- }. (x, ForwardDiff. Partials .(tuple .(x)))
286- cache2 = Dual{typeof (ForwardDiff. Tag (tag, eltype (x))), eltype (x), 1
287- }. (x, ForwardDiff. Partials .(tuple .(x)))
262+ cache1 = ForwardDiff. GradientConfig (f, u)
263+ cache2 = similar (u)
264+ cache3 = similar (u)
288265 else
289- cache1 = similar (x)
290- cache2 = similar (x)
266+ cache1 = similar (u)
267+ cache2 = similar (u)
268+ cache3 = similar (u)
291269 end
292- HesVecGrad (g, cache1, cache2, x, autodiff)
293- end
294270
295- Base. size (L:: HesVecGrad ) = (length (L. cache2), length (L. cache2))
296- Base. size (L:: HesVecGrad , i:: Int ) = length (L. cache2)
297- function Base.:* (L:: HesVecGrad , v:: AbstractVector )
298- L. autodiff ? auto_hesvecgrad (L. g, L. x, v) : num_hesvecgrad (L. g, L. x, v)
271+ cache = (cache1, cache2, cache3,)
272+
273+ vecprod = autodiff ? numauto_hesvec : num_hesvec
274+ vecprod! = autodiff ? numauto_hesvec! : num_hesvec!
275+
276+ outofplace = static_hasmethod (f, typeof ((u,)))
277+ isinplace = static_hasmethod (f, typeof ((u,)))
278+
279+ if ! (isinplace) & ! (outofplace)
280+ error (" $f must have signature f(u)." )
281+ end
282+
283+ L = FwdModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!)
284+
285+ FunctionOperator (L, u, u;
286+ isinplace = isinplace, outofplace = outofplace,
287+ p = p, t = t, islinear = true ,
288+ )
299289end
300290
301- function LinearAlgebra. mul! (dy:: AbstractVector ,
302- L:: HesVecGrad ,
303- v:: AbstractVector )
304- if L. autodiff
305- auto_hesvecgrad! (dy, L. g, L. x, v, L. cache1, L. cache2)
291+ function HesVecGrad (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = true )
292+
293+ if autodiff
294+ cache1 = Dual{
295+ typeof (ForwardDiff. Tag (DeivVecTag (), eltype (u))), eltype (u), 1
296+ }. (u, ForwardDiff. Partials .(tuple .(u)))
297+
298+ cache2 = copy (cache1)
306299 else
307- num_hesvecgrad! (dy, L. g, L. x, v, L. cache1, L. cache2)
300+ cache1 = similar (u)
301+ cache2 = similar (u)
302+ end
303+
304+ cache = (cache1, cache2,)
305+
306+ vecprod = autodiff ? auto_hesvecgrad : num_hesvecgrad
307+ vecprod! = autodiff ? auto_hesvecgrad! : num_hesvecgrad!
308+
309+ outofplace = static_hasmethod (f, typeof ((u,)))
310+ isinplace = static_hasmethod (f, typeof ((u, u,)))
311+
312+ if ! (isinplace) & ! (outofplace)
313+ error (" $f must have signature f(u), or f(du, u)." )
308314 end
315+
316+ L = FwdModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!)
317+
318+ FunctionOperator (L, u, u;
319+ isinplace = isinplace, outofplace = outofplace,
320+ p = p, t = t, islinear = true ,
321+ )
309322end
323+ #
0 commit comments