Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 7cb7918

Browse files
Merge pull request #241 from vpuri3/tag
add `tag` kwarg to JacVec, HesVec, HesVecGrad
2 parents c831947 + e77e83f commit 7cb7918

3 files changed

Lines changed: 43 additions & 11 deletions

File tree

ext/SparseDiffToolsZygote.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ end
6868

6969
function SparseDiffTools.autoback_hesvec(f, x, v)
7070
g = x -> first(Zygote.gradient(f, x))
71-
y = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1
71+
y = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))), eltype(x), 1
7272
}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
7373
ForwardDiff.partials.(g(y), 1)
7474
end

src/differentiation/jaches_products.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -228,16 +228,16 @@ function (L::FwdModeAutoDiffVecProd)(dv, v, p, t)
228228
L.vecprod!(dv, L.f, L.u, v, L.cache...)
229229
end
230230

231-
function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff(),
232-
kwargs...)
231+
function JacVec(f, u::AbstractArray, p = nothing, t = nothing;
232+
autodiff = AutoForwardDiff(), tag = DeivVecTag(), kwargs...)
233233
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
234234
cache1 = similar(u)
235235
cache2 = similar(u)
236236

237237
(cache1, cache2), num_jacvec, num_jacvec!
238238
elseif autodiff isa AutoForwardDiff
239239
cache1 = Dual{
240-
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
240+
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1
241241
}.(u, ForwardDiff.Partials.(tuple.(u)))
242242

243243
cache2 = copy(cache1)
@@ -262,8 +262,8 @@ function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
262262
kwargs...)
263263
end
264264

265-
function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff(),
266-
kwargs...)
265+
function HesVec(f, u::AbstractArray, p = nothing, t = nothing;
266+
autodiff = AutoForwardDiff(), tag = DeivVecTag(), kwargs...)
267267
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
268268
cache1 = similar(u)
269269
cache2 = similar(u)
@@ -280,7 +280,7 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
280280
@assert static_hasmethod(autoback_hesvec, typeof((f, u, u))) "To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"
281281

282282
cache1 = Dual{
283-
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
283+
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1
284284
}.(u, ForwardDiff.Partials.(tuple.(u)))
285285
cache2 = copy(cache1)
286286

@@ -305,16 +305,15 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
305305
end
306306

307307
function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing;
308-
autodiff = AutoForwardDiff(),
309-
kwargs...)
308+
autodiff = AutoForwardDiff(), tag = DeivVecTag(), kwargs...)
310309
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
311310
cache1 = similar(u)
312311
cache2 = similar(u)
313312

314313
(cache1, cache2), num_hesvecgrad, num_hesvecgrad!
315314
elseif autodiff isa AutoForwardDiff
316315
cache1 = Dual{
317-
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
316+
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1
318317
}.(u, ForwardDiff.Partials.(tuple.(u)))
319318
cache2 = copy(cache1)
320319

test/test_jaches_products.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
using SparseDiffTools, ForwardDiff, FiniteDiff, Zygote, IterativeSolvers
22
using LinearAlgebra, Test
3+
using SparseDiffTools: get_tag, DeivVecTag
34

45
using Random
56
Random.seed!(123)
6-
N = 300
77

8+
struct MyTag end
9+
10+
N = 300
811
x = rand(N)
912
v = rand(N)
1013

@@ -104,6 +107,10 @@ _dy = copy(dy);
104107
update_coefficients!(f, v, 5.0, 6.0)
105108
@test L(dy, v, 5.0, 6.0) auto_jacvec(f, v, v)
106109

110+
# GMRES test
111+
out = similar(v)
112+
@test_nowarn gmres!(out, L, v)
113+
107114
L = JacVec(f, copy(x), 1.0, 1.0; autodiff = AutoFiniteDiff())
108115
update_coefficients!(f, x, 1.0, 1.0)
109116
@test L * x num_jacvec(f, x, x)
@@ -121,9 +128,16 @@ _dy = copy(dy);
121128
update_coefficients!(f, v, 5.0, 6.0)
122129
@test L(dy, v, 5.0, 6.0)num_jacvec(f, v, v) rtol=1e-6
123130

131+
# GMRES test
124132
out = similar(v)
125133
@test_nowarn gmres!(out, L, v)
126134

135+
# Tag test
136+
L = JacVec(f, copy(x), 1.0, 1.0)
137+
@test get_tag(L.op.cache[1]) === ForwardDiff.Tag{DeivVecTag, eltype(x)}
138+
L = JacVec(f, copy(x), 1.0, 1.0; tag = MyTag())
139+
@test get_tag(L.op.cache[1]) === ForwardDiff.Tag{MyTag, eltype(x)}
140+
127141
@info "HesVec"
128142

129143
L = HesVec(g, copy(x), 1.0, 1.0, autodiff = AutoFiniteDiff())
@@ -159,6 +173,7 @@ _dy = copy(dy);
159173
update_coefficients!(g, v, 5.0, 6.0)
160174
@test L(dy, v, 5.0, 6.0) numauto_hesvec(g, v, v)
161175

176+
# GMRES test
162177
out = similar(v)
163178
gmres!(out, L, v)
164179

@@ -179,9 +194,16 @@ _dy = copy(dy);
179194
update_coefficients!(g, v, 5.0, 6.0)
180195
@test L(dy, v, 5.0, 6.0) autoback_hesvec(g, v, v)
181196

197+
# GMRES test
182198
out = similar(v)
183199
gmres!(out, L, v)
184200

201+
# Tag test
202+
L = HesVec(g, copy(x), 1.0, 1.0; autodiff = AutoZygote())
203+
@test get_tag(L.op.cache[1]) === ForwardDiff.Tag{DeivVecTag, eltype(x)}
204+
L = HesVec(g, copy(x), 1.0, 1.0; autodiff = AutoZygote(), tag = MyTag())
205+
@test get_tag(L.op.cache[1]) === ForwardDiff.Tag{MyTag, eltype(x)}
206+
185207
@info "HesVecGrad"
186208

187209
L = HesVecGrad(h, copy(x), 1.0, 1.0; autodiff = AutoFiniteDiff())
@@ -203,6 +225,10 @@ _dy = copy(dy);
203225
update_coefficients!(g, v, 5.0, 6.0)
204226
@test L(dy, v, 5.0, 6.0)num_hesvec(g, v, v) rtol=1e-2
205227

228+
# GMRES test
229+
out = similar(v)
230+
gmres!(out, L, v)
231+
206232
L = HesVecGrad(h, copy(x), 1.0, 1.0)
207233
update_coefficients!(g, x, 1.0, 1.0)
208234
update_coefficients!(h, x, 1.0, 1.0)
@@ -223,6 +249,7 @@ update_coefficients!(g, v, 5.0, 6.0)
223249
update_coefficients!(h, v, 5.0, 6.0)
224250
@test L(dy, v, 5.0, 6.0) numauto_hesvec(g, v, v)
225251

252+
# GMRES test
226253
out = similar(v)
227254
gmres!(out, L, v)
228255

@@ -231,4 +258,10 @@ gmres!(out, L, v)
231258
@test x x0
232259
@test v v0
233260

261+
# Tag test
262+
L = HesVecGrad(g, copy(x), 1.0, 1.0)
263+
@test get_tag(L.op.cache[1]) === ForwardDiff.Tag{DeivVecTag, eltype(x)}
264+
L = HesVecGrad(g, copy(x), 1.0, 1.0; tag = MyTag())
265+
@test get_tag(L.op.cache[1]) === ForwardDiff.Tag{MyTag, eltype(x)}
266+
234267
#

0 commit comments

Comments
 (0)