Skip to content

Commit 689b74a

Browse files
Add batch/vectorized finite difference Jacobian evaluation
Implements feature requested in #210: allows computing the full Jacobian in a single batched function call instead of N sequential calls. This is useful for GPU-parallelized functions that can evaluate multiple inputs simultaneously. Adds `batch=true` keyword to `finite_difference_jacobian` and `finite_difference_jacobian!`. When enabled, `f` receives a matrix where each column is an input point and returns a matrix of outputs. Supports forward, central, and complex step methods. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f37d23d commit 689b74a

2 files changed

Lines changed: 252 additions & 2 deletions

File tree

src/jacobians.jl

Lines changed: 174 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,167 @@ function _make_Ji(::AbstractArray, xtype, dx, color_i, nrows, ncols)
186186
size(Ji) != (nrows, ncols) ? reshape(Ji, (nrows, ncols)) : Ji #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
187187
end
188188

189+
"""
190+
_finite_difference_jacobian_batch(f, x, fdtype, returntype, f_in; relstep, absstep, dir)
191+
192+
Internal function implementing vectorized/batched finite difference Jacobian computation.
193+
194+
When `batch=true` is passed to `finite_difference_jacobian`, this function is called instead
195+
of the standard column-by-column approach. The function `f` is expected to accept a matrix
196+
where each column is an input point, and return a matrix where each column is the
197+
corresponding output. This allows GPU-parallelized or otherwise vectorized functions to
198+
evaluate all perturbations in a single call.
199+
200+
For forward differences, a single call to `f` is made with `n+1` columns (base point + `n`
201+
perturbations) if `f_in` is not provided, or `n` columns if `f_in` is provided.
202+
For central differences, `2n` columns are used (forward and backward perturbations).
203+
For complex step, `n` columns are used with complex perturbations.
204+
"""
205+
function _finite_difference_jacobian_batch(f, x, fdtype, returntype, f_in;
206+
relstep, absstep, dir)
207+
fdtype isa Type && (fdtype = fdtype())
208+
n = length(x)
209+
vecx = _vec(x)
210+
211+
if fdtype == Val(:forward)
212+
epsilons = [compute_epsilon(Val(:forward), vecx[i], relstep, absstep, dir) for i in 1:n]
213+
214+
if f_in isa Nothing
215+
# Include x as the first column so we only call f once
216+
X = repeat(vecx, 1, n + 1)
217+
for i in 1:n
218+
X[i, i + 1] += epsilons[i]
219+
end
220+
FX = f(X)
221+
fx_col = @view FX[:, 1]
222+
J = similar(FX, size(FX, 1), n)
223+
for i in 1:n
224+
@. J[:, i] = (FX[:, i + 1] - fx_col) / epsilons[i]
225+
end
226+
else
227+
X = repeat(vecx, 1, n)
228+
for i in 1:n
229+
X[i, i] += epsilons[i]
230+
end
231+
FX = f(X)
232+
vfx = _vec(f_in)
233+
J = similar(FX, size(FX, 1), n)
234+
for i in 1:n
235+
@. J[:, i] = (FX[:, i] - vfx) / epsilons[i]
236+
end
237+
end
238+
return J
239+
240+
elseif fdtype == Val(:central)
241+
epsilons = [compute_epsilon(Val(:central), vecx[i], relstep, absstep, dir) for i in 1:n]
242+
243+
# Build matrix with 2n columns: [x+eps1*e1, x-eps1*e1, x+eps2*e2, x-eps2*e2, ...]
244+
X = repeat(vecx, 1, 2n)
245+
for i in 1:n
246+
X[i, 2i - 1] += epsilons[i]
247+
X[i, 2i] -= epsilons[i]
248+
end
249+
FX = f(X)
250+
J = similar(FX, size(FX, 1), n)
251+
for i in 1:n
252+
@. J[:, i] = (FX[:, 2i - 1] - FX[:, 2i]) / (2 * epsilons[i])
253+
end
254+
return J
255+
256+
elseif fdtype == Val(:complex) && returntype <: Real
257+
epsilon = eps(eltype(x))
258+
259+
# Build complex matrix with n columns
260+
X = repeat(complex.(vecx), 1, n)
261+
for i in 1:n
262+
X[i, i] += im * epsilon
263+
end
264+
FX = f(X)
265+
J = similar(FX, real(eltype(FX)), size(FX, 1), n)
266+
for i in 1:n
267+
@. J[:, i] = imag(FX[:, i]) / epsilon
268+
end
269+
return J
270+
else
271+
fdtype_error(returntype)
272+
end
273+
end
274+
275+
"""
276+
_finite_difference_jacobian_batch!(J, f, x, fdtype, returntype, f_in; relstep, absstep, dir)
277+
278+
Internal in-place function implementing vectorized/batched finite difference Jacobian computation.
279+
280+
When `batch=true` is passed to `finite_difference_jacobian!`, this function is called instead
281+
of the standard column-by-column approach. The function `f` is expected to accept two matrix
282+
arguments `f(FX, X)` where `X` has columns of input points and `FX` is filled with the
283+
corresponding outputs.
284+
"""
285+
function _finite_difference_jacobian_batch!(J, f, x, fdtype, returntype, f_in;
286+
relstep, absstep, dir)
287+
fdtype isa Type && (fdtype = fdtype())
288+
m, n = size(J)
289+
vecx = _vec(x)
290+
291+
if fdtype == Val(:forward)
292+
epsilons = [compute_epsilon(Val(:forward), vecx[i], relstep, absstep, dir) for i in 1:n]
293+
294+
if f_in isa Nothing
295+
# n+1 columns: base point + n perturbations
296+
X = repeat(vecx, 1, n + 1)
297+
for i in 1:n
298+
X[i, i + 1] += epsilons[i]
299+
end
300+
FX = similar(x, m, n + 1)
301+
f(FX, X)
302+
for i in 1:n
303+
@. J[:, i] = (FX[:, i + 1] - FX[:, 1]) / epsilons[i]
304+
end
305+
else
306+
X = repeat(vecx, 1, n)
307+
for i in 1:n
308+
X[i, i] += epsilons[i]
309+
end
310+
FX = similar(x, m, n)
311+
f(FX, X)
312+
vfx = _vec(f_in)
313+
for i in 1:n
314+
@. J[:, i] = (FX[:, i] - vfx) / epsilons[i]
315+
end
316+
end
317+
318+
elseif fdtype == Val(:central)
319+
epsilons = [compute_epsilon(Val(:central), vecx[i], relstep, absstep, dir) for i in 1:n]
320+
321+
X = repeat(vecx, 1, 2n)
322+
for i in 1:n
323+
X[i, 2i - 1] += epsilons[i]
324+
X[i, 2i] -= epsilons[i]
325+
end
326+
FX = similar(x, m, 2n)
327+
f(FX, X)
328+
for i in 1:n
329+
@. J[:, i] = (FX[:, 2i - 1] - FX[:, 2i]) / (2 * epsilons[i])
330+
end
331+
332+
elseif fdtype == Val(:complex) && returntype <: Real
333+
epsilon = eps(eltype(x))
334+
335+
X = repeat(complex.(vecx), 1, n)
336+
for i in 1:n
337+
X[i, i] += im * epsilon
338+
end
339+
FX = similar(X, Complex{eltype(x)}, m, n)
340+
f(FX, X)
341+
for i in 1:n
342+
@. J[:, i] = imag(FX[:, i]) / epsilon
343+
end
344+
else
345+
fdtype_error(returntype)
346+
end
347+
nothing
348+
end
349+
189350
"""
190351
FiniteDiff.finite_difference_jacobian(
191352
f,
@@ -246,7 +407,12 @@ function finite_difference_jacobian(f, x,
246407
colorvec = 1:length(x),
247408
sparsity = nothing,
248409
jac_prototype = nothing,
249-
dir = true)
410+
dir = true,
411+
batch = false)
412+
if batch
413+
return _finite_difference_jacobian_batch(f, x, fdtype, returntype, f_in;
414+
relstep = relstep, absstep = absstep, dir = dir)
415+
end
250416
if f_in isa Nothing
251417
fx = f(x)
252418
else
@@ -443,7 +609,13 @@ function finite_difference_jacobian!(J,
443609
relstep = default_relstep(fdtype, eltype(x)),
444610
absstep = relstep,
445611
colorvec = 1:length(x),
446-
sparsity = ArrayInterface.has_sparsestruct(J) ? J : nothing)
612+
sparsity = ArrayInterface.has_sparsestruct(J) ? J : nothing,
613+
batch = false)
614+
if batch
615+
_finite_difference_jacobian_batch!(J, f, x, fdtype, returntype, f_in;
616+
relstep = relstep, absstep = absstep, dir = true)
617+
return nothing
618+
end
447619
if f_in isa Nothing && fdtype == Val(:forward)
448620
if size(J, 1) == length(x)
449621
fx = zero(x)

test/finitedifftests.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,3 +578,81 @@ end
578578
@test FiniteDiff.finite_difference_hessian(f, x1, FiniteDiff.HessianCache(x1)) == Diagonal(2*ones(4))
579579
@test FiniteDiff.finite_difference_hessian(f, x1, FiniteDiff.HessianCache(x2)) == Diagonal(2*ones(4))
580580
end
581+
582+
# Batched Jacobian tests (issue #210)
583+
@time @testset "Batched Jacobian tests" begin
584+
# Out-of-place batched function: f(X) where X is n×k, returns m×k
585+
function oopf_scalar(x)
586+
[(x[1] + 3) * (x[2]^3 - 7) + 18,
587+
sin(x[2] * exp(x[1]) - 1)]
588+
end
589+
function oopf_batch(X::AbstractMatrix)
590+
hcat([oopf_scalar(X[:, j]) for j in 1:size(X, 2)]...)
591+
end
592+
# Also handle single vector input for the batch function
593+
oopf_batch(x::AbstractVector) = oopf_scalar(x)
594+
595+
# In-place batched function: f(FX, X) where X is n×k, FX is m×k
596+
function iipf_batch(FX::AbstractMatrix, X::AbstractMatrix)
597+
for j in 1:size(X, 2)
598+
FX[1, j] = (X[1, j] + 3) * (X[2, j]^3 - 7) + 18
599+
FX[2, j] = sin(X[2, j] * exp(X[1, j]) - 1)
600+
end
601+
end
602+
603+
x = [1.5, 0.7]
604+
J_ref = [[-7 + x[2]^3 3 * (3 + x[1]) * x[2]^2];
605+
[exp(x[1]) * x[2] * cos(1 - exp(x[1]) * x[2]) exp(x[1]) * cos(1 - exp(x[1]) * x[2])]]
606+
607+
@testset "Out-of-place batch" begin
608+
J_fwd = FiniteDiff.finite_difference_jacobian(oopf_batch, x, Val{:forward}; batch=true)
609+
@test err_func(J_fwd, J_ref) < 1e-6
610+
611+
# With f_in provided
612+
f_in = oopf_scalar(x)
613+
J_fwd2 = FiniteDiff.finite_difference_jacobian(oopf_batch, x, Val{:forward}, eltype(x), f_in; batch=true)
614+
@test err_func(J_fwd2, J_ref) < 1e-6
615+
616+
J_cen = FiniteDiff.finite_difference_jacobian(oopf_batch, x, Val{:central}; batch=true)
617+
@test err_func(J_cen, J_ref) < 1e-8
618+
619+
J_cpx = FiniteDiff.finite_difference_jacobian(oopf_batch, x, Val{:complex}; batch=true)
620+
@test err_func(J_cpx, J_ref) < 1e-14
621+
end
622+
623+
@testset "In-place batch" begin
624+
J = zero(J_ref)
625+
FiniteDiff.finite_difference_jacobian!(J, iipf_batch, x, Val{:forward}; batch=true)
626+
@test err_func(J, J_ref) < 1e-6
627+
628+
# With f_in provided
629+
f_in = oopf_scalar(x)
630+
J .= 0
631+
FiniteDiff.finite_difference_jacobian!(J, iipf_batch, x, Val{:forward}, eltype(x), f_in; batch=true)
632+
@test err_func(J, J_ref) < 1e-6
633+
634+
J .= 0
635+
FiniteDiff.finite_difference_jacobian!(J, iipf_batch, x, Val{:central}; batch=true)
636+
@test err_func(J, J_ref) < 1e-8
637+
638+
J .= 0
639+
FiniteDiff.finite_difference_jacobian!(J, iipf_batch, x, Val{:complex}; batch=true)
640+
@test err_func(J, J_ref) < 1e-14
641+
end
642+
643+
@testset "Batch matches non-batch" begin
644+
# Test on a larger function to make sure batch and non-batch agree
645+
f_oop(x) = [x[1]^2 + x[2]*x[3], sin(x[1]) + x[3]^2, x[1]*x[2]*x[3]]
646+
function f_batch(X::AbstractMatrix)
647+
hcat([f_oop(X[:, j]) for j in 1:size(X, 2)]...)
648+
end
649+
f_batch(x::AbstractVector) = f_oop(x)
650+
651+
x3 = [2.0, 3.0, 1.5]
652+
for fdtype in (Val{:forward}, Val{:central}, Val{:complex})
653+
J_std = FiniteDiff.finite_difference_jacobian(f_oop, x3, fdtype)
654+
J_bat = FiniteDiff.finite_difference_jacobian(f_batch, x3, fdtype; batch=true)
655+
@test J_std J_bat
656+
end
657+
end
658+
end

0 commit comments

Comments
 (0)