|
| 1 | +module SparseDiffToolsPolyesterForwardDiffExt |
| 2 | + |
| 3 | +using ADTypes, SparseDiffTools, PolyesterForwardDiff |
| 4 | +import ForwardDiff |
| 5 | +import SparseDiffTools: AbstractMaybeSparseJacobianCache, AbstractMaybeSparsityDetection, |
| 6 | + ForwardColorJacCache, NoMatrixColoring, sparse_jacobian_cache, sparse_jacobian!, |
| 7 | + sparse_jacobian_static_array, __standard_tag, __chunksize |
| 8 | + |
| 9 | +struct PolyesterForwardDiffJacobianCache{CO, CA, J, FX, X} <: |
| 10 | + AbstractMaybeSparseJacobianCache |
| 11 | + coloring::CO |
| 12 | + cache::CA |
| 13 | + jac_prototype::J |
| 14 | + fx::FX |
| 15 | + x::X |
| 16 | +end |
| 17 | + |
| 18 | +function sparse_jacobian_cache(ad::Union{AutoSparsePolyesterForwardDiff, |
| 19 | + AutoPolyesterForwardDiff}, sd::AbstractMaybeSparsityDetection, f::F, x; |
| 20 | + fx = nothing) where {F} |
| 21 | + coloring_result = sd(ad, f, x) |
| 22 | + fx = fx === nothing ? similar(f(x)) : fx |
| 23 | + if coloring_result isa NoMatrixColoring |
| 24 | + cache = __chunksize(ad, x) |
| 25 | + jac_prototype = nothing |
| 26 | + else |
| 27 | + @warn """Currently PolyesterForwardDiff does not support sparsity detection |
| 28 | + natively. Falling back to using ForwardDiff.jl""" maxlog=1 |
| 29 | + tag = __standard_tag(nothing, x) |
| 30 | + # Colored ForwardDiff passes `tag` directly into Dual so we need the `typeof` |
| 31 | + cache = ForwardColorJacCache(f, x, __chunksize(ad); coloring_result.colorvec, |
| 32 | + dx = fx, sparsity = coloring_result.jacobian_sparsity, tag = typeof(tag)) |
| 33 | + jac_prototype = coloring_result.jacobian_sparsity |
| 34 | + end |
| 35 | + return PolyesterForwardDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x) |
| 36 | +end |
| 37 | + |
| 38 | +function sparse_jacobian_cache(ad::Union{AutoSparsePolyesterForwardDiff, |
| 39 | + AutoPolyesterForwardDiff}, sd::AbstractMaybeSparsityDetection, f!::F, fx, |
| 40 | + x) where {F} |
| 41 | + coloring_result = sd(ad, f!, fx, x) |
| 42 | + if coloring_result isa NoMatrixColoring |
| 43 | + cache = __chunksize(ad, x) |
| 44 | + jac_prototype = nothing |
| 45 | + else |
| 46 | + @warn """Currently PolyesterForwardDiff does not support sparsity detection |
| 47 | + natively. Falling back to using ForwardDiff.jl""" maxlog=1 |
| 48 | + tag = __standard_tag(nothing, x) |
| 49 | + # Colored ForwardDiff passes `tag` directly into Dual so we need the `typeof` |
| 50 | + cache = ForwardColorJacCache(f!, x, __chunksize(ad); coloring_result.colorvec, |
| 51 | + dx = fx, sparsity = coloring_result.jacobian_sparsity, tag = typeof(tag)) |
| 52 | + jac_prototype = coloring_result.jacobian_sparsity |
| 53 | + end |
| 54 | + return PolyesterForwardDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x) |
| 55 | +end |
| 56 | + |
| 57 | +function sparse_jacobian!(J::AbstractMatrix, _, cache::PolyesterForwardDiffJacobianCache, |
| 58 | + f::F, x) where {F} |
| 59 | + if cache.cache isa ForwardColorJacCache |
| 60 | + forwarddiff_color_jacobian(J, f, x, cache.cache) # Use Sparse ForwardDiff |
| 61 | + else |
| 62 | + PolyesterForwardDiff.threaded_jacobian!(f, J, x, cache.cache) # Don't try to exploit sparsity |
| 63 | + end |
| 64 | + return J |
| 65 | +end |
| 66 | + |
| 67 | +function sparse_jacobian!(J::AbstractMatrix, _, cache::PolyesterForwardDiffJacobianCache, |
| 68 | + f!::F, fx, x) where {F} |
| 69 | + if cache.cache isa ForwardColorJacCache |
| 70 | + forwarddiff_color_jacobian!(J, f!, x, cache.cache) # Use Sparse ForwardDiff |
| 71 | + else |
| 72 | + PolyesterForwardDiff.threaded_jacobian!(f!, fx, J, x, cache.cache) # Don't try to exploit sparsity |
| 73 | + end |
| 74 | + return J |
| 75 | +end |
| 76 | + |
| 77 | +end |
0 commit comments