Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit bfd3a67

Browse files
committed
Add new Higher Level API
1 parent 8e1b626 commit bfd3a67

14 files changed

Lines changed: 638 additions & 15 deletions

Project.toml

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
name = "SparseDiffTools"
22
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
3-
authors = ["Pankaj Mishra <pankajmishra1511@gmail.com>", "Chris Rackauckas <contact@chrisrackauckas.com>"]
3+
authors = [
4+
"Pankaj Mishra <pankajmishra1511@gmail.com>",
5+
"Chris Rackauckas <contact@chrisrackauckas.com>",
6+
]
47
version = "2.5.0"
58

69
[deps]
@@ -21,12 +24,15 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2124
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
2225
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2326
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
27+
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2428
VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"
2529

2630
[weakdeps]
31+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
2732
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2833

2934
[extensions]
35+
SparseDiffToolsSymbolicsExt = "Symbolics"
3036
SparseDiffToolsZygoteExt = "Zygote"
3137

3238
[compat]
@@ -43,7 +49,9 @@ SciMLOperators = "0.2.11, 0.3"
4349
Setfield = "1"
4450
StaticArrayInterface = "1.3"
4551
StaticArrays = "1"
52+
Symbolics = "5.5"
4653
Tricks = "0.1.6"
54+
UnPack = "1"
4755
VertexSafeGraphs = "0.2"
4856
Zygote = "0.6"
4957
julia = "1.6"
@@ -52,6 +60,7 @@ julia = "1.6"
5260
ArrayInterfaceBandedMatrices = "2e50d22c-5be1-4042-81b1-c572ed69783d"
5361
ArrayInterfaceBlockBandedMatrices = "5331f1e9-51c7-46b0-a9b0-df4434785e0a"
5462
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
63+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
5564
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
5665
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
5766
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
@@ -63,4 +72,18 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6372
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6473

6574
[targets]
66-
test = ["Test", "ArrayInterfaceBandedMatrices", "ArrayInterfaceBlockBandedMatrices", "BandedMatrices", "BlockBandedMatrices", "IterativeSolvers", "Pkg", "Random", "SafeTestsets", "Symbolics", "Zygote", "StaticArrays"]
75+
test = [
76+
"Test",
77+
"ArrayInterfaceBandedMatrices",
78+
"ArrayInterfaceBlockBandedMatrices",
79+
"BandedMatrices",
80+
"BlockBandedMatrices",
81+
"IterativeSolvers",
82+
"Pkg",
83+
"Random",
84+
"SafeTestsets",
85+
"Symbolics",
86+
"Zygote",
87+
"StaticArrays",
88+
"BenchmarkTools",
89+
]

