Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 079a63b

Browse files
Merge pull request #218 from vpuri3/zygote-ext
Zygote ext
2 parents b75289e + 21a9a1e commit 079a63b

4 files changed

Lines changed: 82 additions & 35 deletions

File tree

Project.toml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseDiffTools"
22
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
33
authors = ["Pankaj Mishra <pankajmishra1511@gmail.com>", "Chris Rackauckas <contact@chrisrackauckas.com>"]
4-
version = "2.00.0"
4+
version = "2.0.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -12,16 +12,24 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1212
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1313
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
15+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1516
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1617
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
1718
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1819
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
20+
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
1921
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
2022
VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"
2123

24+
[weakdeps]
25+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
26+
27+
[extensions]
28+
SparseDiffToolsZygote = "Zygote"
29+
2230
[compat]
2331
Adapt = "3.0"
24-
ArrayInterface = "6, 7"
32+
ArrayInterface = "7"
2533
Compat = "4"
2634
DataStructures = "0.18"
2735
FiniteDiff = "2.8.1"
@@ -30,8 +38,10 @@ Graphs = "1"
3038
Requires = "1"
3139
SciMLOperators = "0.1.19, 0.2"
3240
StaticArrays = "1"
41+
StaticArrayInterface = "1.3"
3342
Tricks = "0.1.6"
3443
VertexSafeGraphs = "0.2"
44+
Zygote = "0.6"
3545
julia = "1.6"
3646

3747
[extras]

src/differentiation/jaches_products_zygote.jl renamed to ext/SparseDiffToolsZygote.jl

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,24 @@
1-
function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
1+
module SparseDiffToolsZygote
2+
3+
if isdefined(Base, :get_extension)
4+
import Zygote
5+
using LinearAlgebra
6+
using SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac
7+
using ForwardDiff: ForwardDiff, Dual, partials
8+
using SciMLOperators: FunctionOperator
9+
using Tricks: static_hasmethod
10+
else
11+
import ..Zygote
12+
using ..LinearAlgebra
13+
using ..SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac
14+
using ..ForwardDiff: ForwardDiff, Dual, partials
15+
using ..SciMLOperators: FunctionOperator
16+
using ..Tricks: static_hasmethod
17+
end
18+
19+
### Jac, Hes products
20+
21+
function SparseDiffTools.numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
222
g = let f = f
323
(dx, x) -> dx .= first(Zygote.gradient(f, x))
424
end
@@ -12,7 +32,7 @@ function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
1232
@. dy = (cache1 - cache2) / (2ϵ)
1333
end
1434

15-
function numback_hesvec(f, x, v)
35+
function SparseDiffTools.numback_hesvec(f, x, v)
1636
g = x -> first(Zygote.gradient(f, x))
1737
T = eltype(x)
1838
# Should it be min? max? mean?
@@ -24,7 +44,7 @@ function numback_hesvec(f, x, v)
2444
(gxp - gxm) / (2ϵ)
2545
end
2646

27-
function autoback_hesvec!(dy, f, x, v,
47+
function SparseDiffTools.autoback_hesvec!(dy, f, x, v,
2848
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))),
2949
eltype(x), 1
3050
}.(x,
@@ -42,16 +62,16 @@ function autoback_hesvec!(dy, f, x, v,
4262
dy .= partials.(cache2, 1)
4363
end
4464

45-
function autoback_hesvec(f, x, v)
65+
function SparseDiffTools.autoback_hesvec(f, x, v)
4666
g = x -> first(Zygote.gradient(f, x))
4767
y = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1
4868
}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
4969
ForwardDiff.partials.(g(y), 1)
5070
end
5171

52-
### Operator Forms
72+
# Operator Forms
5373

54-
function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
74+
function SparseDiffTools.ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
5575

5676
if autodiff
5777
cache1 = Dual{
@@ -65,8 +85,8 @@ function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff =
6585

6686
cache = (cache1, cache2,)
6787

68-
vecprod = autodiff ? autoback_hesvec : numback_hesvec
69-
vecprod! = autodiff ? autoback_hesvec! : numback_hesvec!
88+
vecprod = autodiff ? SparseDiffTools.autoback_hesvec : SparseDiffTools.numback_hesvec
89+
vecprod! = autodiff ? SparseDiffTools.autoback_hesvec! : SparseDiffTools.numback_hesvec!
7090

7191
outofplace = static_hasmethod(f, typeof((u,)))
7292
isinplace = static_hasmethod(f, typeof((u,)))
@@ -82,4 +102,21 @@ function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff =
82102
p = p, t = t, islinear = true,
83103
)
84104
end
85-
#
105+
106+
## VecJac products
107+
108+
function SparseDiffTools.auto_vecjac!(du, f, x, v, cache1 = nothing, cache2 = nothing)
109+
!hasmethod(f, (typeof(x),)) && error("For inplace function use autodiff = false")
110+
du .= reshape(SparseDiffTools.auto_vecjac(f, x, v), size(du))
111+
end
112+
113+
function SparseDiffTools.auto_vecjac(f, x, v)
114+
vv, back = Zygote.pullback(f, x)
115+
return vec(back(reshape(v, size(vv)))[1])
116+
end
117+
118+
function SparseDiffTools.ZygoteVecJac(args...; autodiff = true, kwargs...)
119+
VecJac(args...; autodiff = autodiff, kwargs...)
120+
end
121+
122+
end # module

src/SparseDiffTools.jl

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ using FiniteDiff
55
using ForwardDiff
66
using Graphs
77
using Graphs: SimpleGraph
8-
using Requires
98
using VertexSafeGraphs
109
using Adapt
1110

@@ -21,7 +20,7 @@ using ArrayInterface: matrix_colors
2120

2221
using SciMLOperators
2322
import SciMLOperators: update_coefficients, update_coefficients!
24-
using Tricks: static_hasmethod
23+
using Tricks: Tricks, static_hasmethod
2524

2625
abstract type AbstractAutoDiffVecProd end
2726

@@ -69,16 +68,31 @@ Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
6968
parameterless_type(x) = parameterless_type(typeof(x))
7069
parameterless_type(x::Type) = __parameterless_type(x)
7170

72-
function __init__()
73-
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
74-
export numback_hesvec, numback_hesvec!,
75-
autoback_hesvec, autoback_hesvec!,
76-
auto_vecjac, auto_vecjac!,
77-
ZygoteVecJac, ZygoteHesVec
71+
import Requires
72+
import Reexport
7873

79-
include("differentiation/vecjac_products_zygote.jl")
80-
include("differentiation/jaches_products_zygote.jl")
74+
function numback_hesvec end
75+
function numback_hesvec! end
76+
function autoback_hesvec end
77+
function autoback_hesvec! end
78+
function auto_vecjac end
79+
function auto_vecjac! end
80+
function ZygoteVecJac end
81+
function ZygoteHesVec end
82+
83+
@static if !isdefined(Base, :get_extension)
84+
function __init__()
85+
Requires.@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
86+
include("../ext/SparseDiffToolsZygote.jl")
87+
Reexport.@reexport using .SparseDiffToolsZygote
88+
end
8189
end
8290
end
8391

92+
export
93+
numback_hesvec, numback_hesvec!,
94+
autoback_hesvec, autoback_hesvec!,
95+
auto_vecjac, auto_vecjac!,
96+
ZygoteVecJac, ZygoteHesVec
97+
8498
end # module

src/differentiation/vecjac_products_zygote.jl

Lines changed: 0 additions & 14 deletions
This file was deleted.

0 commit comments

Comments
 (0)