Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/codual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ end
primal(x::CoDual) = x.x
tangent(x::CoDual) = x.dx
Base.copy(x::CoDual) = CoDual(copy(primal(x)), copy(tangent(x)))
# CoDual is immutable and can be safely shared without copying
_copy(x::P) where {P<:CoDual} = x

"""
Expand Down Expand Up @@ -100,7 +99,6 @@ struct NoPullback{R<:Tuple}
r::R
end

# Recursively copy the contained reverse data
_copy(x::P) where {P<:NoPullback} = P(_copy(x.r))

"""
Expand Down
1 change: 0 additions & 1 deletion src/debug_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ struct DebugRRule{Trule}
rule::Trule
end

# Recursively copy the wrapped rule
_copy(x::P) where {P<:DebugRRule} = P(_copy(x.rule))

"""
Expand Down
1 change: 0 additions & 1 deletion src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ end
primal(x::Dual) = x.primal
tangent(x::Dual) = x.tangent
Base.copy(x::Dual) = Dual(copy(primal(x)), copy(tangent(x)))
# Dual is immutable and can be safely shared without copying
_copy(x::P) where {P<:Dual} = x

"""
Expand Down
3 changes: 0 additions & 3 deletions src/fwds_rvs_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ struct FData{T<:NamedTuple}
data::T
end

# Recursively copy the wrapped data
_copy(x::P) where {P<:FData} = P(_copy(x.data))

fields_type(::Type{FData{T}}) where {T<:NamedTuple} = T
Expand Down Expand Up @@ -406,7 +405,6 @@ struct RData{T<:NamedTuple}
data::T
end

# Recursively copy the wrapped data
_copy(x::P) where {P<:RData} = P(_copy(x.data))

fields_type(::Type{RData{T}}) where {T<:NamedTuple} = T
Expand Down Expand Up @@ -833,7 +831,6 @@ struct LazyZeroRData{P,Tdata}
data::Tdata
end

# Recursively copy the wrapped data
_copy(x::P) where {P<:LazyZeroRData} = P(_copy(x.data))

# Returns the type which must be output by LazyZeroRData whenever it is passed a `P`.
Expand Down
3 changes: 0 additions & 3 deletions src/interpreter/forward_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ end
return fwd.fwd_oc(__unflatten_dual_varargs(isva, args, Val(nargs))...)
end

# Copy forward rule with recursively copied captures
function _copy(x::P) where {P<:DerivedFRule}
return P(replace_captures(x.fwd_oc, _copy(x.fwd_oc.oc.captures)))
end
Expand Down Expand Up @@ -399,7 +398,6 @@ mutable struct LazyFRule{primal_sig,Trule}
end
end

# Create new lazy rule with same method instance and debug mode
_copy(x::P) where {P<:LazyFRule} = P(x.mi, x.debug_mode)

@inline function (rule::LazyFRule)(args::Vararg{Any,N}) where {N}
Expand Down Expand Up @@ -440,7 +438,6 @@ end

DynamicFRule(debug_mode::Bool) = DynamicFRule(Dict{Any,Any}(), debug_mode)

# Create new dynamic rule with empty cache and same debug mode
_copy(x::P) where {P<:DynamicFRule} = P(Dict{Any,Any}(), x.debug_mode)

function (dynamic_rule::DynamicFRule)(args::Vararg{Dual,N}) where {N}
Expand Down
17 changes: 14 additions & 3 deletions src/interpreter/reverse_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ struct RRuleZeroWrapper{Trule}
rule::Trule
end

# Recursively copy the wrapped rule
_copy(x::P) where {P<:RRuleZeroWrapper} = P(_copy(x.rule))

struct RRuleWrapperPb{Tpb!!,Tl}
Expand Down Expand Up @@ -952,13 +951,27 @@ function verify_args(r::DerivedRule{sig}, x) where {sig}
throw(ArgumentError("Arguments with sig $Tx do not subtype rule signature, $sig"))
end

_copy(::Nothing) = nothing

function _copy(x::P) where {P<:DerivedRule}
new_captures = _copy(x.fwds_oc.oc.captures)
new_fwds_oc = replace_captures(x.fwds_oc, new_captures)
new_pb_oc_ref = Ref(replace_captures(x.pb_oc_ref[], new_captures))
return P(new_fwds_oc, new_pb_oc_ref, x.nargs)
end

_copy(x::Symbol) = x

_copy(x::Tuple) = map(_copy, x)

_copy(x::NamedTuple) = map(_copy, x)

_copy(x::Ref{T}) where {T} = isassigned(x) ? Ref{T}(_copy(x[])) : Ref{T}()

_copy(x::Type) = x

_copy(x) = copy(x)

@inline function (fwds::DerivedRule{sig})(args::Vararg{CoDual,N}) where {sig,N}
uf_args = __unflatten_codual_varargs(_isva(fwds), args, fwds.nargs)
pb = Pullback(sig, fwds.pb_oc_ref, _isva(fwds), N)
Expand Down Expand Up @@ -1730,7 +1743,6 @@ end