README.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,67 @@ function g(x) # out-of-place
4747
end
4848
```
4949

50+
## High Level API
51+
52+
We need to perform the following steps to utilize SparseDiffTools:
53+
54+
1. Specify a Sparsity Detection Algorithm. There are 3 possible choices currently:
55+
1. `NoSparsityDetection`: This will ignore any AD choice and compute the dense Jacobian
56+
2. `JacPrototypeSparsityDetection`: If you already know the sparsity pattern, you can
57+
specify it as `JacPrototypeSparsityDetection(; jac_prototype=<sparsity pattern>)`.
58+
3. `SymbolicsSparsityDetection`: This will use `Symbolics.jl` to automatically detect
59+
the sparsity pattern. (Note that `Symbolics.jl` must be explicitly loaded before
60+
using this functionality.)
61+
2. Now choose an AD backend from `ADTypes.jl`:
62+
1. If using a Non `*Sparse*` type, then we will not use sparsity detection.
63+
2. All other sparse AD types will internally compute the proper sparsity pattern, and
64+
try to exploit that.
65+
3. Now there are 2 options:
66+
1. Precompute the cache using `sparse_jacobian_cache` and use the `sparse_jacobian` or
67+
`sparse_jacobian!` functions to compute the Jacobian. This option is recommended if
68+
you are repeatedly computing the Jacobian for the same function.
69+
2. Directly use `sparse_jacobian` or `sparse_jacobian!` to compute the Jacobian. This
70+
option should be used if you are only computing the Jacobian once.
71+
72+
```julia
73+
using Symbolics
74+
75+
sd = SymbolicsSparsityDetection()
76+
adtype = AutoSparseFiniteDiff()
77+
x = rand(30)
78+
y = similar(x)
79+
80+
# Option 1
81+
## OOP Function
82+
cache = sparse_jacobian_cache(adtype, sd, g, x; fx=y) # Passing `fx` is needed if size(y) != size(x)
83+
J = sparse_jacobian(adtype, cache, g, x)
84+
### Or
85+
J_preallocated = similar(J)
86+
sparse_jacobian!(J_preallocated, adtype, cache, g, x)
87+
88+
## IIP Function
89+
cache = sparse_jacobian_cache(adtype, sd, f, y, x)
90+
J = sparse_jacobian(adtype, cache, f, y, x)
91+
### Or
92+
J_preallocated = similar(J)
93+
sparse_jacobian!(J_preallocated, adtype, cache, f, y, x)
94+
95+
# Option 2
96+
## OOP Function
97+
J = sparse_jacobian(adtype, sd, g, x)
98+
### Or
99+
J_preallocated = similar(J)
100+
sparse_jacobian!(J_preallocated, adtype, sd, g, x)
101+
102+
## IIP Function
103+
J = sparse_jacobian(adtype, sd, f, y, x)
104+
### Or
105+
J_preallocated = similar(J)
106+
sparse_jacobian!(J_preallocated, adtype, sd, f, y, x)
107+
```
108+
109+
## Lower Level API
110+
50111
For this function, we know that the sparsity pattern of the Jacobian is a
51112
`Tridiagonal` matrix. However, if we didn't know the sparsity pattern for
52113
the Jacobian, we could use the `Symbolics.jacobian_sparsity` function to automatically

ext/SparseDiffToolsSymbolicsExt.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module SparseDiffToolsSymbolicsExt
2+
3+
using SparseDiffTools, Symbolics
4+
import SparseDiffTools: AbstractSparseADType
5+
6+
function (alg::SymbolicsSparsityDetection)(ad::AbstractSparseADType, f, x; fx=nothing,
7+
kwargs...)
8+
fx = fx === nothing ? similar(f(x)) : dx
9+
f!(y, x) = (y .= f(x))
10+
J = Symbolics.jacobian_sparsity(f!, fx, x)
11+
_alg = JacPrototypeSparsityDetection(J, alg.alg)
12+
return _alg(ad, f, x; fx, kwargs...)
13+
end
14+
15+
function (alg::SymbolicsSparsityDetection)(ad::AbstractSparseADType, f!, fx, x; kwargs...)
16+
J = Symbolics.jacobian_sparsity(f!, fx, x)
17+
_alg = JacPrototypeSparsityDetection(J, alg.alg)
18+
return _alg(ad, f!, fx, x; kwargs...)
19+
end
20+
21+
end

ext/SparseDiffToolsZygoteExt.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,36 @@ import Tricks: static_hasmethod
99

1010
import SparseDiffTools: numback_hesvec!,
1111
numback_hesvec, autoback_hesvec!, autoback_hesvec, auto_vecjac!, auto_vecjac
12+
import SparseDiffTools: __f̂, __jacobian!, __gradient, __gradient!
13+
import ADTypes: AutoZygote, AutoSparseZygote
14+
15+
## Satisfying High-Level Interface for Sparse Jacobians
16+
function __gradient(::Union{AutoSparseZygote, AutoZygote}, f, x, cols)
17+
_, ∂x, _ = Zygote.gradient(__f̂, f, x, cols)
18+
return vec(∂x)
19+
end
20+
21+
function __gradient!(::Union{AutoSparseZygote, AutoZygote}, f!, fx, x, cols)
22+
return error("Zygote.jl cannot differentiate in-place (mutating) functions.")
23+
end
24+
25+
# Zygote doesn't provide a way to accumulate directly into `J`. So we modify the code from
26+
# https://github.com/FluxML/Zygote.jl/blob/82c7a000bae7fb0999275e62cc53ddb61aed94c7/src/lib/grad.jl#L140-L157C4
27+
import Zygote: _jvec, _eyelike, _gradcopy!
28+
29+
@views function __jacobian!(J::AbstractMatrix, ::Union{AutoSparseZygote, AutoZygote}, f, x)
30+
y, back = Zygote.pullback(_jvec f, x)
31+
δ = _eyelike(y)
32+
for k in LinearIndices(y)
33+
grad = only(back(δ[:, k]))
34+
_gradcopy!(J[k, :], grad)
35+
end
36+
return J
37+
end
38+
39+
function __jacobian!(J, ::Union{AutoSparseZygote, AutoZygote}, f!, fx, x)
40+
return error("Zygote.jl cannot differentiate in-place (mutating) functions.")
41+
end
1242

1343
### Jac, Hes products
1444

@@ -117,7 +147,7 @@ function (L::AutoDiffVJP{<:AutoZygote, IIP, true})(dv, v, p, t;
117147
copy!(dv, _dv)
118148
end
119149

120-
function (L::AutoDiffVJP{<:AutoZygote, true, false})(_, _, _, _; VJP_input = nothing)
150+
function (L::AutoDiffVJP{<:AutoZygote, true, false})(args...; kwargs...)
121151
error("Zygote requires an out of place method with signature f(u).")
122152
end
123153

src/SparseDiffTools.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@ module SparseDiffTools
22

33
# QoL/Helper Packages
44
using Adapt, Compat, Reexport
5+
import UnPack: @unpack
56
# Graph Coloring
67
using Graphs, VertexSafeGraphs
78
import Graphs: SimpleGraph
89
# Differentiation
910
using FiniteDiff, ForwardDiff
1011
@reexport using ADTypes
12+
import ADTypes: AbstractADType, AutoSparseZygote, AbstractSparseForwardMode,
13+
AbstractSparseReverseMode, AbstractSparseFiniteDifferences, AbstractReverseMode
1114
import ForwardDiff: Dual, jacobian, partials, DEFAULT_CHUNK_THRESHOLD
1215
# Array Packages
1316
using ArrayInterface, SparseArrays
@@ -41,6 +44,13 @@ include("differentiation/compute_hessian_ad.jl")
4144
include("differentiation/jaches_products.jl")
4245
include("differentiation/vecjac_products.jl")
4346

47+
# High Level Interface
48+
include("highlevel/common.jl")
49+
include("highlevel/coloring.jl")
50+
include("highlevel/forward_mode.jl")
51+
include("highlevel/reverse_mode.jl")
52+
include("highlevel/finite_diff.jl")
53+
4454
Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
4555
parameterless_type(x) = parameterless_type(typeof(x))
4656
parameterless_type(x::Type) = __parameterless_type(x)
@@ -67,11 +77,18 @@ export auto_jacvec, auto_jacvec!, num_jacvec, num_jacvec!
6777
export num_vecjac, num_vecjac!, auto_vecjac, auto_vecjac!
6878
# HesVec Products
6979
export numauto_hesvec,
70-
numauto_hesvec!, autonum_hesvec, autonum_hesvec!, numback_hesvec, numback_hesvec!
80+
numauto_hesvec!, autonum_hesvec, autonum_hesvec!, numback_hesvec, numback_hesvec!,
81+
num_hesvec, num_hesvec!, autoback_hesvec, autoback_hesvec!
7182
# HesVecGrad Products
7283
export num_hesvecgrad, num_hesvecgrad!, auto_hesvecgrad, auto_hesvecgrad!
7384
# Operators
7485
export JacVec, HesVec, HesVecGrad, VecJac
7586
export update_coefficients, update_coefficients!, value!
7687

88+
# High Level Interface: sparse_jacobian
89+
export AutoSparseZygote # FIXME: Remove once https://github.com/SciML/ADTypes.jl/pull/16 is merged
90+
export NoSparsityDetection,
91+
SymbolicsSparsityDetection, JacPrototypeSparsityDetection, AutoSparsityDetection
92+
export sparse_jacobian, sparse_jacobian_cache, sparse_jacobian!
93+
7794
end # module

src/coloring/high_level.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,31 @@ struct GreedyStar2Color <: SparseDiffToolsColoringAlgorithm end
77
struct AcyclicColoring <: SparseDiffToolsColoringAlgorithm end
88

99
"""
10-
matrix_colors(A, alg::ColoringAlgorithm = GreedyD1Color())
10+
matrix_colors(A, alg::ColoringAlgorithm = GreedyD1Color();
11+
partition_by_rows::Bool = false)
1112
1213
Return the colorvec vector for the matrix A using the chosen coloring
1314
algorithm. If a known analytical solution exists, that is used instead.
1415
The coloring defaults to a greedy distance-1 coloring.
1516
1617
Note that if A isa SparseMatrixCSC, the sparsity pattern is defined by structural nonzeroes,
1718
ie includes explicitly stored zeros.
19+
20+
If `ArrayInterface.fast_matrix_colors(A)` is true, then uses
21+
`ArrayInterface.matrix_colors(A)` to compute the matrix colors.
1822
"""
1923
function ArrayInterface.matrix_colors(A::AbstractMatrix,
2024
alg::SparseDiffToolsColoringAlgorithm = GreedyD1Color();
2125
partition_by_rows::Bool = false)
26+
27+
# If fast algorithm for matrix coloring exists use that
28+
if !partition_by_rows
29+
ArrayInterface.fast_matrix_colors(A) && return ArrayInterface.matrix_colors(A)
30+
else
31+
A_ = A'
32+
ArrayInterface.fast_matrix_colors(A_) && return ArrayInterface.matrix_colors(A_)
33+
end
34+
2235
_A = A isa SparseMatrixCSC ? A : sparse(A) # Avoid the copy
2336
A_graph = matrix2graph(_A, partition_by_rows)
2437
return color_graph(A_graph, alg)

src/differentiation/compute_jacobian_ad.jl

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -312,16 +312,10 @@ function forwarddiff_color_jacobian_immutable(f, x::AbstractArray{<:Number},
312312
return J
313313
end
314314

315-
function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
316-
f,
317-
x::AbstractArray{<:Number};
318-
dx = similar(x, size(J, 1)),
319-
colorvec = 1:length(x),
320-
sparsity = ArrayInterface.has_sparsestruct(J) ? J :
321-
nothing)
322-
forwarddiff_color_jacobian!(J, f, x,
323-
ForwardColorJacCache(f, x, dx = dx, colorvec = colorvec,
324-
sparsity = sparsity))
315+
function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number}, f,
316+
x::AbstractArray{<:Number}; dx = similar(x, size(J, 1)), colorvec = 1:length(x),
317+
sparsity = ArrayInterface.has_sparsestruct(J) ? J : nothing)
318+
forwarddiff_color_jacobian!(J, f, x, ForwardColorJacCache(f, x; dx, colorvec, sparsity))
325319
end
326320

327321
function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},

src/highlevel/coloring.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
struct MatrixColoringResult{C, J, NR, NC}
2+
colorvec::C
3+
jacobian_sparsity::J
4+
nz_rows::NR
5+
nz_cols::NC
6+
end
7+
8+
struct NoMatrixColoring end
9+
10+
# Using Non-Sparse AD / NoSparsityDetection results in NoMatrixColoring
11+
(::NoSparsityDetection)(::AbstractADType, args...; kwargs...) = NoMatrixColoring()
12+
13+
## If no specialization is available, we don't perform sparsity detection
14+
(::AbstractMaybeSparsityDetection)(::AbstractADType, args...; kws...) = NoMatrixColoring()
15+
16+
# Prespecified Jacobian Structure
17+
function (alg::JacPrototypeSparsityDetection)(ad::AbstractSparseADType, args...; kwargs...)
18+
J = alg.jac_prototype
19+
reverse_mode = ad isa AbstractSparseReverseMode
20+
colorvec = matrix_colors(J, alg.alg; partition_by_rows = reverse_mode)
21+
(nz_rows, nz_cols) = collect.(ArrayInterface.findstructralnz(J))
22+
return MatrixColoringResult(colorvec, J, nz_rows, nz_cols)
23+
end
24+
25+
# TODO: Heuristics to decide whether to use Sparse Differentiation or not
26+
# Simple Idea: Check min(max(colorvec_cols), max(colorvec_rows))

0 commit comments

Comments
 (0)