Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit a8f374b

Browse files
vpuri3ChrisRackauckas
authored andcommitted
iip,oop
1 parent ec9e0f4 commit a8f374b

3 files changed

Lines changed: 29 additions & 8 deletions

File tree

src/differentiation/jaches_products.jl

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,17 @@ function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
241241
vecprod = autodiff ? auto_jacvec : num_jacvec
242242
vecprod! = autodiff ? auto_jacvec! : num_jacvec!
243243

244+
outofplace = static_hasmethod(f, typeof((u,)))
245+
isinplace = static_hasmethod(f, typeof((u, u,)))
246+
247+
if !(isinplace) & !(outofplace)
248+
error("$f must have signature f(u), or f(du, u).")
249+
end
250+
244251
L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)
245252

246253
FunctionOperator(L, u, u;
247-
isinplace = true, outofplace = true,
254+
isinplace = isinplace, outofplace = outofplace,
248255
p = p, t = t, islinear = true,
249256
)
250257
end
@@ -266,10 +273,17 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
266273
vecprod = autodiff ? numauto_hesvec : num_hesvec
267274
vecprod! = autodiff ? numauto_hesvec! : num_hesvec!
268275

276+
outofplace = static_hasmethod(f, typeof((u,)))
277+
isinplace = static_hasmethod(f, typeof((u,)))
278+
279+
if !(isinplace) & !(outofplace)
280+
error("$f must have signature f(u).")
281+
end
282+
269283
L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)
270284

271285
FunctionOperator(L, u, u;
272-
isinplace = true, outofplace = true,
286+
isinplace = isinplace, outofplace = outofplace,
273287
p = p, t = t, islinear = true,
274288
)
275289
end
@@ -292,10 +306,17 @@ function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = tr
292306
vecprod = autodiff ? auto_hesvecgrad : num_hesvecgrad
293307
vecprod! = autodiff ? auto_hesvecgrad! : num_hesvecgrad!
294308

309+
outofplace = static_hasmethod(f, typeof((u,)))
310+
isinplace = static_hasmethod(f, typeof((u, u,)))
311+
312+
if !(isinplace) & !(outofplace)
313+
error("$f must have signature f(u), or f(du, u).")
314+
end
315+
295316
L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)
296317

297318
FunctionOperator(L, u, u;
298-
isinplace = true, outofplace = true,
319+
isinplace = isinplace, outofplace = outofplace,
299320
p = p, t = t, islinear = true,
300321
)
301322
end

src/differentiation/vecjac_products.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true,
9494
vecprod = autodiff ? auto_vecjac : num_vecjac
9595
vecprod! = autodiff ? auto_vecjac! : num_vecjac!
9696

97-
isinplace = static_hasmethod(f, typeof((u, p, t)))
98-
outofplace = static_hasmethod(f, typeof((u, u, p, t)))
97+
outofplace = static_hasmethod(f, typeof((u, p, t)))
98+
isinplace = static_hasmethod(f, typeof((u, u, p, t)))
9999

100100
if !(isinplace) & !(outofplace)
101101
error("$f must have signature f(u, p, t), or f(du, u, p, t)")

test/test_jaches_products.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x)
6565
@test auto_hesvecgrad!(dy, h, x, v, cache1, cache2)ForwardDiff.hessian(g, x) * v rtol=1e-2
6666
@test auto_hesvecgrad(h, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
6767

68-
### JacVec
68+
@info "JacVec"
6969

7070
L = JacVec(f, x)
7171
@test L * x auto_jacvec(f, x, x)
@@ -88,7 +88,7 @@ dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*num_jacvec(f,x,v) + b*_dy r
8888
out = similar(v)
8989
gmres!(out, L, v)
9090

91-
### HesVec
91+
@info "HesVec"
9292

9393
x = rand(N)
9494
v = rand(N)
@@ -113,7 +113,7 @@ dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy r
113113
out = similar(v)
114114
gmres!(out, L, v)
115115

116-
### HesVecGrad
116+
@info "HesVecGrad"
117117

118118
x = rand(N)
119119
v = rand(N)

0 commit comments

Comments
 (0)