Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit f322ec7

Browse files
committed
Zygote
1 parent 4c760e8 commit f322ec7

2 files changed

Lines changed: 8 additions & 2 deletions

File tree

src/differentiation/vecjac_products.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,13 @@ function (L::RevModeAutoDiffVecProd{ad,true,false})(dv, v, p, t) where{ad}
8686
L.vecprod!(dv, (_du, _u) -> L.f(_du, _u, p, t), L.u, v, L.cache...)
8787
end
8888

89-
function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true,
89+
function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = false,
9090
ishermitian = false, opnrom = true)
9191

92+
if autodiff
93+
@assert isdefined(SparseDiffTools, :auto_vecjac) "Please load Zygote with `using Zygote`, or `import Zygote` to use VecJac with `autodiff = true`."
94+
end
95+
9296
cache = (similar(u), similar(u),)
9397

9498
vecprod = autodiff ? auto_vecjac : num_vecjac

src/differentiation/vecjac_products_zygote.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,7 @@ function auto_vecjac(f, x, v)
88
return vec(back(reshape(v, size(vv)))[1])
99
end
1010

11-
const ZygoteVecJac = VecJac
11+
#ZygoteVecJac = VecJac
12+
ZygoteVecJac(args...; autodiff = true, kwargs...) = VecJac(args...; autodiff = autodiff, kwargs...)
13+
1214
#

0 commit comments

Comments
 (0)