This repository was archived by the owner on Aug 22, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 43
Expand file tree
/
Copy pathreverse_mode.jl
More file actions
73 lines (67 loc) · 2.82 KB
/
reverse_mode.jl
File metadata and controls
73 lines (67 loc) · 2.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
struct ReverseModeJacobianCache{CO, CA, J, FX, X, I} <: AbstractMaybeSparseJacobianCache
coloring::CO
cache::CA
jac_prototype::J
fx::FX
x::X
idx_vec::I
end
__test_backend_loaded(ad::ADTypes.AutoSparseReverseDiff) = nothing
__getfield(c::ReverseModeJacobianCache, ::Val{:jac_prototype}) = c.jac_prototype
function sparse_jacobian_cache(ad::Union{AutoEnzyme, AbstractReverseMode},
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
fx = fx === nothing ? similar(f(x)) : fx
coloring_result = sd(ad, f, x)
jac_prototype = __getfield(coloring_result, Val(:jacobian_sparsity))
return ReverseModeJacobianCache(coloring_result, nothing, jac_prototype, fx, x,
collect(1:length(fx)))
end
function sparse_jacobian_cache(ad::Union{AutoEnzyme, AbstractReverseMode},
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
coloring_result = sd(ad, f!, fx, x)
jac_prototype = __getfield(coloring_result, Val(:jacobian_sparsity))
return ReverseModeJacobianCache(coloring_result, nothing, jac_prototype, fx, x,
collect(1:length(fx)))
end
function sparse_jacobian!(J::AbstractMatrix, ad, cache::ReverseModeJacobianCache, args...)
if cache.coloring isa NoMatrixColoring
__test_backend_loaded(ad)
return __jacobian!(J, ad, args...)
else
return __sparse_jacobian_reverse_impl!(J, ad, cache.idx_vec, cache.coloring,
args...)
end
end
function __sparse_jacobian_reverse_impl!(J::AbstractMatrix, ad, idx_vec,
cache::MatrixColoringResult, f::F, x) where {F}
return __sparse_jacobian_reverse_impl!(J, ad, idx_vec, cache, f, nothing, x)
end
function __sparse_jacobian_reverse_impl!(J::AbstractMatrix, ad, idx_vec,
cache::MatrixColoringResult, f::F, fx, x) where {F}
# If `fx` is `nothing` then assume `f` is not in-place
__test_backend_loaded(ad)
x_ = __maybe_copy_x(ad, x)
fx_ = __maybe_copy_x(ad, fx)
@unpack colorvec, nz_rows, nz_cols = cache
for c in 1:maximum(colorvec)
@. idx_vec = colorvec == c
if fx === nothing
gs = __gradient(ad, f, x_, idx_vec)
else
gs = __gradient!(ad, f, fx_, x_, idx_vec)
end
pick_idxs = filter(i -> colorvec[nz_rows[i]] == c, 1:length(nz_rows))
row_idxs = nz_rows[pick_idxs]
col_idxs = nz_cols[pick_idxs]
len_cols = length(col_idxs)
unused_cols = setdiff(1:size(J, 2), col_idxs)
perm_cols = sortperm(vcat(col_idxs, unused_cols))
row_idxs = vcat(row_idxs, zeros(Int, size(J, 2) - len_cols))[perm_cols]
# FIXME: Assumes fast scalar indexing currently. Very easy to write a kernel to do
# this in parallel using KA.jl.
for i in axes(J, 1), j in axes(J, 2)
i == row_idxs[j] && (J[i, j] = gs[j])
end
end
return J
end