6868 jac_prototype = nothing ,
6969 chunksize = nothing ,
7070 dx = sparsity === nothing && jac_prototype === nothing ? nothing : copy (x)) # if dx is nothing, we will estimate dx at the cost of a function call
71+ @show typeof (x)
72+
7173 if sparsity === nothing && jac_prototype === nothing || ! ArrayInterface. ismutable (x)
7274 cfg = chunksize === nothing ? ForwardDiff. JacobianConfig (f, x) : ForwardDiff. JacobianConfig (f, x, ForwardDiff. Chunk (getsize (chunksize)))
7375 return ForwardDiff. jacobian (f, x, cfg)
7476 end
7577 if dx isa Nothing
7678 dx = f (x)
7779 end
80+ @show " Line 80"
7881 forwarddiff_color_jacobian (f,x,ForwardColorJacCache (f,x,chunksize,dx= dx,colorvec= colorvec,sparsity= sparsity),jac_prototype)
7982end
8083
8487 sparsity = nothing ,
8588 jac_prototype = nothing ,
8689 chunksize = nothing ,
87- dx = similar (x, size (J, 1 ))) # if dx is nothing, we will estimate dx at the cost of a function call
90+ dx = similar (x, size (J, 1 ))) # dx kwarg can be used to avoid re-allocating dx every time
8891 if sparsity === nothing && jac_prototype === nothing || ! ArrayInterface. ismutable (x)
8992 cfg = chunksize === nothing ? ForwardDiff. JacobianConfig (f, x) : ForwardDiff. JacobianConfig (f, x, ForwardDiff. Chunk (getsize (chunksize)))
9093 return ForwardDiff. jacobian (f, x, cfg)
9194 end
95+ @show " Line 95"
9296 forwarddiff_color_jacobian (J,f,x,ForwardColorJacCache (f,x,chunksize,dx= dx,colorvec= colorvec,sparsity= sparsity),jac_prototype)
9397end
9498
@@ -99,10 +103,18 @@ function forwarddiff_color_jacobian(f,x::AbstractArray{<:Number},jac_cache::Forw
99103
100104 J = jac_prototype isa Nothing ? (sparsity isa Nothing ? false .* vec (dx) .* vecx' : zeros (eltype (x),size (sparsity))) : zero (jac_prototype)
101105
102- forwarddiff_color_jacobian (J, f, x, jac_cache, jac_prototype)
106+ @show typeof (J)
107+ if ArrayInterface. ismutable (J) # Whenever J is mutable, we mutate it to avoid allocations
108+ @show " Line 108"
109+ forwarddiff_color_jacobian (J, f, x, jac_cache, jac_prototype)
110+ else
111+ @show " Line 111"
112+ forwarddiff_color_jacobian_immutable (J, f, x, jac_cache, jac_prototype)
113+ end
103114end
104115
105- function forwarddiff_color_jacobian (J:: AbstractArray{<:Number} ,f,x:: AbstractArray{<:Number} ,jac_cache:: ForwardColorJacCache ,jac_prototype= nothing )
116+ # When J is mutable, this version of forwarddiff_color_jacobian will mutate J to avoid allocations
117+ function forwarddiff_color_jacobian (J:: AbstractMatrix{<:Number} ,f,x:: AbstractArray{<:Number} ,jac_cache:: ForwardColorJacCache )
106118 t = jac_cache. t
107119 dx = jac_cache. dx
108120 p = jac_cache. p
@@ -141,7 +153,11 @@ function forwarddiff_color_jacobian(J::AbstractArray{<:Number},f,x::AbstractArra
141153 cols_index_c = vcat (cols_index_c,zeros (Int,nrows- len_rows))[perm_rows]
142154 Ji = [j== cols_index_c[i] ? dx[i] : false for i in 1 : nrows, j in 1 : ncols]
143155 end
144- J = J + Ji
156+ if j == 1 && i == 1
157+ J .= Ji # overwrite pre-allocated matrix
158+ else
159+ J .+ = Ji
160+ end
145161 color_i += 1
146162 (color_i > maxcolor) && return J
147163 end
@@ -150,14 +166,19 @@ function forwarddiff_color_jacobian(J::AbstractArray{<:Number},f,x::AbstractArra
150166 col_index = (i- 1 )* chunksize + j
151167 (col_index > ncols) && return J
152168 Ji = mapreduce (i -> i== col_index ? partials .(vec (fx), j) : adapt (parameterless_type (J),zeros (eltype (J),nrows)), hcat, 1 : ncols)
153- J = J + (size (Ji)!= size (J) ? reshape (Ji,size (J)) : Ji) # branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
169+ if j == 1 && i == 1
170+ J .= (size (Ji)!= size (J) ? reshape (Ji,size (J)) : Ji) # overwrite pre-allocated matrix
171+ else
172+ J .+ = (size (Ji)!= size (J) ? reshape (Ji,size (J)) : Ji) # branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
173+ end
154174 end
155175 end
156176 end
157177 J
158178end
159179
160- function forwarddiff_color_jacobian (J:: SparseMatrixCSC{<:Number} ,f,x:: AbstractArray{<:Number} ,jac_cache:: ForwardColorJacCache ,jac_prototype= nothing )
180+ # When J is immutable, this version of forwarddiff_color_jacobian will avoid mutating J
181+ function forwarddiff_color_jacobian_immutable (J:: AbstractArray{<:Number} ,f,x:: AbstractArray{<:Number} ,jac_cache:: ForwardColorJacCache )
161182 t = jac_cache. t
162183 dx = jac_cache. dx
163184 p = jac_cache. p
@@ -187,13 +208,16 @@ function forwarddiff_color_jacobian(J::SparseMatrixCSC{<:Number},f,x::AbstractAr
187208 pick_inds = [i for i in 1 : length (rows_index) if colorvec[cols_index[i]] == color_i]
188209 rows_index_c = rows_index[pick_inds]
189210 cols_index_c = cols_index[pick_inds]
190- Ji = sparse (rows_index_c, cols_index_c, dx[rows_index_c],nrows,ncols)
191- # J = J + Ji
192- if j == 1 && i == 1
193- J .= Ji # overwrite pre-allocated matrix
211+ if J isa SparseMatrixCSC
212+ Ji = sparse (rows_index_c, cols_index_c, dx[rows_index_c],nrows,ncols)
194213 else
195- J .+ = Ji
214+ len_rows = length (pick_inds)
215+ unused_rows = setdiff (1 : nrows,rows_index_c)
216+ perm_rows = sortperm (vcat (rows_index_c,unused_rows))
217+ cols_index_c = vcat (cols_index_c,zeros (Int,nrows- len_rows))[perm_rows]
218+ Ji = [j== cols_index_c[i] ? dx[i] : false for i in 1 : nrows, j in 1 : ncols]
196219 end
220+ J = J + Ji
197221 color_i += 1
198222 (color_i > maxcolor) && return J
199223 end
@@ -202,12 +226,7 @@ function forwarddiff_color_jacobian(J::SparseMatrixCSC{<:Number},f,x::AbstractAr
202226 col_index = (i- 1 )* chunksize + j
203227 (col_index > ncols) && return J
204228 Ji = mapreduce (i -> i== col_index ? partials .(vec (fx), j) : adapt (parameterless_type (J),zeros (eltype (J),nrows)), hcat, 1 : ncols)
205- # J = J + (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
206- if j == 1 && i == 1
207- J .= (size (Ji)!= size (J) ? reshape (Ji,size (J)) : Ji) # overwrite pre-allocated matrix
208- else
209- J .+ = (size (Ji)!= size (J) ? reshape (Ji,size (J)) : Ji) # branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
210- end
229+ J = J + (size (Ji)!= size (J) ? reshape (Ji,size (J)) : Ji) # branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
211230 end
212231 end
213232 end
0 commit comments