Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit c831947

Browse files
format
1 parent f1e4f85 commit c831947

9 files changed

Lines changed: 141 additions & 108 deletions

ext/SparseDiffToolsZygote.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ end
1414

1515
### Jac, Hes products
1616

17-
function SparseDiffTools.numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
17+
function SparseDiffTools.numback_hesvec!(dy, f, x, v, cache1 = similar(v),
18+
cache2 = similar(v))
1819
g = let f = f
1920
(dx, x) -> dx .= first(Zygote.gradient(f, x))
2021
end
@@ -42,14 +43,20 @@ function SparseDiffTools.numback_hesvec(f, x, v)
4243
end
4344

4445
function SparseDiffTools.autoback_hesvec!(dy, f, x, v,
45-
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))),
46-
eltype(x), 1
47-
}.(x,
48-
ForwardDiff.Partials.(tuple.(reshape(v, size(x))))),
49-
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag(), eltype(x))),
50-
eltype(x), 1
51-
}.(x,
52-
ForwardDiff.Partials.(tuple.(reshape(v, size(x))))))
46+
cache1 = Dual{
47+
typeof(ForwardDiff.Tag(DeivVecTag(),
48+
eltype(x))),
49+
eltype(x), 1
50+
}.(x,
51+
ForwardDiff.Partials.(tuple.(reshape(v,
52+
size(x))))),
53+
cache2 = Dual{
54+
typeof(ForwardDiff.Tag(DeivVecTag(),
55+
eltype(x))),
56+
eltype(x), 1
57+
}.(x,
58+
ForwardDiff.Partials.(tuple.(reshape(v,
59+
size(x))))))
5360
g = let f = f
5461
(dx, x) -> dx .= first(Zygote.gradient(f, x))
5562
end

src/SparseDiffTools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ function auto_vecjac! end
8181

8282
@static if !isdefined(Base, :get_extension)
8383
function __init__()
84-
Requires.@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
84+
Requires.@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
8585
include("../ext/SparseDiffToolsZygote.jl")
8686
@reexport using .SparseDiffToolsZygote
8787
end

