@@ -267,24 +267,16 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
267267
268268 fill! (J, zero (eltype (J)))
269269
270- # fast path if J and sparsity are both SparseMatrixCSC and have the same number of columns and stored values
271- sparseCSC_common = (J isa SparseMatrixCSC &&
272- sparsity isa SparseMatrixCSC &&
273- length (J. colptr) == length (sparsity. colptr) &&
274- length (J. nzval) == length (sparsity. nzval))
275-
276- if sparseCSC_common
277- J. colptr .= sparsity. colptr
278- J. rowval .= sparsity. rowval
279- end
280-
281270 if FiniteDiff. _use_findstructralnz (sparsity)
282271 rows_index, cols_index = ArrayInterface. findstructralnz (sparsity)
283272 else
284273 rows_index = 1 : size (J,1 )
285274 cols_index = 1 : size (J,2 )
286275 end
287276
277+ # fast path if J and sparsity are both SparseMatrixCSC and have the same number of columns and stored values
278+ sparseCSC_common_sparsity = FiniteDiff. _use_sparseCSC_common_sparsity! (J, sparsity)
279+
288280 vecx = vec (x)
289281 vect = vec (t)
290282 vecfx= vec (fx)
@@ -302,8 +294,8 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
302294
303295 if ArrayInterface. fast_scalar_indexing (dx)
304296 # dx is implicitly used in vecdx
305- if sparseCSC_common
306- _colorediteration! (J,vecdx,colorvec,color_i,ncols)
297+ if sparseCSC_common_sparsity
298+ FiniteDiff . _colorediteration! (J,vecdx,colorvec,color_i,ncols)
307299 else
308300 FiniteDiff. _colorediteration! (J,sparsity,rows_index,cols_index,vecdx,colorvec,color_i,ncols)
309301 end
@@ -334,16 +326,5 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
334326 return J
335327end
336328
337- # fast version of FiniteDiff._colorediteration! for the case where J and sparsity have the same sparsity pattern
338- @inline function _colorediteration! (Jsparsity:: SparseMatrixCSC ,vfx,colorvec,color_i,ncols)
339- @inbounds for col_index in 1 : ncols
340- if colorvec[col_index] == color_i
341- @inbounds for spidx in nzrange (Jsparsity, col_index)
342- row_index = Jsparsity. rowval[spidx]
343- Jsparsity. nzval[spidx]= vfx[row_index]
344- end
345- end
346- end
347- end
348329
349330
0 commit comments