Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 84542bb

Browse files
vpuri3ChrisRackauckas
authored andcommitted
redefine JacVec, HesVec, HesVecGrad, VecJac
1 parent e1d8be1 commit 84542bb

4 files changed

Lines changed: 159 additions & 275 deletions

File tree

src/SparseDiffTools.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ using SciMLOperators
2323
import SciMLOperators: update_coefficients, update_coefficients!
2424
using Tricks: static_hasmethod
2525

26+
abstract type AbstractAutoDiffVecProd end
27+
2628
export contract_color,
2729
greedy_d1,
2830
greedy_star1_coloring,
@@ -46,8 +48,7 @@ export contract_color,
4648
autonum_hesvec, autonum_hesvec!,
4749
num_hesvecgrad, num_hesvecgrad!,
4850
auto_hesvecgrad, auto_hesvecgrad!,
49-
JacVec, HesVec, HesVecGrad,
50-
JacVecProd, HesVecProd, HesVecGradProd, VecJacProd,
51+
JacVec, HesVec, HesVecGrad, VecJac,
5152
update_coefficients, update_coefficients!,
5253
value!
5354

@@ -64,8 +65,6 @@ include("differentiation/compute_hessian_ad.jl")
6465
include("differentiation/jaches_products.jl")
6566
include("differentiation/vecjac_products.jl")
6667

67-
include("differentiation/operators.jl")
68-
6968
Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
7069
parameterless_type(x) = parameterless_type(typeof(x))
7170
parameterless_type(x::Type) = __parameterless_type(x)

src/differentiation/jaches_products.jl

Lines changed: 74 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -198,112 +198,105 @@ end
198198

199199
### Operator Forms
200200

201-
struct JacVec{F, T1, T2, xType}
201+
mutable struct FwdModeAutoDiffVecProd{F,U,C,V,V!} <: AbstractAutoDiffVecProd
202202
f::F
203-
cache1::T1
204-
cache2::T2
205-
x::xType
206-
autodiff::Bool
203+
u::U
204+
cache::C
205+
vecprod::V
206+
vecprod!::V!
207207
end
208208

209-
function JacVec(f, x::AbstractArray, tag = DeivVecTag(); autodiff = true)
210-
if autodiff
211-
cache1 = Dual{typeof(ForwardDiff.Tag(tag, eltype(x))), eltype(x), 1
212-
}.(x, ForwardDiff.Partials.(tuple.(x)))
213-
cache2 = Dual{typeof(ForwardDiff.Tag(tag, eltype(x))), eltype(x), 1
214-
}.(x, ForwardDiff.Partials.(tuple.(x)))
215-
else
216-
cache1 = similar(x)
217-
cache2 = similar(x)
218-
end
219-
JacVec(f, cache1, cache2, x, autodiff)
209+
function update_coefficients(L::FwdModeAutoDiffVecProd, u, p, t)
210+
FwdModeAutoDiffVecProd(L.f, u, L.vecprod, L.vecprod!, L.cache)
220211
end
221212

222-
Base.eltype(L::JacVec) = eltype(L.x)
223-
Base.size(L::JacVec) = (length(L.cache1), length(L.cache1))
224-
Base.size(L::JacVec, i::Int) = length(L.cache1)
225-
function Base.:*(L::JacVec, v::AbstractVector)
226-
L.autodiff ? auto_jacvec(_x -> L.f(_x), L.x, v) :
227-
num_jacvec(_x -> L.f(_x), L.x, v)
213+
function update_coefficients!(L::FwdModeAutoDiffVecProd, u, p, t)
214+
L.u .= u
215+
L
228216
end
229217

230-
function LinearAlgebra.mul!(dy::AbstractVector, L::JacVec, v::AbstractVector)
231-
if L.autodiff
232-
auto_jacvec!(dy, (_y, _x) -> L.f(_y, _x), L.x, v, L.cache1, L.cache2)
233-
else
234-
num_jacvec!(dy, (_y, _x) -> L.f(_y, _x), L.x, v, L.cache1, L.cache2)
235-
end
218+
function (L::FwdModeAutoDiffVecProd)(v, p, t)
219+
L.vecprod(L.f, L.u, v)
236220
end
237221

238-
struct HesVec{F, T1, T2, xType}
239-
f::F
240-
cache1::T1
241-
cache2::T2
242-
cache3::T2
243-
x::xType
244-
autodiff::Bool
222+
function (L::FwdModeAutoDiffVecProd)(dv, v, p, t)
223+
L.vecprod!(dv, L.f, L.u, v, L.cache...)
245224
end
246225

