diff --git a/src/codual.jl b/src/codual.jl index b92ed63349..adbc4d0881 100644 --- a/src/codual.jl +++ b/src/codual.jl @@ -15,7 +15,6 @@ end primal(x::CoDual) = x.x tangent(x::CoDual) = x.dx Base.copy(x::CoDual) = CoDual(copy(primal(x)), copy(tangent(x))) -# CoDual is immutable and can be safely shared without copying _copy(x::P) where {P<:CoDual} = x """ @@ -100,7 +99,6 @@ struct NoPullback{R<:Tuple} r::R end -# Recursively copy the contained reverse data _copy(x::P) where {P<:NoPullback} = P(_copy(x.r)) """ diff --git a/src/debug_mode.jl b/src/debug_mode.jl index ac9ae1793b..338d0f9c4c 100644 --- a/src/debug_mode.jl +++ b/src/debug_mode.jl @@ -78,7 +78,6 @@ struct DebugRRule{Trule} rule::Trule end -# Recursively copy the wrapped rule _copy(x::P) where {P<:DebugRRule} = P(_copy(x.rule)) """ diff --git a/src/dual.jl b/src/dual.jl index 26f0a25935..400ae347ac 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -15,7 +15,6 @@ end primal(x::Dual) = x.primal tangent(x::Dual) = x.tangent Base.copy(x::Dual) = Dual(copy(primal(x)), copy(tangent(x))) -# Dual is immutable and can be safely shared without copying _copy(x::P) where {P<:Dual} = x """ diff --git a/src/fwds_rvs_data.jl b/src/fwds_rvs_data.jl index e5b46319d0..e4b402a42f 100644 --- a/src/fwds_rvs_data.jl +++ b/src/fwds_rvs_data.jl @@ -22,7 +22,6 @@ struct FData{T<:NamedTuple} data::T end -# Recursively copy the wrapped data _copy(x::P) where {P<:FData} = P(_copy(x.data)) fields_type(::Type{FData{T}}) where {T<:NamedTuple} = T @@ -406,7 +405,6 @@ struct RData{T<:NamedTuple} data::T end -# Recursively copy the wrapped data _copy(x::P) where {P<:RData} = P(_copy(x.data)) fields_type(::Type{RData{T}}) where {T<:NamedTuple} = T @@ -833,7 +831,6 @@ struct LazyZeroRData{P,Tdata} data::Tdata end -# Recursively copy the wrapped data _copy(x::P) where {P<:LazyZeroRData} = P(_copy(x.data)) # Returns the type which must be output by LazyZeroRData whenever it is passed a `P`. diff --git a/src/interpreter/forward_mode.jl b/src/interpreter/forward_mode.jl index 6d0b32361c..37774aa588 100644 --- a/src/interpreter/forward_mode.jl +++ b/src/interpreter/forward_mode.jl @@ -71,7 +71,6 @@ end return fwd.fwd_oc(__unflatten_dual_varargs(isva, args, Val(nargs))...) end -# Copy forward rule with recursively copied captures function _copy(x::P) where {P<:DerivedFRule} return P(replace_captures(x.fwd_oc, _copy(x.fwd_oc.oc.captures))) end @@ -399,7 +398,6 @@ mutable struct LazyFRule{primal_sig,Trule} end end -# Create new lazy rule with same method instance and debug mode _copy(x::P) where {P<:LazyFRule} = P(x.mi, x.debug_mode) @inline function (rule::LazyFRule)(args::Vararg{Any,N}) where {N} @@ -440,7 +438,6 @@ end DynamicFRule(debug_mode::Bool) = DynamicFRule(Dict{Any,Any}(), debug_mode) -# Create new dynamic rule with empty cache and same debug mode _copy(x::P) where {P<:DynamicFRule} = P(Dict{Any,Any}(), x.debug_mode) function (dynamic_rule::DynamicFRule)(args::Vararg{Dual,N}) where {N} diff --git a/src/interpreter/reverse_mode.jl b/src/interpreter/reverse_mode.jl index 0bc23c8695..bb7fb941f6 100644 --- a/src/interpreter/reverse_mode.jl +++ b/src/interpreter/reverse_mode.jl @@ -289,7 +289,6 @@ struct RRuleZeroWrapper{Trule} rule::Trule end -# Recursively copy the wrapped rule _copy(x::P) where {P<:RRuleZeroWrapper} = P(_copy(x.rule)) struct RRuleWrapperPb{Tpb!!,Tl} @@ -952,6 +951,8 @@ function verify_args(r::DerivedRule{sig}, x) where {sig} throw(ArgumentError("Arguments with sig $Tx do not subtype rule signature, $sig")) end +_copy(::Nothing) = nothing + function _copy(x::P) where {P<:DerivedRule} new_captures = _copy(x.fwds_oc.oc.captures) new_fwds_oc = replace_captures(x.fwds_oc, new_captures) @@ -959,6 +960,18 @@ function _copy(x::P) where {P<:DerivedRule} return P(new_fwds_oc, new_pb_oc_ref, x.nargs) end +_copy(x::Symbol) = x + +_copy(x::Tuple) = map(_copy, x) + +_copy(x::NamedTuple) = map(_copy, x) + +_copy(x::Ref{T}) where {T} = isassigned(x) ? Ref{T}(_copy(x[])) : Ref{T}() + +_copy(x::Type) = x + +_copy(x) = copy(x) + @inline function (fwds::DerivedRule{sig})(args::Vararg{CoDual,N}) where {sig,N} uf_args = __unflatten_codual_varargs(_isva(fwds), args, fwds.nargs) pb = Pullback(sig, fwds.pb_oc_ref, _isva(fwds), N) @@ -1730,7 +1743,6 @@ end DynamicDerivedRule(debug_mode::Bool) = DynamicDerivedRule(Dict{Any,Any}(), debug_mode) -# Create new dynamic rule with empty cache and same debug mode _copy(x::P) where {P<:DynamicDerivedRule} = P(Dict{Any,Any}(), x.debug_mode) function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any,N}) where {N} @@ -1828,7 +1840,6 @@ mutable struct LazyDerivedRule{primal_sig,Trule} end end -# Create new lazy rule with same method instance and debug mode _copy(x::P) where {P<:LazyDerivedRule} = P(x.mi, x.debug_mode) @inline function (rule::LazyDerivedRule)(args::Vararg{Any,N}) where {N} diff --git a/src/rrules/misty_closures.jl b/src/rrules/misty_closures.jl index 22f415eff0..d5a3ea4bf4 100644 --- a/src/rrules/misty_closures.jl +++ b/src/rrules/misty_closures.jl @@ -77,7 +77,6 @@ struct MistyClosureRData{Tr} captures_rdata::Tr end -# Deep copy the captures data for misty closures _copy(r::MistyClosureRData) = MistyClosureRData(deepcopy(r.captures_rdata)) fdata_type(::Type{<:MistyClosureTangent}) = MistyClosureFData diff --git a/src/stack.jl b/src/stack.jl index 1a0716760f..a7259c71ba 100644 --- a/src/stack.jl +++ b/src/stack.jl @@ -11,7 +11,6 @@ mutable struct Stack{T} Stack{T}() where {T} = new{T}(Vector{T}(undef, 0), 0) end -# Create a new empty stack of the same type _copy(::Stack{T}) where {T} = Stack{T}() @inline function Base.push!(x::Stack{T}, val::T) where {T} diff --git a/src/tangents.jl b/src/tangents.jl index b8a13b1454..26a7aa53e0 100644 --- a/src/tangents.jl +++ b/src/tangents.jl @@ -20,7 +20,6 @@ struct PossiblyUninitTangent{T} PossiblyUninitTangent{T}() where {T} = new{T}() end -# Copy only if initialized, otherwise create new uninitialized instance _copy(x::P) where {P<:PossiblyUninitTangent} = is_init(x) ? P(_copy(x.tangent)) : P() @inline PossiblyUninitTangent(tangent::T) where {T} = PossiblyUninitTangent{T}(tangent) diff --git a/src/utils.jl b/src/utils.jl index 1797ba6cd3..c2218efea7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -418,75 +418,3 @@ function _copytrito!(B::AbstractMatrix, A::AbstractMatrix, uplo::AbstractChar) end return B end - -""" - _copy(x) - -*Note:* this is not part of the public Mooncake.jl interface, and may change without warning. - -Internal protocol for creating copies of AD-related data structures. This function is used -throughout the automatic differentiation system to create appropriate copies of rules, -caches, and other internal data structures when building new AD contexts. - -# Semantics - -The `_copy` protocol defines how different types should be copied within the AD system: - -- For immutable AD types (like `CoDual`, `Dual`), typically returns the same object -- For mutable containers, creates new instances with copied contents -- For composite types, recursively applies `_copy` to fields -- Falls back to `Base.copy` for general types - -# Implementation Requirements - -When implementing `_copy` for a new type, consider: -- Whether the type represents mutable state that needs actual copying -- Whether fields should be recursively copied or can be shared -- Performance implications of copying vs. sharing immutable data - -# Examples - -```julia -# For immutable AD types - no copying needed -_copy(x::CoDual) = x - -# For `Stack` type - create new empty instance -_copy(::Stack{T}) where {T} = Stack{T}() - -# For composite types - recursive copying -_copy(x::Tuple) = map(_copy, x) - -# For rule types - create new instances with appropriately copied captures/caches -function _copy(x::DerivedRule) - new_captures = _copy(x.fwds_oc.oc.captures) - new_fwds_oc = replace_captures(x.fwds_oc, new_captures) - new_pb_oc_ref = Ref(replace_captures(x.pb_oc_ref[], new_captures)) - return typeof(x)(new_fwds_oc, new_pb_oc_ref, x.nargs) -end - -# For misty closure reverse data - deep copy the captures data -_copy(r::MistyClosureRData) = MistyClosureRData(deepcopy(r.captures_rdata)) - -# For tangent types - copy conditionally based on initialization state -_copy(x::PossiblyUninitTangent) = is_init(x) ? typeof(x)(_copy(x.tangent)) : typeof(x)() - -# For forwards/reverse data types - recursively copy wrapped data -_copy(x::FData) = typeof(x)(_copy(x.data)) -_copy(x::RData) = typeof(x)(_copy(x.data)) -_copy(x::LazyZeroRData) = typeof(x)(_copy(x.data)) - -# Fallback to Base.copy -_copy(x) = copy(x) -``` -""" - -# Generic implementations that work with any type -_copy(::Nothing) = nothing -_copy(x::Symbol) = x -_copy(x::Tuple) = map(_copy, x) -_copy(x::NamedTuple) = map(_copy, x) -_copy(x::Ref{T}) where {T} = isassigned(x) ? Ref{T}(_copy(x[])) : Ref{T}() -_copy(x::Type) = x - -# Fallback to Base.copy for all other types -_copy(x) = copy(x)