Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit c0846a7

Browse files
proper duals for JacVec
1 parent 5152045 commit c0846a7

2 files changed

Lines changed: 10 additions & 5 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseDiffTools"
22
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
33
authors = ["Pankaj Mishra <pankajmishra1511@gmail.com>", "Chris Rackauckas <contact@chrisrackauckas.com>"]
4-
version = "1.19.1"
4+
version = "1.19.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/differentiation/jaches_products.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,21 @@ function auto_jacvec!(
66
f,
77
x,
88
v,
9-
cache1 = Dual{DeivVecTag}.(x, reshape(v, size(x))),
9+
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials(reshape(v, size(x)))),
1010
cache2 = similar(cache1),
1111
)
12-
cache1 .= Dual{DeivVecTag}.(x, reshape(v, size(x)))
12+
cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials(reshape(v, size(x))))
1313
f(cache2, cache1)
14-
dy .= partials.(cache2, 1)
14+
vecdy = _vec(dy)
15+
vecdy .= partials.(vec(cache2), 1)
1516
end
1617

18+
_vec(v) = vec(v)
19+
_vec(v::AbstractVector) = v
20+
1721
function auto_jacvec(f, x, v)
1822
vv = reshape(v, axes(x))
19-
vec(partials.(vec(f(ForwardDiff.Dual{DeivVecTag}.(x, vv))), 1))
23+
vec(partials.(vec(f(ForwardDiff.Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, vv))), 1))
2024
end
2125

2226
function num_jacvec!(
@@ -197,6 +201,7 @@ function JacVec(f, x::AbstractArray; autodiff = true)
197201
JacVec(f, cache1, cache2, x, autodiff)
198202
end
199203

204+
Base.eltype(L::JacVec) = eltype(L.x)
200205
Base.size(L::JacVec) = (length(L.cache1), length(L.cache1))
201206
Base.size(L::JacVec, i::Int) = length(L.cache1)
202207
Base.:*(L::JacVec, v::AbstractVector) =

0 commit comments

Comments
 (0)