247-
function HesVec(f, x::AbstractArray; autodiff = true)
226+
function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
227+
248228
if autodiff
249-
cache1 = ForwardDiff.GradientConfig(f, x)
250-
cache2 = similar(x)
251-
cache3 = similar(x)
229+
cache1 = Dual{
230+
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
231+
}.(u, ForwardDiff.Partials.(tuple.(u)))
232+
233+
cache2 = copy(cache1)
252234
else
253-
cache1 = similar(x)
254-
cache2 = similar(x)
255-
cache3 = similar(x)
235+
cache1 = similar(u)
236+
cache2 = similar(u)
256237
end
257-
HesVec(f, cache1, cache2, cache3, x, autodiff)
258-
end
259238

260-
Base.size(L::HesVec) = (length(L.cache2), length(L.cache2))
261-
Base.size(L::HesVec, i::Int) = length(L.cache2)
262-
function Base.:*(L::HesVec, v::AbstractVector)
263-
L.autodiff ? numauto_hesvec(L.f, L.x, v) : num_hesvec(L.f, L.x, v)
264-
end
239+
cache = (cache1, cache2,)
265240

266-
function LinearAlgebra.mul!(dy::AbstractVector, L::HesVec, v::AbstractVector)
267-
if L.autodiff
268-
numauto_hesvec!(dy, L.f, L.x, v, L.cache1, L.cache2, L.cache3)
269-
else
270-
num_hesvec!(dy, L.f, L.x, v, L.cache1, L.cache2, L.cache3)
271-
end
272-
end
241+
vecprod = autodiff ? auto_jacvec : num_jacvec
242+
vecprod! = autodiff ? auto_jacvec! : num_jacvec!
273243

274-
struct HesVecGrad{G, T1, T2, uType}
275-
g::G
276-
cache1::T1
277-
cache2::T2
278-
x::uType
279-
autodiff::Bool
244+
L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)
245+
246+
FunctionOperator(L, u, u;
247+
isinplace = true, outofplace = true,
248+
p = p, t = t, islinear = true,
249+
)
280250
end
281251

282-
function HesVecGrad(g, x::AbstractArray, tag = DeivVecTag(); autodiff = false)
252+
function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
253+
283254
if autodiff
284-
cache1 = Dual{typeof(ForwardDiff.Tag(tag, eltype(x))), eltype(x), 1
285-
}.(x, ForwardDiff.Partials.(tuple.(x)))
286-
cache2 = Dual{typeof(ForwardDiff.Tag(tag, eltype(x))), eltype(x), 1
287-
}.(x, ForwardDiff.Partials.(tuple.(x)))
255+
cache1 = ForwardDiff.GradientConfig(f, u)
256+
cache2 = similar(u)
257+
cache3 = similar(u)
288258
else
289-
cache1 = similar(x)
290-
cache2 = similar(x)
259+
cache1 = similar(u)
260+
cache2 = similar(u)
261+
cache3 = similar(u)
291262
end
292-
HesVecGrad(g, cache1, cache2, x, autodiff)
293-
end
294263

295-
Base.size(L::HesVecGrad) = (length(L.cache2), length(L.cache2))
296-
Base.size(L::HesVecGrad, i::Int) = length(L.cache2)
297-
function Base.:*(L::HesVecGrad, v::AbstractVector)
298-
L.autodiff ? auto_hesvecgrad(L.g, L.x, v) : num_hesvecgrad(L.g, L.x, v)
264+
cache = (cache1, cache2, cache3,)
265+
266+
vecprod = autodiff ? numauto_hesvec : num_hesvec
267+
vecprod! = autodiff ? numauto_hesvec! : num_hesvec!
268+
269+
L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)
270+
271+
FunctionOperator(L, u, u;
272+
isinplace = true, outofplace = true,
273+
p = p, t = t, islinear = true,
274+
)
299275
end
300276

301-
function LinearAlgebra.mul!(dy::AbstractVector,
302-
L::HesVecGrad,
303-
v::AbstractVector)
304-
if L.autodiff
305-
auto_hesvecgrad!(dy, L.g, L.x, v, L.cache1, L.cache2)
277+
function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
278+
279+
if autodiff
280+
cache1 = Dual{
281+
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
282+
}.(u, ForwardDiff.Partials.(tuple.(u)))
283+
284+
cache2 = copy(cache1)
306285
else
307-
num_hesvecgrad!(dy, L.g, L.x, v, L.cache1, L.cache2)
286+
cache1 = similar(u)
287+
cache2 = similar(u)
308288
end
289+
290+
cache = (cache1, cache2,)
291+
292+
vecprod = autodiff ? auto_hesvecgrad : num_hesvecgrad
293+
vecprod! = autodiff ? auto_hesvecgrad! : num_hesvecgrad!
294+
295+
L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)
296+
297+
FunctionOperator(L, u, u;
298+
isinplace = true, outofplace = true,
299+
p = p, t = t, islinear = true,
300+
)
309301
end
302+
#

0 commit comments

Comments
 (0)