Skip to content

Commit 7d83628

Browse files
Support immutable output types in out-of-place scalar→vector gradient
`finite_difference_gradient` (out-of-place, cached) for scalar `x` previously delegated to `finite_difference_gradient!` which uses `@. df = result / epsilon`. This in-place broadcast fails when the output buffer contains immutable array types (e.g. `ArrayPartition{SVector}` from `SecondOrderODEProblem` in OrdinaryDiffEq.jl) because `setindex!` is not defined for `SVector`. Extract the scalar→vector case into `_scalar_gradient_oop` which computes the finite difference result purely out-of-place using `@. (a - b) / h` (which allocates a new array via `copy` rather than mutating via `copyto!`). Fixes SciML/OrdinaryDiffEq.jl#3444 Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f37d23d commit 7d83628

2 files changed

Lines changed: 72 additions & 4 deletions

File tree

src/gradients.jl

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,12 +256,46 @@ function finite_difference_gradient(
256256
dir = true) where {T1, T2, T3, T4, fdtype, returntype, inplace}
257257
if typeof(x) <: AbstractArray
258258
df = zero(returntype) .* x
259+
finite_difference_gradient!(
260+
df, f, x, cache, relstep = relstep, absstep = absstep, dir = dir)
261+
df
259262
else
260-
df = zero(cache.c1)
263+
# Scalar x: compute out-of-place to support immutable output types
264+
# (e.g. ArrayPartition{SVector} from SecondOrderODEProblem).
265+
_scalar_gradient_oop(f, x, cache, fdtype, returntype, inplace;
266+
relstep = relstep, absstep = absstep, dir = dir)
267+
end
268+
end
269+
270+
# Out-of-place scalar→vector gradient that never mutates the result,
271+
# so it works even when f returns immutable arrays (SVector, etc.).
272+
function _scalar_gradient_oop(
273+
f, x::Number, cache, fdtype, returntype, inplace;
274+
relstep, absstep, dir)
275+
fx, c1, c2 = cache.fx, cache.c1, cache.c2
276+
277+
if fdtype == Val(:forward)
278+
epsilon = compute_epsilon(Val(:forward), x, relstep, absstep, dir)
279+
_c1 = inplace == Val(true) ? (f(c1, x + epsilon); c1) : f(x + epsilon)
280+
if typeof(fx) != Nothing
281+
@. (_c1 - fx) / epsilon
282+
else
283+
_c2 = inplace == Val(true) ? (f(c2, x); c2) : f(x)
284+
@. (_c1 - _c2) / epsilon
285+
end
286+
elseif fdtype == Val(:central)
287+
epsilon = compute_epsilon(Val(:central), x, relstep, absstep, dir)
288+
_c1 = inplace == Val(true) ? (f(c1, x + epsilon); c1) : f(x + epsilon)
289+
_c2 = inplace == Val(true) ? (f(c2, x - epsilon); c2) : f(x - epsilon)
290+
@. (_c1 - _c2) / (2 * epsilon)
291+
elseif fdtype == Val(:complex) && returntype <: Real
292+
epsilon_complex = eps(real(eltype(x)))
293+
_c1 = inplace == Val(true) ?
294+
(f(c1, x + im * epsilon_complex); c1) : f(x + im * epsilon_complex)
295+
@. imag(_c1) / epsilon_complex
296+
else
297+
fdtype_error(returntype)
261298
end
262-
finite_difference_gradient!(
263-
df, f, x, cache, relstep = relstep, absstep = absstep, dir = dir)
264-
df
265299
end
266300

267301
# vector of derivatives of a vector->scalar map by each component of a vector x

test/finitedifftests.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,40 @@ complex_cache = FiniteDiff.GradientCache(df, x, Val{:complex})
290290
@test err_func(FiniteDiff.finite_difference_gradient!(df, f, x, complex_cache), df_ref) < 1e-15
291291
end
292292

293+
@time @testset "Gradient of f:scalar->vector with immutable output" begin
294+
# Regression test: finite_difference_gradient with scalar x must work
295+
# when f returns immutable arrays, since the out-of-place API should
296+
# never need to mutate the result. This failed previously because the
297+
# cached version called finite_difference_gradient! which used @. df = ...
298+
# to write into an immutable buffer.
299+
#
300+
# We use a wrapper around a regular Vector that blocks setindex! to
301+
# simulate immutable array types (like StaticArrays.SVector or
302+
# ArrayPartition containing SVectors).
303+
struct ReadOnlyVec{T} <: AbstractVector{T}
304+
data::Vector{T}
305+
end
306+
Base.size(v::ReadOnlyVec) = size(v.data)
307+
Base.getindex(v::ReadOnlyVec, i::Int) = v.data[i]
308+
Base.setindex!(::ReadOnlyVec, _, ::Int) = error("ReadOnlyVec does not support setindex!")
309+
Base.similar(v::ReadOnlyVec) = ReadOnlyVec(zeros(eltype(v), length(v)))
310+
Base.zero(v::ReadOnlyVec) = ReadOnlyVec(zeros(eltype(v), length(v)))
311+
# Out-of-place broadcast returns a plain Vector (like SVector .+ SVector returns SVector)
312+
Base.BroadcastStyle(::Type{<:ReadOnlyVec}) = Broadcast.DefaultArrayStyle{1}()
313+
314+
g(t) = ReadOnlyVec([sin(t), cos(t)])
315+
t0 = 1.0
316+
g_ref = [cos(t0), -sin(t0)]
317+
318+
# Out-of-place cached version (the bug path)
319+
df_template = similar(g(t0))
320+
for fdtype in (Val(:forward), Val(:central))
321+
cache = FiniteDiff.GradientCache(df_template, t0, fdtype, Float64, Val(false))
322+
result = FiniteDiff.finite_difference_gradient(g, t0, cache)
323+
@test err_func(collect(result), g_ref) < 1e-4
324+
end
325+
end
326+
293327
f(df, x) = (df[1] = sin(x); df[2] = cos(x); df)
294328
x = (2π * rand()) * (1 + im)
295329
fx = fill(zero(typeof(x)), 2)

0 commit comments

Comments
 (0)