Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 93a8287

Browse files
Allow users to set the tag for the configs
1 parent cb76902 commit 93a8287

1 file changed

Lines changed: 16 additions & 8 deletions

File tree

src/differentiation/compute_jacobian_ad.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ end
1111
getsize(::Val{N}) where N = N
1212
getsize(N::Integer) = N
1313
void_setindex!(args...) = (setindex!(args...); return)
14+
gettag(::Type{ForwardDiff.Dual{T}}) where {T} = T
1415

1516
const default_chunk_size = ForwardDiff.pickchunksize
1617

1718
function ForwardColorJacCache(f::F,x,_chunksize = nothing;
1819
dx = nothing,
20+
tag = nothing,
1921
colorvec=1:length(x),
2022
sparsity::Union{AbstractArray,Nothing}=nothing) where {F}
2123

@@ -25,15 +27,21 @@ function ForwardColorJacCache(f::F,x,_chunksize = nothing;
2527
chunksize = _chunksize
2628
end
2729

30+
if tag === nothing
31+
T = typeof(ForwardDiff.Tag(f,eltype(vec(x))))
32+
else
33+
T = tag
34+
end
35+
2836
if x isa Array
2937
p = generate_chunked_partials(x,colorvec,chunksize)
30-
t = similar(x,Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x)))),eltype(x),length(first(first(p)))})
38+
t = similar(x,Dual{T})
3139
for i in eachindex(t)
32-
t[i] = Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x)))),eltype(x),length(first(first(p)))}(x[i],ForwardDiff.Partials(first(p)[i]))
40+
t[i] = Dual{T,eltype(x),length(first(first(p)))}(x[i],ForwardDiff.Partials(first(p)[i]))
3341
end
3442
else
3543
p = adapt.(parameterless_type(x),generate_chunked_partials(x,colorvec,chunksize))
36-
_t = Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x))))}.(vec(x),first(p))
44+
_t = Dual{T}.(vec(x),first(p))
3745
t = ArrayInterface.restructure(x,_t)
3846
end
3947

@@ -44,7 +52,7 @@ function ForwardColorJacCache(f::F,x,_chunksize = nothing;
4452
else
4553
tup = ArrayInterface.allowed_getindex(ArrayInterface.allowed_getindex(p,1),1) .* false
4654
_pi = adapt(parameterless_type(dx),[tup for i in 1:length(dx)])
47-
fx = reshape(Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x))))}.(vec(dx),_pi),size(dx)...)
55+
fx = reshape(Dual{T}.(vec(dx),_pi),size(dx)...)
4856
_dx = dx
4957
end
5058

@@ -162,7 +170,7 @@ function forwarddiff_color_jacobian(J::AbstractMatrix{<:Number},f::F,x::Abstract
162170

163171
for i in eachindex(p)
164172
partial_i = p[i]
165-
t = reshape(Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}.(vecx, partial_i),size(t))
173+
t = reshape(Dual{gettag(eltype(t))}.(vecx, partial_i),size(t))
166174
fx = f(t)
167175
if !(sparsity isa Nothing)
168176
for j in 1:chunksize
@@ -230,7 +238,7 @@ function forwarddiff_color_jacobian_immutable(f,x::AbstractArray{<:Number},jac_c
230238

231239
for i in eachindex(p)
232240
partial_i = p[i]
233-
t = reshape(Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}.(vecx, partial_i),size(t))
241+
t = reshape(Dual{gettag(eltype(t))}.(vecx, partial_i),size(t))
234242
fx = f(t)
235243
if !(sparsity isa Nothing)
236244
for j in 1:chunksize
@@ -311,10 +319,10 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
311319

312320
if vect isa Array
313321
@inbounds @simd ivdep for j in eachindex(vect)
314-
vect[j] = Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}(vecx[j], partial_i[j])
322+
vect[j] = Dual{gettag(eltype(t))}(vecx[j], partial_i[j])
315323
end
316324
else
317-
vect .= Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}.(vecx, partial_i)
325+
vect .= Dual{gettag(eltype(t))}.(vecx, partial_i)
318326
end
319327

320328
f(fx,t)

0 commit comments

Comments
 (0)