src/coloring/high_level.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ Note that if A isa SparseMatrixCSC, the sparsity pattern is defined by structura
1717
ie includes explicitly stored zeros.
1818
"""
1919
function ArrayInterface.matrix_colors(A::AbstractMatrix,
20-
alg::SparseDiffToolsColoringAlgorithm = GreedyD1Color();
21-
partition_by_rows::Bool = false)
20+
alg::SparseDiffToolsColoringAlgorithm = GreedyD1Color();
21+
partition_by_rows::Bool = false)
2222
_A = A isa SparseMatrixCSC ? A : sparse(A) # Avoid the copy
2323
A_graph = matrix2graph(_A, partition_by_rows)
2424
return color_graph(A_graph, alg)

src/differentiation/compute_jacobian_ad.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ function ForwardColorJacCache(f::F, x, _chunksize = nothing;
5252
_dx = similar(x)
5353
else
5454
tup = ArrayInterface.allowed_getindex(ArrayInterface.allowed_getindex(p, 1),
55-
1) .* false
55+
1) .* false
5656
_pi = adapt(parameterless_type(dx), [tup for i in 1:length(dx)])
5757
fx = reshape(Dual{T, eltype(dx), length(tup)}.(vec(dx), ForwardDiff.Partials.(_pi)),
5858
size(dx)...)

src/differentiation/jaches_products.jl

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ end
201201

202202
### Operator Forms
203203

204-
struct FwdModeAutoDiffVecProd{F,U,C,V,V!} <: AbstractAutoDiffVecProd
204+
struct FwdModeAutoDiffVecProd{F, U, C, V, V!} <: AbstractAutoDiffVecProd
205205
f::F
206206
u::U
207207
cache::C
@@ -230,16 +230,15 @@ end
230230

231231
function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff(),
232232
kwargs...)
233-
234233
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
235234
cache1 = similar(u)
236235
cache2 = similar(u)
237236

238237
(cache1, cache2), num_jacvec, num_jacvec!
239238
elseif autodiff isa AutoForwardDiff
240239
cache1 = Dual{
241-
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
242-
}.(u, ForwardDiff.Partials.(tuple.(u)))
240+
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
241+
}.(u, ForwardDiff.Partials.(tuple.(u)))
243242

244243
cache2 = copy(cache1)
245244

@@ -249,7 +248,7 @@ function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
249248
end
250249

251250
outofplace = static_hasmethod(f, typeof((u,)))
252-
isinplace = static_hasmethod(f, typeof((u, u,)))
251+
isinplace = static_hasmethod(f, typeof((u, u)))
253252

254253
if !(isinplace) & !(outofplace)
255254
error("$f must have signature f(u), or f(du, u).")
@@ -260,13 +259,11 @@ function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
260259
FunctionOperator(L, u, u;
261260
isinplace = isinplace, outofplace = outofplace,
262261
p = p, t = t, islinear = true,
263-
kwargs...,
264-
)
262+
kwargs...)
265263
end
266264

267265
function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff(),
268266
kwargs...)
269-
270267
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
271268
cache1 = similar(u)
272269
cache2 = similar(u)
@@ -283,8 +280,8 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
283280
@assert static_hasmethod(autoback_hesvec, typeof((f, u, u))) "To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"
284281

285282
cache1 = Dual{
286-
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
287-
}.(u, ForwardDiff.Partials.(tuple.(u)))
283+
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
284+
}.(u, ForwardDiff.Partials.(tuple.(u)))
288285
cache2 = copy(cache1)
289286

290287
(cache1, cache2), autoback_hesvec, autoback_hesvec!
@@ -293,7 +290,7 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
293290
end
294291

295292
outofplace = static_hasmethod(f, typeof((u,)))
296-
isinplace = static_hasmethod(f, typeof((u,)))
293+
isinplace = static_hasmethod(f, typeof((u,)))
297294

298295
if !(isinplace) & !(outofplace)
299296
error("$f must have signature f(u).")
@@ -304,13 +301,12 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFo
304301
FunctionOperator(L, u, u;
305302
isinplace = isinplace, outofplace = outofplace,
306303
p = p, t = t, islinear = true,
307-
kwargs...,
308-
)
304+
kwargs...)
309305
end
310306

311-
function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff(),
307+
function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing;
308+
autodiff = AutoForwardDiff(),
312309
kwargs...)
313-
314310
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
315311
cache1 = similar(u)
316312
cache2 = similar(u)
@@ -319,7 +315,7 @@ function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = Au
319315
elseif autodiff isa AutoForwardDiff
320316
cache1 = Dual{
321317
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
322-
}.(u, ForwardDiff.Partials.(tuple.(u)))
318+
}.(u, ForwardDiff.Partials.(tuple.(u)))
323319
cache2 = copy(cache1)
324320

325321
(cache1, cache2), auto_hesvecgrad, auto_hesvecgrad!
@@ -328,7 +324,7 @@ function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = Au
328324
end
329325

330326
outofplace = static_hasmethod(f, typeof((u,)))
331-
isinplace = static_hasmethod(f, typeof((u, u,)))
327+
isinplace = static_hasmethod(f, typeof((u, u)))
332328

333329
if !(isinplace) & !(outofplace)
334330
error("$f must have signature f(u), or f(du, u).")
@@ -339,7 +335,6 @@ function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = Au
339335
FunctionOperator(L, u, u;
340336
isinplace = isinplace, outofplace = outofplace,
341337
p = p, t = t, islinear = true,
342-
kwargs...,
343-
)
338+
kwargs...)
344339
end
345340
#

src/differentiation/vecjac_products.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ end
3737

3838
### Operator Forms
3939

40-
struct RevModeAutoDiffVecProd{ad,iip,oop,F,U,C,V,V!} <: AbstractAutoDiffVecProd
40+
struct RevModeAutoDiffVecProd{ad, iip, oop, F, U, C, V, V!} <: AbstractAutoDiffVecProd
4141
f::F
4242
u::U
4343
cache::C
@@ -57,10 +57,8 @@ struct RevModeAutoDiffVecProd{ad,iip,oop,F,U,C,V,V!} <: AbstractAutoDiffVecProd
5757
typeof(u),
5858
typeof(cache),
5959
typeof(vecprod),
60-
typeof(vecprod!),
61-
}(
62-
f, u, cache, vecprod, vecprod!,
63-
)
60+
typeof(vecprod!)
61+
}(f, u, cache, vecprod, vecprod!)
6462
end
6563
end
6664

@@ -81,17 +79,16 @@ function (L::RevModeAutoDiffVecProd)(v, p, t)
8179
end
8280

8381
# prefer non in-place method
84-
function (L::RevModeAutoDiffVecProd{ad,iip,true})(dv, v, p, t) where{ad,iip}
82+
function (L::RevModeAutoDiffVecProd{ad, iip, true})(dv, v, p, t) where {ad, iip}
8583
L.vecprod!(dv, L.f, L.u, v, L.cache...)
8684
end
8785

88-
function (L::RevModeAutoDiffVecProd{ad,true,false})(dv, v, p, t) where{ad}
86+
function (L::RevModeAutoDiffVecProd{ad, true, false})(dv, v, p, t) where {ad}
8987
L.vecprod!(dv, L.f, L.u, v, L.cache...)
9088
end
9189

9290
function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFiniteDiff(),
9391
kwargs...)
94-
9592
vecprod, vecprod! = if autodiff isa AutoFiniteDiff
9693
num_vecjac, num_vecjac!
9794
elseif autodiff isa AutoZygote
@@ -100,10 +97,10 @@ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFi
10097
auto_vecjac, auto_vecjac!
10198
end
10299

103-
cache = (similar(u), similar(u),)
100+
cache = (similar(u), similar(u))
104101

105102
outofplace = static_hasmethod(f, typeof((u,)))
106-
isinplace = static_hasmethod(f, typeof((u, u,)))
103+
isinplace = static_hasmethod(f, typeof((u, u)))
107104

108105
if !(isinplace) & !(outofplace)
109106
error("$f must have signature f(u), or f(du, u)")
@@ -115,7 +112,6 @@ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFi
115112
FunctionOperator(L, u, u;
116113
isinplace = isinplace, outofplace = outofplace,
117114
p = p, t = t, islinear = true,
118-
kwargs...
119-
)
115+
kwargs...)
120116
end
121117
#

0 commit comments

Comments
 (0)