Forwards-mode is taking shape nicely in #389 , but will initially only let you propagate a single tangent forwards at a time. We would ideally like to be able to propagate a collection of tangents forwards at the same time for efficiency reasons -- this is a common optimisation permitted by many systems.
In order to do this, we will require three extensions to the existing system:
- design + implement a chunked tangent system,
- extend all
frule!!s to handle it, and
- make some very minor tweaks to
build_frule in order to handle this new system.
Points 2 and 3 are self-explanatory, but point 1 requires some discussion.
Chunked-Tangent System
The basic idea here is to pair each primal x with a collection of tangents ts. I'm going to assume the chunk size statically, and bake it into type for the chunks (where appropriate). We might do something like
struct ChunkedPrimitiveTangent{T,N}
ts::NTuple{N,T}
end
tangent_type(::Type{Float64}, ::Val{N}) where {N} = ChunkedPrimitiveTangent{tangent_type(T),N}
i.e. the chunked tangent type for Float64 with chunk size N is the thing returned above.
For Arrays, we probably want the chunked tangent type to be another Array with an extra dimension to contain the chunks. i.e. the chunked tangent type for a Vector{Float64} would be a Matrix{Float64}, where the second dimension indexes over the chunks. We might want to shove this inside a container to make it semantically meaningful / to ensure that we can enforce statically that all active data has the same chunk size.
For composite types like Tuples, NamedTuples, structs, and mutable structs, we should push the chunks down to the leaves of the objects. i.e. the chunked tangent type for Tuple{Vector{Float64}, Float64} should be a Tuple{Matrix{Float64}, ChunkedPrimitiveTangent{...}} rather than a Vector{Tuple{Vector{Float64},Float64}}.
Discussed in https://github.com/chalk-lab/Mooncake.jl/discussions/533
Forwards-mode is taking shape nicely in #389 , but will initially only let you propagate a single tangent forwards at a time. We would ideally like to be able to propagate a collection of tangents forwards at the same time for efficiency reasons -- this is a common optimisation permitted by many systems.
In order to do this, we will require three extensions to the existing system:
frule!!s to handle it, andbuild_frulein order to handle this new system.Points 2 and 3 are self-explanatory, but point 1 requires some discussion.
Chunked-Tangent System
The basic idea here is to pair each primal
xwith a collection of tangentsts. I'm going to assume the chunk size statically, and bake it into type for the chunks (where appropriate). We might do something likei.e. the chunked tangent type for
Float64with chunk sizeNis the thing returned above.For
Arrays, we probably want the chunked tangent type to be anotherArraywith an extra dimension to contain the chunks. i.e. the chunked tangent type for aVector{Float64}would be aMatrix{Float64}, where the second dimension indexes over the chunks. We might want to shove this inside a container to make it semantically meaningful / to ensure that we can enforce statically that all active data has the same chunk size.For composite types like
Tuples,NamedTuples,structs, andmutable structs, we should push the chunks down to the leaves of the objects. i.e. the chunked tangent type forTuple{Vector{Float64}, Float64}should be aTuple{Matrix{Float64}, ChunkedPrimitiveTangent{...}}rather than aVector{Tuple{Vector{Float64},Float64}}.