Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 21a9a1e

Browse files
committed
overload functions from main
1 parent 5cc3111 commit 21a9a1e

1 file changed

Lines changed: 14 additions & 12 deletions

File tree

ext/SparseDiffToolsZygote.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ if isdefined(Base, :get_extension)
44
import Zygote
55
using LinearAlgebra
66
using SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac
7-
using ForwardDiff: ForwardDiff, Dual
7+
using ForwardDiff: ForwardDiff, Dual, partials
88
using SciMLOperators: FunctionOperator
99
using Tricks: static_hasmethod
1010
else
@@ -18,7 +18,7 @@ end
1818

1919
### Jac, Hes products
2020

21-
function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
21+
function SparseDiffTools.numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
2222
g = let f = f
2323
(dx, x) -> dx .= first(Zygote.gradient(f, x))
2424
end
@@ -32,7 +32,7 @@ function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
3232
@. dy = (cache1 - cache2) / (2ϵ)
3333
end
3434

35-
function numback_hesvec(f, x, v)
35+
function SparseDiffTools.numback_hesvec(f, x, v)
3636
g = x -> first(Zygote.gradient(f, x))
3737
T = eltype(x)
3838
# Should it be min? max? mean?
@@ -44,7 +44,7 @@ function numback_hesvec(f, x, v)
4444
(gxp - gxm) / (2ϵ)
4545
end
4646

47-
function autoback_hesvec!(dy, f, x, v,
47+
function SparseDiffTools.autoback_hesvec!(dy, f, x, v,
4848
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))),
4949
eltype(x), 1
5050
}.(x,
@@ -62,7 +62,7 @@ function autoback_hesvec!(dy, f, x, v,
6262
dy .= partials.(cache2, 1)
6363
end
6464

65-
function autoback_hesvec(f, x, v)
65+
function SparseDiffTools.autoback_hesvec(f, x, v)
6666
g = x -> first(Zygote.gradient(f, x))
6767
y = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1
6868
}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
@@ -71,7 +71,7 @@ end
7171

7272
# Operator Forms
7373

74-
function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
74+
function SparseDiffTools.ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
7575

7676
if autodiff
7777
cache1 = Dual{
@@ -85,8 +85,8 @@ function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff =
8585

8686
cache = (cache1, cache2,)
8787

88-
vecprod = autodiff ? autoback_hesvec : numback_hesvec
89-
vecprod! = autodiff ? autoback_hesvec! : numback_hesvec!
88+
vecprod = autodiff ? SparseDiffTools.autoback_hesvec : SparseDiffTools.numback_hesvec
89+
vecprod! = autodiff ? SparseDiffTools.autoback_hesvec! : SparseDiffTools.numback_hesvec!
9090

9191
outofplace = static_hasmethod(f, typeof((u,)))
9292
isinplace = static_hasmethod(f, typeof((u,)))
@@ -105,16 +105,18 @@ end
105105

106106
## VecJac products
107107

108-
function auto_vecjac!(du, f, x, v, cache1 = nothing, cache2 = nothing)
108+
function SparseDiffTools.auto_vecjac!(du, f, x, v, cache1 = nothing, cache2 = nothing)
109109
!hasmethod(f, (typeof(x),)) && error("For inplace function use autodiff = false")
110-
du .= reshape(auto_vecjac(f, x, v), size(du))
110+
du .= reshape(SparseDiffTools.auto_vecjac(f, x, v), size(du))
111111
end
112112

113-
function auto_vecjac(f, x, v)
113+
function SparseDiffTools.auto_vecjac(f, x, v)
114114
vv, back = Zygote.pullback(f, x)
115115
return vec(back(reshape(v, size(vv)))[1])
116116
end
117117

118-
ZygoteVecJac(args...; autodiff = true, kwargs...) = VecJac(args...; autodiff = autodiff, kwargs...)
118+
function SparseDiffTools.ZygoteVecJac(args...; autodiff = true, kwargs...)
119+
VecJac(args...; autodiff = autodiff, kwargs...)
120+
end
119121

120122
end # module

0 commit comments

Comments
 (0)