@@ -256,12 +256,46 @@ function finite_difference_gradient(
256256 dir = true ) where {T1, T2, T3, T4, fdtype, returntype, inplace}
257257 if typeof (x) <: AbstractArray
258258 df = zero (returntype) .* x
259+ finite_difference_gradient! (
260+ df, f, x, cache, relstep = relstep, absstep = absstep, dir = dir)
261+ df
259262 else
260- df = zero (cache. c1)
263+ # Scalar x: compute out-of-place to support immutable output types
264+ # (e.g. ArrayPartition{SVector} from SecondOrderODEProblem).
265+ _scalar_gradient_oop (f, x, cache, fdtype, returntype, inplace;
266+ relstep = relstep, absstep = absstep, dir = dir)
267+ end
268+ end
269+
270+ # Out-of-place scalar→vector gradient that never mutates the result,
271+ # so it works even when f returns immutable arrays (SVector, etc.).
272+ function _scalar_gradient_oop (
273+ f, x:: Number , cache, fdtype, returntype, inplace;
274+ relstep, absstep, dir)
275+ fx, c1, c2 = cache. fx, cache. c1, cache. c2
276+
277+ if fdtype == Val (:forward )
278+ epsilon = compute_epsilon (Val (:forward ), x, relstep, absstep, dir)
279+ _c1 = inplace == Val (true ) ? (f (c1, x + epsilon); c1) : f (x + epsilon)
280+ if typeof (fx) != Nothing
281+ @. (_c1 - fx) / epsilon
282+ else
283+ _c2 = inplace == Val (true ) ? (f (c2, x); c2) : f (x)
284+ @. (_c1 - _c2) / epsilon
285+ end
286+ elseif fdtype == Val (:central )
287+ epsilon = compute_epsilon (Val (:central ), x, relstep, absstep, dir)
288+ _c1 = inplace == Val (true ) ? (f (c1, x + epsilon); c1) : f (x + epsilon)
289+ _c2 = inplace == Val (true ) ? (f (c2, x - epsilon); c2) : f (x - epsilon)
290+ @. (_c1 - _c2) / (2 * epsilon)
291+ elseif fdtype == Val (:complex ) && returntype <: Real
292+ epsilon_complex = eps (real (eltype (x)))
293+ _c1 = inplace == Val (true ) ?
294+ (f (c1, x + im * epsilon_complex); c1) : f (x + im * epsilon_complex)
295+ @. imag (_c1) / epsilon_complex
296+ else
297+ fdtype_error (returntype)
261298 end
262- finite_difference_gradient! (
263- df, f, x, cache, relstep = relstep, absstep = absstep, dir = dir)
264- df
265299end
266300
267301# vector of derivatives of a vector->scalar map by each component of a vector x
0 commit comments