Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit f876266

Browse files
committed
zygote ext file
1 parent d989fa8 commit f876266

3 files changed

Lines changed: 47 additions & 26 deletions

File tree

src/differentiation/jaches_products_zygote.jl renamed to ext/SparseDiffToolsZygote.jl

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,24 @@
1+
module SparseDiffToolsZygote
2+
3+
import Zygote
4+
5+
using SparseDiffTools
6+
using SparseDiffTools: DeviVecTag, FwdModeAutoDiffVecProd
7+
8+
using ForwardDiff
9+
using ForwardDiff: Dual, Tag
10+
11+
using SciMLOperators: FunctionOperator
12+
using Tricks: static_hasmethod
13+
14+
export
15+
numback_hesvec, numback_hesvec!,
16+
autoback_hesvec, autoback_hesvec!,
17+
auto_vecjac, auto_vecjac!,
18+
ZygoteVecJac, ZygoteHesVec
19+
20+
### Jac, Hes products
21+
122
function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
223
g = let f = f
324
(dx, x) -> dx .= first(Zygote.gradient(f, x))
@@ -49,7 +70,7 @@ function autoback_hesvec(f, x, v)
4970
ForwardDiff.partials.(g(y), 1)
5071
end
5172

52-
### Operator Forms
73+
# Operator Forms
5374

5475
function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
5576

@@ -82,4 +103,19 @@ function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff =
82103
p = p, t = t, islinear = true,
83104
)
84105
end
85-
#
106+
107+
## VecJac products
108+
109+
function auto_vecjac!(du, f, x, v, cache1 = nothing, cache2 = nothing)
110+
!hasmethod(f, (typeof(x),)) && error("For inplace function use autodiff = false")
111+
du .= reshape(auto_vecjac(f, x, v), size(du))
112+
end
113+
114+
function auto_vecjac(f, x, v)
115+
vv, back = Zygote.pullback(f, x)
116+
return vec(back(reshape(v, size(vv)))[1])
117+
end
118+
119+
ZygoteVecJac(args...; autodiff = true, kwargs...) = VecJac(args...; autodiff = autodiff, kwargs...)
120+
121+
end # module

src/SparseDiffTools.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ using FiniteDiff
55
using ForwardDiff
66
using Graphs
77
using Graphs: SimpleGraph
8-
using Requires
98
using VertexSafeGraphs
109
using Adapt
1110

@@ -69,16 +68,16 @@ Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
6968
parameterless_type(x) = parameterless_type(typeof(x))
7069
parameterless_type(x::Type) = __parameterless_type(x)
7170

72-
function __init__()
73-
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
74-
export numback_hesvec, numback_hesvec!,
75-
autoback_hesvec, autoback_hesvec!,
76-
auto_vecjac, auto_vecjac!,
77-
ZygoteVecJac, ZygoteHesVec
71+
#if !isdefined(Base, :get_extension)
72+
using Requires
73+
#end
7874

79-
include("differentiation/vecjac_products_zygote.jl")
80-
include("differentiation/jaches_products_zygote.jl")
81-
end
75+
function __init__()
76+
#@static if !isdefined(Base, :get_extension)
77+
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
78+
include("../ext/SparseDiffToolsZygote.jl")
79+
end
80+
#end
8281
end
8382

8483
end # module

src/differentiation/vecjac_products_zygote.jl

Lines changed: 0 additions & 14 deletions
This file was deleted.

0 commit comments

Comments
 (0)