@@ -43,11 +43,17 @@ struct RevModeAutoDiffVecProd{ad, iip, oop, F, U, C, V, V!} <: AbstractAutoDiffV
4343 cache:: C
4444 vecprod:: V
4545 vecprod!:: V!
46+ autodiff:: ad
4647
47- function RevModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!;
48- autodiff = AutoFiniteDiff (),
49- isinplace = false , outofplace = true )
50- @assert isinplace || outofplace
48+ function RevModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!, autodiff)
49+
50+ outofplace = static_hasmethod (f, typeof ((u,)))
51+ isinplace = static_hasmethod (f, typeof ((u, u)))
52+
53+ if ! (isinplace) & ! (outofplace)
54+ msg = " $f must have signature f(u), or f(du, u)"
55+ throw (ArgumentError (msg))
56+ end
5157
5258 new{
5359 typeof (autodiff),
@@ -58,13 +64,19 @@ struct RevModeAutoDiffVecProd{ad, iip, oop, F, U, C, V, V!} <: AbstractAutoDiffV
5864 typeof (cache),
5965 typeof (vecprod),
6066 typeof (vecprod!)
61- }(f, u, cache, vecprod, vecprod!)
67+ }(
68+ f, u, cache, vecprod, vecprod!, autodiff,
69+ )
6270 end
6371end
6472
73+ function get_iip_oop (:: RevModeAutoDiffVecProd{ad, iip, oop} ) where {ad, iip, oop}
74+ iip, oop
75+ end
76+
6577function update_coefficients (L:: RevModeAutoDiffVecProd , u, p, t)
66- f = update_coefficients (L. f, u, p, t)
67- RevModeAutoDiffVecProd (f, u, L . vecprod, L . vecprod!, L . cache)
78+ @set! L . f = update_coefficients (L. f, u, p, t)
79+ @set! L . u = u
6880end
6981
7082function update_coefficients! (L:: RevModeAutoDiffVecProd , u, p, t)
@@ -97,31 +109,36 @@ function Base.resize!(L::RevModeAutoDiffVecProd, n::Integer)
97109 end
98110end
99111
100- function VecJac (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = AutoFiniteDiff (),
101- kwargs... )
102- vecprod, vecprod! = if autodiff isa AutoFiniteDiff
103- num_vecjac, num_vecjac!
112+ """
113+ VecJac(f, u, [p, t]; autodiff = AutoFiniteDiff())
114+
115+ Returns FunctionOperator that computes
116+ """
117+ function VecJac (f, u:: AbstractArray , p = nothing , t = nothing ;
118+ autodiff = AutoFiniteDiff (), kwargs... )
119+
120+ vecprod, vecprod!, cache = if autodiff isa AutoFiniteDiff
121+ num_vecjac, num_vecjac!, (similar (u), similar (u))
104122 elseif autodiff isa AutoZygote
105123 @assert static_hasmethod (auto_vecjac, typeof ((f, u, u))) " To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"
106124
107- auto_vecjac, auto_vecjac!
125+ auto_vecjac, auto_vecjac!, ()
108126 end
109127
110- cache = ( similar (u), similar (u) )
128+ L = RevModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!, autodiff )
111129
112- outofplace = static_hasmethod (f, typeof ((u,)))
113- isinplace = static_hasmethod (f, typeof ((u, u)))
130+ iip, oop = get_iip_oop (L)
114131
115- if ! ( isinplace) & ! ( outofplace)
116- error ( " $f must have signature f(u), or f(du, u) " )
117- end
132+ FunctionOperator (L, u, u; isinplace = iip, outofplace = oop,
133+ p = p, t = t, islinear = true , kwargs ... )
134+ end
118135
119- L = RevModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!; autodiff = autodiff,
120- isinplace = isinplace, outofplace = outofplace)
121136
122- FunctionOperator (L, u, u;
123- isinplace = isinplace, outofplace = outofplace,
124- p = p, t = t, islinear = true ,
125- kwargs... )
137+ function FixedVecJac (f, u:: AbstractArray , p = nothing , t = nothing ;
138+ autodiff = AutoFiniteDiff (), kwargs... )
139+ _fixedvecjac (f, u, p, t, autodiff, kwargs)
140+ end
141+
142+ function _fixedvecjac (f, u, p, t, ad:: AutoFiniteDiff , kwargs)
126143end
127144#
0 commit comments