@@ -2,36 +2,36 @@ module SparseDiffToolsEnzymeExt
22
33import ArrayInterface: fast_scalar_indexing
44import SparseDiffTools: __f̂, __maybe_copy_x, __jacobian!, __gradient, __gradient!,
5- AutoSparseEnzyme , __test_backend_loaded
5+ AutoSparse{ <: AutoEnzyme } , __test_backend_loaded
66# FIXME : For Enzyme we currently assume reverse mode
7- import ADTypes: AutoEnzyme
7+ import ADTypes: AutoSparse, AutoEnzyme
88using Enzyme
99
1010using ForwardDiff
1111
12- @inline __test_backend_loaded (:: Union{AutoSparseEnzyme , AutoEnzyme} ) = nothing
12+ @inline __test_backend_loaded (:: Union{AutoSparse{<:AutoEnzyme} , AutoEnzyme} ) = nothing
1313
1414# # Satisfying High-Level Interface for Sparse Jacobians
15- function __gradient (:: Union{AutoSparseEnzyme , AutoEnzyme} , f, x, cols)
15+ function __gradient (:: Union{AutoSparse{<:AutoEnzyme} , AutoEnzyme} , f, x, cols)
1616 dx = zero (x)
1717 autodiff (Reverse, __f̂, Const (f), Duplicated (x, dx), Const (cols))
1818 return vec (dx)
1919end
2020
21- function __gradient! (:: Union{AutoSparseEnzyme , AutoEnzyme} , f!, fx, x, cols)
21+ function __gradient! (:: Union{AutoSparse{<:AutoEnzyme} , AutoEnzyme} , f!, fx, x, cols)
2222 dx = zero (x)
2323 dfx = zero (fx)
2424 autodiff (Reverse, __f̂, Active, Const (f!), Duplicated (fx, dfx), Duplicated (x, dx),
2525 Const (cols))
2626 return dx
2727end
2828
29- function __jacobian! (J:: AbstractMatrix , :: Union{AutoSparseEnzyme , AutoEnzyme} , f, x)
29+ function __jacobian! (J:: AbstractMatrix , :: Union{AutoSparse{<:AutoEnzyme} , AutoEnzyme} , f, x)
3030 J .= jacobian (Reverse, f, x, Val (size (J, 1 )))
3131 return J
3232end
3333
34- @views function __jacobian! (J, ad:: Union{AutoSparseEnzyme , AutoEnzyme} , f!, fx, x)
34+ @views function __jacobian! (J, ad:: Union{AutoSparse{<:AutoEnzyme} , AutoEnzyme} , f!, fx, x)
3535 # This version is slowish not sure how to do jacobians for inplace functions
3636 @warn " Current code for computing jacobian for inplace functions in Enzyme is slow." maxlog= 1
3737 dfx = zero (fx)
5858 return J
5959end
6060
61- __maybe_copy_x (:: Union{AutoSparseEnzyme , AutoEnzyme} , x:: SubArray ) = copy (x)
61+ __maybe_copy_x (:: Union{AutoSparse{<:AutoEnzyme} , AutoEnzyme} , x:: SubArray ) = copy (x)
6262
6363end
0 commit comments