1+ module SparseDiffToolsEnzymeExt
2+
3+ import ArrayInterface: fast_scalar_indexing
4+ import SparseDiffTools: __f̂, __jacobian!, __gradient, __gradient!, AutoSparseEnzyme
5+ # FIXME : For Enzyme we currently assume reverse mode
6+ import ADTypes: AutoEnzyme
7+ using Enzyme
8+
9+ using ForwardDiff
10+
11+ # # Satisfying High-Level Interface for Sparse Jacobians
12+ function __gradient (:: Union{AutoSparseEnzyme, AutoEnzyme} , f, x, cols)
13+ dx = zero (x)
14+ autodiff (Reverse, __f̂, Const (f), Duplicated (x, dx), Const (cols))
15+ return vec (dx)
16+ end
17+
18+ function __gradient! (:: Union{AutoSparseEnzyme, AutoEnzyme} , f!, fx, x, cols)
19+ dx = zero (x)
20+ dfx = zero (fx)
21+ autodiff (Reverse, __f̂, Active, Const (f!), Duplicated (fx, dfx), Duplicated (x, dx),
22+ Const (cols))
23+ return dx
24+ end
25+
26+ function __jacobian! (J:: AbstractMatrix , :: Union{AutoSparseEnzyme, AutoEnzyme} , f, x)
27+ J .= jacobian (Reverse, f, x, Val (size (J, 1 )))
28+ return J
29+ end
30+
31+ @views function __jacobian! (J, ad:: Union{AutoSparseEnzyme, AutoEnzyme} , f!, fx, x)
32+ # This version is slowish not sure how to do jacobians for inplace functions
33+ @warn " Current code for computing jacobian for inplace functions in Enzyme is slow." maxlog= 1
34+ dfx = zero (fx)
35+ dx = zero (x)
36+
37+ function __f_row_idx (f!, fx, x, row_idx)
38+ f! (fx, x)
39+ if fast_scalar_indexing (fx)
40+ return fx[row_idx]
41+ else
42+ # Avoid the GPU Arrays scalar indexing error
43+ return sum (selectdim (fx, 1 , row_idx: row_idx))
44+ end
45+ end
46+
47+ for row_idx in 1 : size (J, 1 )
48+ autodiff (Reverse, __f_row_idx, Const (f!), DuplicatedNoNeed (fx, dfx),
49+ Duplicated (x, dx), Const (row_idx))
50+ J[row_idx, :] .= dx
51+ fill! (dfx, 0 )
52+ fill! (dx, 0 )
53+ end
54+
55+ return J
56+ end
57+
58+ end
0 commit comments