1111getsize (:: Val{N} ) where N = N
1212getsize (N:: Integer ) = N
1313void_setindex! (args... ) = (setindex! (args... ); return )
14+ gettag (:: Type{ForwardDiff.Dual{T}} ) where {T} = T
1415
1516const default_chunk_size = ForwardDiff. pickchunksize
17+ const SMALLTAG = ForwardDiff. Tag (missing ,Float64)
1618
1719function ForwardColorJacCache (f:: F ,x,_chunksize = nothing ;
1820 dx = nothing ,
21+ tag = nothing ,
1922 colorvec= 1 : length (x),
2023 sparsity:: Union{AbstractArray,Nothing} = nothing ) where {F}
2124
@@ -25,15 +28,21 @@ function ForwardColorJacCache(f::F,x,_chunksize = nothing;
2528 chunksize = _chunksize
2629 end
2730
31+ if tag === nothing
32+ T = typeof (ForwardDiff. Tag (f,eltype (vec (x))))
33+ else
34+ T = tag
35+ end
36+
2837 if x isa Array
2938 p = generate_chunked_partials (x,colorvec,chunksize)
30- t = similar (x,Dual{typeof (ForwardDiff . Tag (f, eltype ( vec (x)))), eltype (x), length ( first ( first (p))) })
39+ t = similar (x,Dual{T })
3140 for i in eachindex (t)
32- t[i] = Dual {typeof(ForwardDiff.Tag(f,eltype(vec(x)))) ,eltype(x),length(first(first(p)))} (x[i],ForwardDiff. Partials (first (p)[i]))
41+ t[i] = Dual {T ,eltype(x),length(first(first(p)))} (x[i],ForwardDiff. Partials (first (p)[i]))
3342 end
3443 else
3544 p = adapt .(parameterless_type (x),generate_chunked_partials (x,colorvec,chunksize))
36- _t = Dual {typeof(ForwardDiff.Tag(f ,eltype(vec(x)))) } .(vec (x),first (p))
45+ _t = Dual {T ,eltype(x),getsize(chunksize) } .(vec (x),ForwardDiff . Partials .( first (p) ))
3746 t = ArrayInterface. restructure (x,_t)
3847 end
3948
@@ -44,7 +53,7 @@ function ForwardColorJacCache(f::F,x,_chunksize = nothing;
4453 else
4554 tup = ArrayInterface. allowed_getindex (ArrayInterface. allowed_getindex (p,1 ),1 ) .* false
4655 _pi = adapt (parameterless_type (dx),[tup for i in 1 : length (dx)])
47- fx = reshape (Dual {typeof(ForwardDiff.Tag(f ,eltype(vec(x)))) } .(vec (dx),_pi),size (dx)... )
56+ fx = reshape (Dual {T ,eltype(dx),length(tup) } .(vec (dx),ForwardDiff . Partials .( _pi) ),size (dx)... )
4857 _dx = dx
4958 end
5059
@@ -162,7 +171,7 @@ function forwarddiff_color_jacobian(J::AbstractMatrix{<:Number},f::F,x::Abstract
162171
163172 for i in eachindex (p)
164173 partial_i = p[i]
165- t = reshape (Dual {typeof(ForwardDiff.Tag(f, eltype(vecx)))} .(vecx, partial_i),size (t))
174+ t = reshape (eltype (t) .(vecx, ForwardDiff . Partials .( partial_i) ),size (t))
166175 fx = f (t)
167176 if ! (sparsity isa Nothing)
168177 for j in 1 : chunksize
@@ -230,7 +239,7 @@ function forwarddiff_color_jacobian_immutable(f,x::AbstractArray{<:Number},jac_c
230239
231240 for i in eachindex (p)
232241 partial_i = p[i]
233- t = reshape (Dual {typeof(ForwardDiff.Tag(f, eltype(vecx)))} .(vecx, partial_i),size (t))
242+ t = reshape (eltype (t) .(vecx, ForwardDiff . Partials .( partial_i) ),size (t))
234243 fx = f (t)
235244 if ! (sparsity isa Nothing)
236245 for j in 1 : chunksize
@@ -311,10 +320,10 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
311320
312321 if vect isa Array
313322 @inbounds @simd ivdep for j in eachindex (vect)
314- vect[j] = Dual {typeof(ForwardDiff.Tag(f, eltype(vecx)))} (vecx[j], partial_i[j])
323+ vect[j] = eltype (t) (vecx[j], ForwardDiff . Partials ( partial_i[j]) )
315324 end
316325 else
317- vect .= Dual {typeof(ForwardDiff.Tag(f, eltype(vecx)))} .(vecx, partial_i)
326+ vect .= eltype (t) .(vecx, ForwardDiff . Partials .( partial_i) )
318327 end
319328
320329 f (fx,t)
0 commit comments