Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 9a2f088

Browse files
committed
Write separate forwarddiff_color_jacobian to autodiff oop f to an allocated matrix (rather than allocate new matrix for Jacobian)
1 parent 9169f5b commit 9a2f088

2 files changed

Lines changed: 57 additions & 9 deletions

File tree

src/differentiation/compute_jacobian_ad.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,53 @@ function forwarddiff_color_jacobian(J::AbstractArray{<:Number},f,x::AbstractArra
140140
cols_index_c = vcat(cols_index_c,zeros(Int,nrows-len_rows))[perm_rows]
141141
Ji = [j==cols_index_c[i] ? dx[i] : false for i in 1:nrows, j in 1:ncols]
142142
end
143+
J = J + Ji
144+
color_i += 1
145+
(color_i > maxcolor) && return J
146+
end
147+
else
148+
for j in 1:chunksize
149+
col_index = (i-1)*chunksize + j
150+
(col_index > ncols) && return J
151+
Ji = mapreduce(i -> i==col_index ? partials.(vec(fx), j) : adapt(parameterless_type(J),zeros(eltype(J),nrows)), hcat, 1:ncols)
152+
J = J + (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
153+
end
154+
end
155+
end
156+
J
157+
end
158+
159+
function forwarddiff_color_jacobian(J::SparseMatrixCSC{<:Number},f,x::AbstractArray{<:Number},jac_cache::ForwardColorJacCache,jac_prototype=nothing)
160+
t = jac_cache.t
161+
dx = jac_cache.dx
162+
p = jac_cache.p
163+
colorvec = jac_cache.colorvec
164+
sparsity = jac_cache.sparsity
165+
chunksize = jac_cache.chunksize
166+
color_i = 1
167+
maxcolor = maximum(colorvec)
168+
169+
vecx = vec(x)
170+
171+
nrows,ncols = size(J)
172+
173+
if !(sparsity isa Nothing)
174+
rows_index, cols_index = ArrayInterface.findstructralnz(sparsity)
175+
rows_index = [rows_index[i] for i in 1:length(rows_index)]
176+
cols_index = [cols_index[i] for i in 1:length(cols_index)]
177+
end
178+
179+
for i in eachindex(p)
180+
partial_i = p[i]
181+
t = reshape(Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}.(vecx, partial_i),size(t))
182+
fx = f(t)
183+
if !(sparsity isa Nothing)
184+
for j in 1:chunksize
185+
dx = vec(partials.(fx, j))
186+
pick_inds = [i for i in 1:length(rows_index) if colorvec[cols_index[i]] == color_i]
187+
rows_index_c = rows_index[pick_inds]
188+
cols_index_c = cols_index[pick_inds]
189+
Ji = sparse(rows_index_c, cols_index_c, dx[rows_index_c],nrows,ncols)
143190
# J = J + Ji
144191
J .+= Ji
145192
color_i += 1

test/test_ad.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,16 @@ _J1 = forwarddiff_color_jacobian(oopf, x, colorvec = repeat(1:3,10), sparsity =
115115
@test _J1 J
116116
@test fcalls == 1
117117

118+
119+
#oop with in-place Jacobian
120+
x = rand(30)
121+
fcalls = 0
122+
_oop_jacout = spzeros(size(J)...)
123+
forwarddiff_color_jacobian(_oop_jacout, oopf, x; colorvec = repeat(1:3,10), sparsity = _J, jac_prototype = _J)
124+
@test _oop_jacout J
125+
@test typeof(_oop_jacout) == typeof(_J)
126+
@test fcalls == 1
127+
118128
@info "4th passed"
119129

120130
fcalls = 0
@@ -238,12 +248,3 @@ f(x) = x
238248
J = forwarddiff_color_jacobian(f,x)
239249
@test J isa SArray
240250
@test J SMatrix{1,1}([1.])
241-
242-
@info "6"
243-
#oop with in-place Jacobian
244-
fcalls = 0
245-
_oop_jacout = spzeros(size(J)...)
246-
forwarddiff_color_jacobian(_oop_jacout, oopf, x; colorvec = repeat(1:3,10), sparsity = _J, jac_prototype = _J)
247-
@test _oop_jacout J
248-
@test typeof(_oop_jacout) == typeof(_J)
249-
@test fcalls == 1

0 commit comments

Comments
 (0)