Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 5a02552

Browse files
committed
fix VecJac constructor, clean auto_vecjac<bang>, add setfield
1 parent 65dbe8d commit 5a02552

4 files changed

Lines changed: 53 additions & 33 deletions

File tree

Project.toml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,19 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1616
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1717
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1818
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
19+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1920
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2021
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
2122
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2223
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
2324
VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"
2425

26+
[weakdeps]
27+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
28+
29+
[extensions]
30+
SparseDiffToolsZygote = "Zygote"
31+
2532
[compat]
2633
ADTypes = "0.1"
2734
Adapt = "3.0"
@@ -41,9 +48,6 @@ VertexSafeGraphs = "0.2"
4148
Zygote = "0.6"
4249
julia = "1.6"
4350

44-
[extensions]
45-
SparseDiffToolsZygote = "Zygote"
46-
4751
[extras]
4852
ArrayInterfaceBandedMatrices = "2e50d22c-5be1-4042-81b1-c572ed69783d"
4953
ArrayInterfaceBlockBandedMatrices = "5331f1e9-51c7-46b0-a9b0-df4434785e0a"
@@ -60,6 +64,3 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6064

6165
[targets]
6266
test = ["Test", "ArrayInterfaceBandedMatrices", "ArrayInterfaceBlockBandedMatrices", "BandedMatrices", "BlockBandedMatrices", "IterativeSolvers", "Pkg", "Random", "SafeTestsets", "Symbolics", "Zygote", "StaticArrays"]
63-
64-
[weakdeps]
65-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

src/SparseDiffTools.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using Graphs
77
using Graphs: SimpleGraph
88
using VertexSafeGraphs
99
using Adapt
10+
1011
using Reexport
1112
@reexport using ADTypes
1213

@@ -23,6 +24,7 @@ using ArrayInterface: matrix_colors
2324
using SciMLOperators
2425
import SciMLOperators: update_coefficients, update_coefficients!
2526
using Tricks: Tricks, static_hasmethod
27+
using Setfield: @set!
2628

2729
abstract type AbstractAutoDiffVecProd end
2830

src/differentiation/jaches_products.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ struct FwdModeAutoDiffVecProd{F, U, C, V, V!} <: AbstractAutoDiffVecProd
210210
end
211211

212212
function update_coefficients(L::FwdModeAutoDiffVecProd, u, p, t)
213-
f = update_coefficients(L.f, u, p, t)
214-
FwdModeAutoDiffVecProd(f, u, L.cache, L.vecprod, L.vecprod!)
213+
@set! L.f = update_coefficients(L.f, u, p, t)
214+
@set! L.u = u
215215
end
216216

217217
function update_coefficients!(L::FwdModeAutoDiffVecProd, u, p, t)
@@ -248,7 +248,7 @@ function JacVec(f, u::AbstractArray, p = nothing, t = nothing;
248248
elseif autodiff isa AutoForwardDiff
249249
cache1 = Dual{
250250
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1
251-
}.(u, ForwardDiff.Partials.(tuple.(u)))
251+
}.(u, ForwardDiff.Partials.(tuple.(u)))
252252

253253
cache2 = copy(cache1)
254254

src/differentiation/vecjac_products.jl

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,17 @@ struct RevModeAutoDiffVecProd{ad, iip, oop, F, U, C, V, V!} <: AbstractAutoDiffV
4343
cache::C
4444
vecprod::V
4545
vecprod!::V!
46+
autodiff::ad
4647

47-
function RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!;
48-
autodiff = AutoFiniteDiff(),
49-
isinplace = false, outofplace = true)
50-
@assert isinplace || outofplace
48+
function RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!, autodiff)
49+
50+
outofplace = static_hasmethod(f, typeof((u,)))
51+
isinplace = static_hasmethod(f, typeof((u, u)))
52+
53+
if !(isinplace) & !(outofplace)
54+
msg = "$f must have signature f(u), or f(du, u)"
55+
throw(ArgumentError(msg))
56+
end
5157

5258
new{
5359
typeof(autodiff),
@@ -58,13 +64,19 @@ struct RevModeAutoDiffVecProd{ad, iip, oop, F, U, C, V, V!} <: AbstractAutoDiffV
5864
typeof(cache),
5965
typeof(vecprod),
6066
typeof(vecprod!)
61-
}(f, u, cache, vecprod, vecprod!)
67+
}(
68+
f, u, cache, vecprod, vecprod!, autodiff,
69+
)
6270
end
6371
end
6472

73+
function get_iip_oop(::RevModeAutoDiffVecProd{ad, iip, oop}) where{ad, iip, oop}
74+
iip, oop
75+
end
76+
6577
function update_coefficients(L::RevModeAutoDiffVecProd, u, p, t)
66-
f = update_coefficients(L.f, u, p, t)
67-
RevModeAutoDiffVecProd(f, u, L.vecprod, L.vecprod!, L.cache)
78+
@set! L.f = update_coefficients(L.f, u, p, t)
79+
@set! L.u = u
6880
end
6981

7082
function update_coefficients!(L::RevModeAutoDiffVecProd, u, p, t)
@@ -97,31 +109,36 @@ function Base.resize!(L::RevModeAutoDiffVecProd, n::Integer)
97109
end
98110
end
99111

100-
function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFiniteDiff(),
101-
kwargs...)
102-
vecprod, vecprod! = if autodiff isa AutoFiniteDiff
103-
num_vecjac, num_vecjac!
112+
"""
113+
VecJac(f, u, [p, t]; autodiff = AutoFiniteDiff())
114+
115+
Returns FunctionOperator that computes
116+
"""
117+
function VecJac(f, u::AbstractArray, p = nothing, t = nothing;
118+
autodiff = AutoFiniteDiff(), kwargs...)
119+
120+
vecprod, vecprod!, cache = if autodiff isa AutoFiniteDiff
121+
num_vecjac, num_vecjac!, (similar(u), similar(u))
104122
elseif autodiff isa AutoZygote
105123
@assert static_hasmethod(auto_vecjac, typeof((f, u, u))) "To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"
106124

107-
auto_vecjac, auto_vecjac!
125+
auto_vecjac, auto_vecjac!, ()
108126
end
109127

110-
cache = (similar(u), similar(u))
128+
L = RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!, autodiff)
111129

112-
outofplace = static_hasmethod(f, typeof((u,)))
113-
isinplace = static_hasmethod(f, typeof((u, u)))
130+
iip, oop = get_iip_oop(L)
114131

115-
if !(isinplace) & !(outofplace)
116-
error("$f must have signature f(u), or f(du, u)")
117-
end
132+
FunctionOperator(L, u, u; isinplace = iip, outofplace = oop,
133+
p = p, t = t, islinear = true, kwargs...)
134+
end
118135

119-
L = RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!; autodiff = autodiff,
120-
isinplace = isinplace, outofplace = outofplace)
121136

122-
FunctionOperator(L, u, u;
123-
isinplace = isinplace, outofplace = outofplace,
124-
p = p, t = t, islinear = true,
125-
kwargs...)
137+
function FixedVecJac(f, u::AbstractArray, p = nothing, t = nothing;
138+
autodiff = AutoFiniteDiff(), kwargs...)
139+
_fixedvecjac(f, u, p, t, autodiff, kwargs)
140+
end
141+
142+
function _fixedvecjac(f, u, p, t, ad::AutoFiniteDiff, kwargs)
126143
end
127144
#

0 commit comments

Comments
 (0)