1111getsize (:: Val{N} ) where N = N
1212getsize (N:: Integer ) = N
1313void_setindex! (args... ) = (setindex! (args... ); return )
14+ gettag (:: Type{ForwardDiff.Dual{T}} ) where {T} = T
1415
1516const default_chunk_size = ForwardDiff. pickchunksize
1617
1718function ForwardColorJacCache (f:: F ,x,_chunksize = nothing ;
1819 dx = nothing ,
20+ tag = nothing ,
1921 colorvec= 1 : length (x),
2022 sparsity:: Union{AbstractArray,Nothing} = nothing ) where {F}
2123
@@ -25,15 +27,21 @@ function ForwardColorJacCache(f::F,x,_chunksize = nothing;
2527 chunksize = _chunksize
2628 end
2729
30+ if tag === nothing
31+ T = typeof (ForwardDiff. Tag (f,eltype (vec (x))))
32+ else
33+ T = tag
34+ end
35+
2836 if x isa Array
2937 p = generate_chunked_partials (x,colorvec,chunksize)
30- t = similar (x,Dual{typeof (ForwardDiff . Tag (f, eltype ( vec (x)))), eltype (x), length ( first ( first (p))) })
38+ t = similar (x,Dual{T })
3139 for i in eachindex (t)
32- t[i] = Dual {typeof(ForwardDiff.Tag(f,eltype(vec(x)))) ,eltype(x),length(first(first(p)))} (x[i],ForwardDiff. Partials (first (p)[i]))
40+ t[i] = Dual {T ,eltype(x),length(first(first(p)))} (x[i],ForwardDiff. Partials (first (p)[i]))
3341 end
3442 else
3543 p = adapt .(parameterless_type (x),generate_chunked_partials (x,colorvec,chunksize))
36- _t = Dual {typeof(ForwardDiff.Tag(f,eltype(vec(x)))) } .(vec (x),first (p))
44+ _t = Dual {T } .(vec (x),first (p))
3745 t = ArrayInterface. restructure (x,_t)
3846 end
3947
@@ -44,7 +52,7 @@ function ForwardColorJacCache(f::F,x,_chunksize = nothing;
4452 else
4553 tup = ArrayInterface. allowed_getindex (ArrayInterface. allowed_getindex (p,1 ),1 ) .* false
4654 _pi = adapt (parameterless_type (dx),[tup for i in 1 : length (dx)])
47- fx = reshape (Dual {typeof(ForwardDiff.Tag(f,eltype(vec(x)))) } .(vec (dx),_pi),size (dx)... )
55+ fx = reshape (Dual {T } .(vec (dx),_pi),size (dx)... )
4856 _dx = dx
4957 end
5058
@@ -162,7 +170,7 @@ function forwarddiff_color_jacobian(J::AbstractMatrix{<:Number},f::F,x::Abstract
162170
163171 for i in eachindex (p)
164172 partial_i = p[i]
165- t = reshape (Dual {typeof(ForwardDiff.Tag(f, eltype(vecx) ))} .(vecx, partial_i),size (t))
173+ t = reshape (Dual {gettag( eltype(t ))} .(vecx, partial_i),size (t))
166174 fx = f (t)
167175 if ! (sparsity isa Nothing)
168176 for j in 1 : chunksize
@@ -230,7 +238,7 @@ function forwarddiff_color_jacobian_immutable(f,x::AbstractArray{<:Number},jac_c
230238
231239 for i in eachindex (p)
232240 partial_i = p[i]
233- t = reshape (Dual {typeof(ForwardDiff.Tag(f, eltype(vecx) ))} .(vecx, partial_i),size (t))
241+ t = reshape (Dual {gettag( eltype(t ))} .(vecx, partial_i),size (t))
234242 fx = f (t)
235243 if ! (sparsity isa Nothing)
236244 for j in 1 : chunksize
@@ -311,10 +319,10 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
311319
312320 if vect isa Array
313321 @inbounds @simd ivdep for j in eachindex (vect)
314- vect[j] = Dual {typeof(ForwardDiff.Tag(f, eltype(vecx) ))} (vecx[j], partial_i[j])
322+ vect[j] = Dual {gettag( eltype(t ))} (vecx[j], partial_i[j])
315323 end
316324 else
317- vect .= Dual {typeof(ForwardDiff.Tag(f, eltype(vecx) ))} .(vecx, partial_i)
325+ vect .= Dual {gettag( eltype(t ))} .(vecx, partial_i)
318326 end
319327
320328 f (fx,t)
0 commit comments