Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 9169f5b

Browse files
committed
Implement ability to autodiff oop function and store in user-provided Jacobian
1 parent 9637980 commit 9169f5b

2 files changed

Lines changed: 36 additions & 3 deletions

File tree

src/differentiation/compute_jacobian_ad.jl

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,30 @@ end
7878
forwarddiff_color_jacobian(f,x,ForwardColorJacCache(f,x,chunksize,dx=dx,colorvec=colorvec,sparsity=sparsity),jac_prototype)
7979
end
8080

81+
@inline function forwarddiff_color_jacobian(J::AbstractArray{<:Number}, f,
82+
x::AbstractArray{<:Number};
83+
colorvec = 1:length(x),
84+
sparsity = nothing,
85+
jac_prototype = nothing,
86+
chunksize = nothing) # Note no dx keyword b/c can infer Jacobian's size via J
87+
if sparsity === nothing && jac_prototype === nothing || !ArrayInterface.ismutable(x)
88+
cfg = chunksize === nothing ? ForwardDiff.JacobianConfig(f, x) : ForwardDiff.JacobianConfig(f, x, ForwardDiff.Chunk(getsize(chunksize)))
89+
return ForwardDiff.jacobian(f, x, cfg)
90+
end
91+
forwarddiff_color_jacobian(J,f,x,ForwardColorJacCache(f,x,chunksize,dx=similar(x,size(J, 1)),colorvec=colorvec,sparsity=sparsity),jac_prototype)
92+
end
93+
8194
function forwarddiff_color_jacobian(f,x::AbstractArray{<:Number},jac_cache::ForwardColorJacCache,jac_prototype=nothing)
95+
dx = jac_cache.dx
96+
vecx = vec(x)
97+
sparsity = jac_cache.sparsity
98+
99+
J = jac_prototype isa Nothing ? (sparsity isa Nothing ? false .* vec(dx) .* vecx' : zeros(eltype(x),size(sparsity))) : zero(jac_prototype)
100+
101+
forwarddiff_color_jacobian(J, f, x, jac_cache, jac_prototype)
102+
end
103+
104+
function forwarddiff_color_jacobian(J::AbstractArray{<:Number},f,x::AbstractArray{<:Number},jac_cache::ForwardColorJacCache,jac_prototype=nothing)
82105
t = jac_cache.t
83106
dx = jac_cache.dx
84107
p = jac_cache.p
@@ -90,7 +113,6 @@ function forwarddiff_color_jacobian(f,x::AbstractArray{<:Number},jac_cache::Forw
90113

91114
vecx = vec(x)
92115

93-
J = jac_prototype isa Nothing ? (sparsity isa Nothing ? false .* vec(dx) .* vecx' : zeros(eltype(x),size(sparsity))) : zero(jac_prototype)
94116
nrows,ncols = size(J)
95117

96118
if !(sparsity isa Nothing)
@@ -118,7 +140,8 @@ function forwarddiff_color_jacobian(f,x::AbstractArray{<:Number},jac_cache::Forw
118140
cols_index_c = vcat(cols_index_c,zeros(Int,nrows-len_rows))[perm_rows]
119141
Ji = [j==cols_index_c[i] ? dx[i] : false for i in 1:nrows, j in 1:ncols]
120142
end
121-
J = J + Ji
143+
# J = J + Ji
144+
J .+= Ji
122145
color_i += 1
123146
(color_i > maxcolor) && return J
124147
end
@@ -127,7 +150,8 @@ function forwarddiff_color_jacobian(f,x::AbstractArray{<:Number},jac_cache::Forw
127150
col_index = (i-1)*chunksize + j
128151
(col_index > ncols) && return J
129152
Ji = mapreduce(i -> i==col_index ? partials.(vec(fx), j) : adapt(parameterless_type(J),zeros(eltype(J),nrows)), hcat, 1:ncols)
130-
J = J + (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
153+
# J = J + (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
154+
J .+= (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
131155
end
132156
end
133157
end

test/test_ad.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,12 @@ f(x) = x
238238
J = forwarddiff_color_jacobian(f,x)
239239
@test J isa SArray
240240
@test J SMatrix{1,1}([1.])
241+
242+
@info "6"
243+
#oop with in-place Jacobian
244+
fcalls = 0
245+
_oop_jacout = spzeros(size(J)...)
246+
forwarddiff_color_jacobian(_oop_jacout, oopf, x; colorvec = repeat(1:3,10), sparsity = _J, jac_prototype = _J)
247+
@test _oop_jacout J
248+
@test typeof(_oop_jacout) == typeof(_J)
249+
@test fcalls == 1

0 commit comments

Comments
 (0)