DynamicDerivedRule(debug_mode::Bool) = DynamicDerivedRule(Dict{Any,Any}(), debug_mode)

# Create new dynamic rule with empty cache and same debug mode
_copy(x::P) where {P<:DynamicDerivedRule} = P(Dict{Any,Any}(), x.debug_mode)

function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any,N}) where {N}
Expand Down Expand Up @@ -1828,7 +1840,6 @@ mutable struct LazyDerivedRule{primal_sig,Trule}
end
end

# Create new lazy rule with same method instance and debug mode
_copy(x::P) where {P<:LazyDerivedRule} = P(x.mi, x.debug_mode)

@inline function (rule::LazyDerivedRule)(args::Vararg{Any,N}) where {N}
Expand Down
1 change: 0 additions & 1 deletion src/rrules/misty_closures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ struct MistyClosureRData{Tr}
captures_rdata::Tr
end

# Deep copy the captures data for misty closures
_copy(r::MistyClosureRData) = MistyClosureRData(deepcopy(r.captures_rdata))

fdata_type(::Type{<:MistyClosureTangent}) = MistyClosureFData
Expand Down
1 change: 0 additions & 1 deletion src/stack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ mutable struct Stack{T}
Stack{T}() where {T} = new{T}(Vector{T}(undef, 0), 0)
end

# Create a new empty stack of the same type
_copy(::Stack{T}) where {T} = Stack{T}()

@inline function Base.push!(x::Stack{T}, val::T) where {T}
Expand Down
1 change: 0 additions & 1 deletion src/tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ struct PossiblyUninitTangent{T}
PossiblyUninitTangent{T}() where {T} = new{T}()
end

# Copy only if initialized, otherwise create new uninitialized instance
_copy(x::P) where {P<:PossiblyUninitTangent} = is_init(x) ? P(_copy(x.tangent)) : P()

@inline PossiblyUninitTangent(tangent::T) where {T} = PossiblyUninitTangent{T}(tangent)
Expand Down
72 changes: 0 additions & 72 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,75 +418,3 @@ function _copytrito!(B::AbstractMatrix, A::AbstractMatrix, uplo::AbstractChar)
end
return B
end

"""
_copy(x)

*Note:* this is not part of the public Mooncake.jl interface, and may change without warning.

Internal protocol for creating copies of AD-related data structures. This function is used
throughout the automatic differentiation system to create appropriate copies of rules,
caches, and other internal data structures when building new AD contexts.

# Semantics

The `_copy` protocol defines how different types should be copied within the AD system:

- For immutable AD types (like `CoDual`, `Dual`), typically returns the same object
- For mutable containers, creates new instances with copied contents
- For composite types, recursively applies `_copy` to fields
- Falls back to `Base.copy` for general types

# Implementation Requirements

When implementing `_copy` for a new type, consider:
- Whether the type represents mutable state that needs actual copying
- Whether fields should be recursively copied or can be shared
- Performance implications of copying vs. sharing immutable data

# Examples

```julia
# For immutable AD types - no copying needed
_copy(x::CoDual) = x

# For `Stack` type - create new empty instance
_copy(::Stack{T}) where {T} = Stack{T}()

# For composite types - recursive copying
_copy(x::Tuple) = map(_copy, x)

# For rule types - create new instances with appropriately copied captures/caches
function _copy(x::DerivedRule)
new_captures = _copy(x.fwds_oc.oc.captures)
new_fwds_oc = replace_captures(x.fwds_oc, new_captures)
new_pb_oc_ref = Ref(replace_captures(x.pb_oc_ref[], new_captures))
return typeof(x)(new_fwds_oc, new_pb_oc_ref, x.nargs)
end

# For misty closure reverse data - deep copy the captures data
_copy(r::MistyClosureRData) = MistyClosureRData(deepcopy(r.captures_rdata))

# For tangent types - copy conditionally based on initialization state
_copy(x::PossiblyUninitTangent) = is_init(x) ? typeof(x)(_copy(x.tangent)) : typeof(x)()

# For forwards/reverse data types - recursively copy wrapped data
_copy(x::FData) = typeof(x)(_copy(x.data))
_copy(x::RData) = typeof(x)(_copy(x.data))
_copy(x::LazyZeroRData) = typeof(x)(_copy(x.data))

# Fallback to Base.copy
_copy(x) = copy(x)
```
"""

# Generic implementations that work with any type
_copy(::Nothing) = nothing
_copy(x::Symbol) = x
_copy(x::Tuple) = map(_copy, x)
_copy(x::NamedTuple) = map(_copy, x)
_copy(x::Ref{T}) where {T} = isassigned(x) ? Ref{T}(_copy(x[])) : Ref{T}()
_copy(x::Type) = x

# Fallback to Base.copy for all other types
_copy(x) = copy(x)
Loading