diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index bd8ccf2869..9d4e266024 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -41,6 +41,7 @@ jobs: 'rrules/low_level_maths', 'rrules/memory', 'rrules/misc', + 'rrules/misty_closures', 'rrules/new', 'rrules/random', 'rrules/tasks', diff --git a/HISTORY.md b/HISTORY.md new file mode 100644 index 0000000000..59e82e1a11 --- /dev/null +++ b/HISTORY.md @@ -0,0 +1,11 @@ +# 0.4.143 + +## Public Interface +- Mooncake offers forward mode AD. +- Two new functions added to the public interface: `prepare_derivative_cache` and `value_and_derivative!!`. +- One new type added to the public interface: `Dual`. + +## Internals +- `get_interpreter` was previously a zero-arg function. Is now a unary function, called with a "mode" argument: `get_interpreter(ForwardMode)`, `get_interpreter(ReverseMode)`. +- `@zero_derivative` should now be preferred to `@zero_adjoint`. `@zero_adjoint` will be removed in 0.5. +- `@from_chainrules` should now be preferred to `@from_rrule`. `@from_rrule` will be removed in 0.5. diff --git a/Project.toml b/Project.toml index 35fab25768..de792a4e53 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.146" +version = "0.4.147" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -59,8 +59,9 @@ FunctionWrappers = "1.1.3" GPUArraysCore = "0.1, 0.2" Graphs = "1" InteractiveUtils = "1" -JET = "0.9, 0.10" +JET = "0.9" JuliaFormatter = "1.0, 2.1" +JuliaInterpreter = "0.9" LinearAlgebra = "1" LuxLib = "1" MistyClosures = "2" @@ -81,9 +82,10 @@ DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" +JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AllocCheck", "Aqua", "BenchmarkTools", "DiffTests", "JET", "JuliaFormatter", "Pkg", "StableRNGs", "Test"] +test = ["AllocCheck", "Aqua", "BenchmarkTools", "DiffTests", "JET", "JuliaFormatter", "JuliaInterpreter", "Pkg", "StableRNGs", "Test"] diff --git a/bench/run_benchmarks.jl b/bench/run_benchmarks.jl index 2f17b52583..01b492f14a 100644 --- a/bench/run_benchmarks.jl +++ b/bench/run_benchmarks.jl @@ -19,6 +19,7 @@ using AbstractGPs, Zygote using Mooncake: + Dual, CoDual, generate_hand_written_rrule!!_test_cases, generate_derived_rrule!!_test_cases, @@ -26,10 +27,13 @@ using Mooncake: _typeof, primal, tangent, + zero_dual, zero_codual using Mooncake.TestUtils: _deepcopy +to_benchmark(__frule!!::R, dx::Vararg{Dual,N}) where {R,N} = __frule!!(dx...) + function to_benchmark(__rrule!!::R, dx::Vararg{CoDual,N}) where {R,N} dx_f = Mooncake.tuple_map(x -> CoDual(primal(x), Mooncake.fdata(tangent(x))), dx) out, pb!! = __rrule!!(dx_f...) @@ -206,6 +210,20 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo evals=1, ) + # Benchmark AD via Mooncake. + @info "Mooncake (Forward)" + rule = Mooncake.build_frule(args...) + duals = map(x -> x isa CoDual ? Dual(x.x, x.dx) : zero_dual(x), args) + to_benchmark(rule, duals...) + include_other_frameworks && GC.gc(true) + suite["mooncake_fwd"] = Chairmarks.benchmark( + () -> (rule, duals), + identity, + a -> to_benchmark(a[1], a[2]...), + _ -> true; + evals=1, + ) + if include_other_frameworks if should_run_benchmark(Val(:zygote), args...) @info "Zygote" @@ -258,6 +276,7 @@ function combine_results(result, tag, _range, default_range) d = result[2] primal_time = minimum(d["primal"]).time mooncake_time = minimum(d["mooncake"]).time + mooncake_fwd_time = minimum(d["mooncake_fwd"]).time zygote_time = in("zygote", keys(d)) ? minimum(d["zygote"]).time : missing rd_time = in("rd", keys(d)) ? minimum(d["rd"]).time : missing ez_time = in("enzyme", keys(d)) ? minimum(d["enzyme"]).time : missing @@ -267,6 +286,8 @@ function combine_results(result, tag, _range, default_range) primal_time=primal_time, mooncake_time=mooncake_time, Mooncake=mooncake_time / primal_time, + mooncake_fwd_time=mooncake_fwd_time, + MooncakeFwd=mooncake_fwd_time / primal_time, zygote_time=zygote_time, Zygote=zygote_time / primal_time, rd_time=rd_time, @@ -348,7 +369,7 @@ end function create_inter_ad_benchmarks() results = benchmark_inter_framework_rules() - tools = [:Mooncake, :Zygote, :ReverseDiff, :Enzyme] + tools = [:Mooncake, :MooncakeFwd, :Zygote, :ReverseDiff, :Enzyme] df = DataFrame(results)[:, [:tag, :primal_time, tools...]] # Plot graph of results. diff --git a/ext/MooncakeSpecialFunctionsExt.jl b/ext/MooncakeSpecialFunctionsExt.jl index 65806ae833..0f28c373c8 100644 --- a/ext/MooncakeSpecialFunctionsExt.jl +++ b/ext/MooncakeSpecialFunctionsExt.jl @@ -3,44 +3,44 @@ module MooncakeSpecialFunctionsExt using SpecialFunctions, Mooncake using Base: IEEEFloat -import Mooncake: @from_rrule, DefaultCtx, @zero_adjoint +import Mooncake: DefaultCtx, @from_chainrules, @zero_derivative -@from_rrule DefaultCtx Tuple{typeof(airyai),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(airyaix),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(airyaiprime),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(airybi),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(airybiprime),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(besselj0),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(besselj1),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(bessely0),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(bessely1),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(dawson),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(digamma),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(erf),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(erf),IEEEFloat,IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(erfc),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(logerfc),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(erfcinv),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(erfcx),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(logerfcx),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(erfi),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(erfinv),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(gamma),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(invdigamma),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(trigamma),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(polygamma),Integer,IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(beta),IEEEFloat,IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(logbeta),IEEEFloat,IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(logabsgamma),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(loggamma),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(expint),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(expintx),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(expinti),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(sinint),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(cosint),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(ellipk),IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(ellipe),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(airyai),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(airyaix),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(airyaiprime),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(airybi),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(airybiprime),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(besselj0),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(besselj1),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(bessely0),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(bessely1),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(dawson),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(digamma),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(erf),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(erf),IEEEFloat,IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(erfc),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(logerfc),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(erfcinv),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(erfcx),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(logerfcx),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(erfi),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(erfinv),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(gamma),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(invdigamma),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(trigamma),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(polygamma),Integer,IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(beta),IEEEFloat,IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(logbeta),IEEEFloat,IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(logabsgamma),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(loggamma),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(expint),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(expintx),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(expinti),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(sinint),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(cosint),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(ellipk),IEEEFloat} +@from_chainrules DefaultCtx Tuple{typeof(ellipe),IEEEFloat} -@zero_adjoint DefaultCtx Tuple{typeof(logfactorial),Integer} +@zero_derivative DefaultCtx Tuple{typeof(logfactorial),Integer} end diff --git a/src/Mooncake.jl b/src/Mooncake.jl index 586ae1ed8c..afed343bb1 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -19,6 +19,7 @@ using Base: twiceprecision using Base.Experimental: @opaque using Base.Iterators: product +using Base.Meta: isexpr using Core: Intrinsics, bitcast, @@ -29,6 +30,8 @@ using Core: GotoIfNot, PhiNode, PiNode, + PhiCNode, + UpsilonNode, SSAValue, Argument, OpaqueClosure, @@ -43,6 +46,13 @@ using DispatchDoctor: @stable, @unstable # Needs to be defined before various other things. function _foreigncall_ end +""" + frule!!(f::Dual, x::Dual...) + +Performs AD in forward mode, possibly modifying the inputs, and returns a `Dual`. +""" +function frule!! end + """ rrule!!(f::CoDual, x::CoDual...) @@ -72,10 +82,11 @@ function rrule!! end build_primitive_rrule(sig::Type{<:Tuple}) Construct an rrule for signature `sig`. For this function to be called in `build_rrule`, you -must also ensure that `is_primitive(context_type, sig)` is `true`. The callable returned by -this must obey the rrule interface, but there are no restrictions on the type of callable -itself. For example, you might return a callable `struct`. By default, this function returns -`rrule!!` so, most of the time, you should just implement a method of `rrule!!`. +must also ensure that `is_primitive(context_type, ReverseMode, sig)` is `true`. The callable +returned by this must obey the rrule interface, but there are no restrictions on the type of +callable itself. For example, you might return a callable `struct`. By default, this +function returns `rrule!!` so, most of the time, you should just implement a method of +`rrule!!`. # Extended Help @@ -95,6 +106,7 @@ build_primitive_rrule(::Type{<:Tuple}) = rrule!! @stable default_mode = "disable" default_union_limit = 2 begin include("utils.jl") include("tangents.jl") +include("dual.jl") include("fwds_rvs_data.jl") include("codual.jl") include("debug_mode.jl") @@ -110,7 +122,8 @@ include(joinpath("interpreter", "patch_for_319.jl")) include(joinpath("interpreter", "ir_utils.jl")) include(joinpath("interpreter", "ir_normalisation.jl")) include(joinpath("interpreter", "zero_like_rdata.jl")) -include(joinpath("interpreter", "s2s_reverse_mode_ad.jl")) +include(joinpath("interpreter", "forward_mode.jl")) +include(joinpath("interpreter", "reverse_mode.jl")) end include("tools_for_rules.jl") @@ -129,6 +142,7 @@ include(joinpath("rrules", "lapack.jl")) include(joinpath("rrules", "linear_algebra.jl")) include(joinpath("rrules", "low_level_maths.jl")) include(joinpath("rrules", "misc.jl")) +include(joinpath("rrules", "misty_closures.jl")) include(joinpath("rrules", "new.jl")) include(joinpath("rrules", "random.jl")) include(joinpath("rrules", "tasks.jl")) @@ -146,12 +160,15 @@ include("developer_tools.jl") # Public, not exported include("public.jl") + end #! format: on -@public Config, value_and_pullback!!, prepare_pullback_cache +@public Config, value_and_pullback!!, prepare_pullback_cache, value_and_derivative!! +@public prepare_derivative_cache, Dual # Public, exported -export value_and_gradient!!, prepare_gradient_cache +export value_and_gradient!!, prepare_gradient_cache, value_and_derivative!! +export prepare_derivative_cache end diff --git a/src/codual.jl b/src/codual.jl index 994ac81d9a..adbc4d0881 100644 --- a/src/codual.jl +++ b/src/codual.jl @@ -17,6 +17,13 @@ tangent(x::CoDual) = x.dx Base.copy(x::CoDual) = CoDual(copy(primal(x)), copy(tangent(x))) _copy(x::P) where {P<:CoDual} = x +""" + extract(x::CoDual) + +Helper function. Returns the 2-tuple `x.x, x.dx`. +""" +extract(x::CoDual) = primal(x), tangent(x) + """ zero_codual(x) diff --git a/src/debug_mode.jl b/src/debug_mode.jl index 828a1c8bf4..338d0f9c4c 100644 --- a/src/debug_mode.jl +++ b/src/debug_mode.jl @@ -1,3 +1,5 @@ +# TODO: make it non-trivial. See https://github.com/chalk-lab/Mooncake.jl/issues/672 +DebugFRule(rule) = rule """ DebugPullback(pb, y, x) diff --git a/src/developer_tools.jl b/src/developer_tools.jl index e2ed013db0..8543b4a6ec 100644 --- a/src/developer_tools.jl +++ b/src/developer_tools.jl @@ -1,5 +1,5 @@ """ - primal_ir(sig::Type{<:Tuple}; interp=get_interpreter())::IRCode + primal_ir(interp::MooncakeInterpreter, sig::Type{<:Tuple})::IRCode !!! warning This is not part of the public interface of Mooncake. As such, it may change as @@ -11,16 +11,62 @@ Roughly equivalent to `Base.code_ircode_by_type(sig; interp)`. For example, if you wanted to get the IR associated to the call `map(sin, randn(10))`, you could do one of the following calls: ```jldoctest -julia> Mooncake.primal_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode +julia> using Mooncake: primal_ir, get_interpreter, ReverseMode + +julia> primal_ir(get_interpreter(ReverseMode), Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode true -julia> Mooncake.primal_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode +julia> primal_ir(get_interpreter(ReverseMode), typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode true ``` """ -function primal_ir(sig::Type{<:Tuple}; interp=get_interpreter())::IRCode +function primal_ir(interp::MooncakeInterpreter, sig::Type{<:Tuple})::IRCode return generate_ir(interp, sig).primal_ir end +""" + dual_ir( + sig::Type{<:Tuple}; + interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true, + ) + +!!! warning + This is not part of the public interface of Mooncake. As such, it may change as + part of a non-breaking release of the package. + + +Generate the `Core.Compiler.IRCode` used to perform forwards-mode AD. Take a look +at how `build_frule` makes use of `generate_dual_ir` to see exactly how this is used in +practice. + +For example, if you wanted to get the IR associated to forwards-mode AD for the call +`map(sin, randn(10))`, you could do either of the following: +```jldoctest +julia> Mooncake.dual_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode +true +julia> Mooncake.dual_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode +true +``` + +# Arguments +- `sig::Type{<:Tuple}`: the signature of the call to be differentiated. + +# Keyword Arguments +- `interp`: the interpreter to use to obtain the primal IR. +- `debug_mode::Bool`: whether the generated IR should make use of Mooncake's debug mode. +- `do_inline::Bool`: whether to apply an inlining pass prior to returning the ir generated + by this function. This is `true` by default, but setting it to `false` can sometimes be + helpful if you need to understand what function calls are generated in order to perform + AD, before lots of it gets inlined away. +""" +function dual_ir( + sig::Type{<:Tuple}; + interp=get_interpreter(ForwardMode), + debug_mode::Bool=false, + do_inline::Bool=true, +) + return generate_dual_ir(interp, sig; debug_mode, do_inline)[1] +end + """ fwd_ir( sig::Type{<:Tuple}; @@ -31,8 +77,9 @@ end This is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package. -Generate the `Core.Compiler.IRCode` used to construct the forwards-pass of AD. Take a look -at how `build_rrule` makes use of `generate_ir` to see exactly how this is used in practice. +Generate the `Core.Compiler.IRCode` used to construct the forwards-pass of reverse-mode AD. +Take a look at how `build_rrule` makes use of `generate_ir` to see exactly how this is used +in practice. For example, if you wanted to get the IR associated to the forwards pass for the call `map(sin, randn(10))`, you could do either of the following: @@ -56,7 +103,7 @@ true """ function fwd_ir( sig::Type{<:Tuple}; - interp=get_interpreter(), + interp=get_interpreter(ReverseMode), debug_mode::Bool=false, do_inline::Bool=true, )::IRCode @@ -73,8 +120,9 @@ end This is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package. -Generate the `Core.Compiler.IRCode` used to construct the reverse-pass of AD. Take a look -at how `build_rrule` makes use of `generate_ir` to see exactly how this is used in practice. +Generate the `Core.Compiler.IRCode` used to construct the reverse-pass of reverse-mode AD. +Take a look at how `build_rrule` makes use of `generate_ir` to see exactly how this is used +in practice. For example, if you wanted to get the IR associated to the reverse pass for the call `map(sin, randn(10))`, you could do either of the following: @@ -98,7 +146,7 @@ true """ function rvs_ir( sig::Type{<:Tuple}; - interp=get_interpreter(), + interp=get_interpreter(ReverseMode), debug_mode::Bool=false, do_inline::Bool=true, )::IRCode diff --git a/src/dual.jl b/src/dual.jl new file mode 100644 index 0000000000..400ae347ac --- /dev/null +++ b/src/dual.jl @@ -0,0 +1,56 @@ +""" + Dual(primal::P, tangent::T) + +Used to pair together a `primal` value and a `tangent` to it. In the context of foward mode +AD (aka computing Frechet derivatives), `primal` governs the point at which the derivative +is computed, and `tangent` the direction in which it is computed. + +Must satisfy `tangent_type(P) == T`. +""" +struct Dual{P,T} + primal::P + tangent::T +end + +primal(x::Dual) = x.primal +tangent(x::Dual) = x.tangent +Base.copy(x::Dual) = Dual(copy(primal(x)), copy(tangent(x))) +_copy(x::P) where {P<:Dual} = x + +""" + extract(x::Dual) + +Helper function. Returns the 2-tuple `x.x, x.dx`. +""" +extract(x::Dual) = primal(x), tangent(x) + +zero_dual(x) = Dual(x, zero_tangent(x)) +randn_dual(rng::AbstractRNG, x) = Dual(x, randn_tangent(rng, x)) + +function dual_type(::Type{P}) where {P} + P == DataType && return Dual + P isa Union && return Union{dual_type(P.a),dual_type(P.b)} + P <: UnionAll && return Dual # P is abstract, so we don't know its tangent type. + return isconcretetype(P) ? Dual{P,tangent_type(P)} : Dual +end + +function dual_type(p::Type{Type{P}}) where {P} + return @isdefined(P) ? Dual{Type{P},NoTangent} : Dual{_typeof(p),NoTangent} +end + +_primal(x) = x +_primal(x::Dual) = primal(x) + +""" + verify_dual_type(x::Dual) + +Check that the type of `tangent(x)` is the tangent type of the type of `primal(x)`. +""" +verify_dual_type(x::Dual) = tangent_type(typeof(primal(x))) == typeof(tangent(x)) + +@inline uninit_dual(x::P) where {P} = Dual(x, uninit_tangent(x)) + +# Always sharpen the first thing if it's a type so static dispatch remains possible. +function Dual(x::Type{P}, dx::NoTangent) where {P} + return Dual{@isdefined(P) ? Type{P} : typeof(x),NoTangent}(x, dx) +end diff --git a/src/interface.jl b/src/interface.jl index 0bdedb72d9..d2875148d6 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -198,7 +198,7 @@ end __exclude_unsupported_output(y) __exclude_func_with_unsupported_output(fx) -Required for the robust design of [`value_and_pullback!!`](@ref), [`prepare_pullback_cache`](@ref). +Required for the robust design of [`value_and_pullback!!`](@ref), [`prepare_pullback_cache`](@ref). Ensures that `y` or returned value of `fx::Tuple{Tf, Targs...}` contains no aliasing, circular references, `Ptr`s or non differentiable datatypes. In the forward pass f(args...) output can only return a "Tree" like datastructure with leaf nodes as primitive types. Refer https://github.com/chalk-lab/Mooncake.jl/issues/517#issuecomment-2715202789 and related issue for details. @@ -434,7 +434,8 @@ Returns a cache used with [`value_and_pullback!!`](@ref). See that function for __exclude_func_with_unsupported_output(fx) # Construct rule and tangents. - rule = build_rrule(get_interpreter(), Tuple{map(_typeof, fx)...}; kwargs...) + interp = get_interpreter(ReverseMode) + rule = build_rrule(interp, Tuple{map(_typeof, fx)...}; kwargs...) tangents = map(zero_tangent, fx) # Run the rule forwards -- this should do a decent chunk of pre-allocation. @@ -554,3 +555,19 @@ function value_and_gradient!!(cache::Cache, f::F, x::Vararg{Any,N}) where {F,N} coduals = tuple_map(CoDual, (f, x...), tuple_map(set_to_zero!!, cache.tangents)) return __value_and_gradient!!(cache.rule, coduals...) end + +""" + prepare_derivative_cache(f, x...) + +Returns a cache used with [`value_and_derivative!!`](@ref). See that function for more info. +""" +@unstable prepare_derivative_cache(fx...; kwargs...) = build_frule(fx...; kwargs...) + +""" + value_and_derivative!!(rule::R, f::Dual, x::Vararg{Dual,N}) + +Returns a `Dual` containing the result of applying forward-mode AD to compute the (Frechet) +derivative of `primal(f)` at the primal values in `x` in the direction of the tangent values +in `f` and `x`. +""" +value_and_derivative!!(rule::R, fx::Vararg{Dual,N}) where {R,N} = rule(fx...) diff --git a/src/interpreter/abstract_interpretation.jl b/src/interpreter/abstract_interpretation.jl index fc70125c4a..4e3071c54f 100644 --- a/src/interpreter/abstract_interpretation.jl +++ b/src/interpreter/abstract_interpretation.jl @@ -24,7 +24,7 @@ MooncakeCache() = MooncakeCache(IdDict{Core.MethodInstance,Core.CodeInstance}()) # The method table used by `Mooncake.@mooncake_overlay`. Base.Experimental.@MethodTable mooncake_method_table -struct MooncakeInterpreter{C} <: CC.AbstractInterpreter +struct MooncakeInterpreter{C,M<:Mode} <: CC.AbstractInterpreter meta # additional information world::UInt inf_params::CC.InferenceParams @@ -33,7 +33,8 @@ struct MooncakeInterpreter{C} <: CC.AbstractInterpreter code_cache::MooncakeCache oc_cache::Dict{ClosureCacheKey,Any} function MooncakeInterpreter( - ::Type{C}; + ::Type{C}, + ::Type{M}; meta=nothing, world::UInt=Base.get_world_counter(), inf_params::CC.InferenceParams=CC.InferenceParams(), @@ -41,8 +42,8 @@ struct MooncakeInterpreter{C} <: CC.AbstractInterpreter inf_cache::Vector{CC.InferenceResult}=CC.InferenceResult[], code_cache::MooncakeCache=MooncakeCache(), oc_cache::Dict{ClosureCacheKey,Any}=Dict{ClosureCacheKey,Any}(), - ) where {C} - ip = new{C}(meta, world, inf_params, opt_params, inf_cache, code_cache, oc_cache) + ) where {C,M<:Mode} + ip = new{C,M}(meta, world, inf_params, opt_params, inf_cache, code_cache, oc_cache) tts = Any[ Tuple{typeof(sum),Tuple{Int}}, Tuple{typeof(sum),Tuple{Int,Int}}, @@ -75,7 +76,7 @@ function _show_interp(io::IO, ::MIME"text/plain", ::MooncakeInterpreter) return print(io, "MooncakeInterpreter()") end -MooncakeInterpreter() = MooncakeInterpreter(DefaultCtx) +MooncakeInterpreter(M::Type{<:Mode}) = MooncakeInterpreter(DefaultCtx, M) context_type(::MooncakeInterpreter{C}) where {C} = C @@ -122,14 +123,14 @@ CC.getsplit_impl(info::NoInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx) CC.getresult_impl(info::NoInlineCallInfo, idx::Int) = CC.getresult(info.info, idx) function Core.Compiler.abstract_call_gf_by_type( - interp::MooncakeInterpreter{C}, + interp::MooncakeInterpreter{C,M}, @nospecialize(f), arginfo::CC.ArgInfo, si::CC.StmtInfo, @nospecialize(atype), sv::CC.AbsIntState, max_methods::Int, -) where {C} +) where {C,M} # invoke the default abstract call to get the default CC.CallMeta. cm = @invoke CC.abstract_call_gf_by_type( @@ -144,7 +145,7 @@ function Core.Compiler.abstract_call_gf_by_type( # Check to see whether the call in question is a Mooncake primitive. If it is, set its # call info such that in the `CC.inlining_policy` it is not inlined away. - callinfo = is_primitive(C, atype) ? NoInlineCallInfo(cm.info, atype) : cm.info + callinfo = is_primitive(C, M, atype) ? NoInlineCallInfo(cm.info, atype) : cm.info # Construct a CallMeta correctly depending on the version of Julia. @static if VERSION ≥ v"1.11-" @@ -196,23 +197,26 @@ else # 1.11 and up. end """ - const GLOBAL_INTERPRETER + const GLOBAL_INTERPRETERS -Globally cached interpreter. Should only be accessed via `get_interpreter`. +Cached interpreters. Should only be accessed via `get_interpreter`. """ -const GLOBAL_INTERPRETER = Ref(MooncakeInterpreter()) +const GLOBAL_INTERPRETERS = Dict( + ForwardMode => MooncakeInterpreter(DefaultCtx, ForwardMode), + ReverseMode => MooncakeInterpreter(DefaultCtx, ReverseMode), +) """ - get_interpreter() + get_interpreter(mode::Type{<:Mode}) Returns a `MooncakeInterpreter` appropriate for the current world age. Will use a cached interpreter if one already exists for the current world age, otherwise creates a new one. This should be prefered over constructing a `MooncakeInterpreter` directly. """ -function get_interpreter() - if GLOBAL_INTERPRETER[].world != Base.get_world_counter() - GLOBAL_INTERPRETER[] = MooncakeInterpreter() +function get_interpreter(mode::Type{<:Mode}) + if GLOBAL_INTERPRETERS[mode].world != Base.get_world_counter() + GLOBAL_INTERPRETERS[mode] = MooncakeInterpreter(DefaultCtx, mode) end - return GLOBAL_INTERPRETER[] + return GLOBAL_INTERPRETERS[mode] end diff --git a/src/interpreter/contexts.jl b/src/interpreter/contexts.jl index c150cc8973..3e839cbac0 100644 --- a/src/interpreter/contexts.jl +++ b/src/interpreter/contexts.jl @@ -17,13 +17,36 @@ performance, it should be a primitive in the DefaultCtx, but not the MinimalCtx. struct DefaultCtx end """ - is_primitive(::Type{Ctx}, sig) where {Ctx} + abstract type Mode end + +Subtypes of this signify which mode of AD is being considered. +""" +abstract type Mode end + +""" + struct ForwardMode end + +Used primarily as the second argument to [`is_primitive`](@ref) to determine whether a +function is a primitive in forwards-mode AD. +""" +struct ForwardMode <: Mode end + +""" + struct ReverseMode end + +Used primarily as the second argument to [`is_primitive`](@ref) to determine whether a +function is a primitive in reverse-mode AD. +""" +struct ReverseMode <: Mode end + +""" + is_primitive(::Type{Ctx}, ::Type{M}, sig) where {Ctx,M} Returns a `Bool` specifying whether the methods specified by `sig` are considered primitives -in the context of contexts of type `Ctx`. +in the context of contexts of type `Ctx` in mode `M`. ```julia -is_primitive(DefaultCtx, Tuple{typeof(sin), Float64}) +is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(sin), Float64}) ``` will return if calling `sin(5.0)` should be treated as primitive when the context is a `DefaultCtx`. @@ -32,24 +55,64 @@ Observe that this information means that whether or not something is a primitive particular context depends only on static information, not any run-time information that might live in a particular instance of `Ctx`. """ -is_primitive(::Type{MinimalCtx}, sig::Type{<:Tuple}) = false -is_primitive(::Type{DefaultCtx}, sig) = is_primitive(MinimalCtx, sig) +is_primitive(::Type{MinimalCtx}, ::Type{<:Mode}, sig::Type{<:Tuple}) = false +function is_primitive(::Type{DefaultCtx}, ::Type{M}, sig) where {M<:Mode} + return is_primitive(MinimalCtx, M, sig) +end """ - @is_primitive context_type signature + @is_primitive context_type [mode_type] signature -Creates a method of `is_primitive` which always returns `true` for the context_type and -`signature` provided. For example -```julia -@is_primitive MinimalCtx Tuple{typeof(foo), Float64} -``` -is equivalent to -```julia -is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(foo), Float64}}) = true +Creates a method of [`is_primitive`](@ref) which always returns `true` for the +`context_type`, and `signature` provided. For example +```jldoctest +julia> using Mooncake: DefaultCtx, @is_primitive, is_primitive, ForwardMode, ReverseMode + +julia> foo(x::Float64) = 2x +foo (generic function with 1 method) + +julia> @is_primitive DefaultCtx Tuple{typeof(foo),Float64} + +julia> is_primitive(DefaultCtx, ForwardMode, Tuple{typeof(foo),Float64}) +true + +julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo),Float64}) +true ``` +Observe that this means that a rule is a primitive in all AD modes. + +You should implement more complicated methods of [`is_primitive`](@ref) in the usual way. + +Optionally, you can specify that a rule is only a primitive in a particular mode, eg. +```jldoctest +julia> using Mooncake: DefaultCtx, @is_primitive, is_primitive, ForwardMode, ReverseMode + +julia> bar(x::Float64) = 2x +bar (generic function with 1 method) -You should implemented more complicated method of `is_primitive` in the usual way. +julia> @is_primitive DefaultCtx ForwardMode Tuple{typeof(bar),Float64} + +julia> is_primitive(DefaultCtx, ForwardMode, Tuple{typeof(bar),Float64}) +true + +julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(bar),Float64}) +false +``` """ macro is_primitive(Tctx, sig) - return :(Mooncake.is_primitive(::Type{$(esc(Tctx))}, ::Type{<:$(esc(sig))}) = true) + return _is_primitive_expression(Tctx, :(Mooncake.Mode), sig) +end + +macro is_primitive(Tctx, Tmode, sig) + return _is_primitive_expression(Tctx, esc(Tmode), sig) +end + +function _is_primitive_expression(Tctx, Tmode, sig) + return quote + function Mooncake.is_primitive( + ::Type{$(esc(Tctx))}, ::Type{<:$(Tmode)}, ::Type{<:$(esc(sig))} + ) + return true + end + end end diff --git a/src/interpreter/forward_mode.jl b/src/interpreter/forward_mode.jl new file mode 100644 index 0000000000..37774aa588 --- /dev/null +++ b/src/interpreter/forward_mode.jl @@ -0,0 +1,452 @@ +function build_frule(args...; debug_mode=false, silence_debug_messages=true) + sig = _typeof(TestUtils.__get_primals(args)) + interp = get_interpreter(ForwardMode) + return build_frule(interp, sig; debug_mode, silence_debug_messages) +end + +struct DualRuleInfo + isva::Bool + nargs::Int + dual_ret_type::Type +end + +function build_frule( + interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false, silence_debug_messages=true +) where {C} + @nospecialize sig_or_mi + + # To avoid segfaults, ensure that we bail out if the interpreter's world age is greater + # than the current world age. + if Base.get_world_counter() > interp.world + throw( + ArgumentError( + "World age associated to interp is behind current world age. Please " * + "create a new interpreter for the current world age.", + ), + ) + end + + # If we're compiling in debug mode, let the user know by default. + if !silence_debug_messages && debug_mode + @info "Compiling frule for $sig_or_mi in debug mode. Disable for best performance." + end + + # If we have a hand-coded rule, just use that. + sig = _get_sig(sig_or_mi) + is_primitive(C, ForwardMode, sig) && return (debug_mode ? DebugFRule(frule!!) : frule!!) + + # We don't have a hand-coded rule, so derive one. + lock(MOONCAKE_INFERENCE_LOCK) + try + # If we've already derived the OpaqueClosures and info, do not re-derive, just + # create a copy and pass in new shared data. + oc_cache_key = ClosureCacheKey(interp.world, (sig_or_mi, debug_mode, :forward)) + if haskey(interp.oc_cache, oc_cache_key) + return interp.oc_cache[oc_cache_key] + else + # Derive forward-pass IR, and shove in a `MistyClosure`. + dual_ir, captures, info = generate_dual_ir(interp, sig_or_mi; debug_mode) + dual_oc = misty_closure( + info.dual_ret_type, dual_ir, captures...; do_compile=true + ) + raw_rule = DerivedFRule{sig,typeof(dual_oc),info.isva,info.nargs}(dual_oc) + rule = debug_mode ? DebugFRule(raw_rule) : raw_rule + interp.oc_cache[oc_cache_key] = rule + return rule + end + catch e + rethrow(e) + finally + unlock(MOONCAKE_INFERENCE_LOCK) + end +end + +struct DerivedFRule{primal_sig,Tfwd_oc,isva,nargs} + fwd_oc::Tfwd_oc +end + +@inline function (fwd::DerivedFRule{P,sig,isva,nargs})( + args::Vararg{Dual,N} +) where {P,sig,N,isva,nargs} + return fwd.fwd_oc(__unflatten_dual_varargs(isva, args, Val(nargs))...) +end + +function _copy(x::P) where {P<:DerivedFRule} + return P(replace_captures(x.fwd_oc, _copy(x.fwd_oc.oc.captures))) +end + +""" + __unflatten_dual_varargs(isva::Bool, args, ::Val{nargs}) where {nargs} + +If isva and nargs=2, then inputs `(Dual(5.0, 0.0), Dual(4.0, 0.0), Dual(3.0, 0.0))` +are transformed into `(Dual(5.0, 0.0), Dual((5.0, 4.0), (0.0, 0.0)))`. +""" +function __unflatten_dual_varargs(isva::Bool, args, ::Val{nargs}) where {nargs} + isva || return args + group_primal = map(primal, args[nargs:end]) + if tangent_type(_typeof(group_primal)) == NoTangent + grouped_args = zero_dual(group_primal) + else + grouped_args = Dual(group_primal, map(tangent, args[nargs:end])) + end + return (args[1:(nargs - 1)]..., grouped_args) +end + +struct DualInfo + primal_ir::IRCode + interp::MooncakeInterpreter + is_used::Vector{Bool} + debug_mode::Bool +end + +function generate_dual_ir( + interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true +) + # Reset id count. This ensures that the IDs generated are the same each time this + # function runs. + seed_id!() + + # Grab code associated to the primal. + primal_ir, _ = lookup_ir(interp, sig_or_mi) + nargs = length(primal_ir.argtypes) + + # Normalise the IR. + isva, spnames = is_vararg_and_sparam_names(sig_or_mi) + primal_ir = normalise!(primal_ir, spnames) + + # Keep a copy of the primal IR with the insertions + dual_ir = CC.copy(primal_ir) + + # Modify dual argument types: + # - add one for the captures in the first position, with placeholder type for now + # - convert the rest to dual types + for (a, P) in enumerate(primal_ir.argtypes) + dual_ir.argtypes[a] = dual_type(CC.widenconst(P)) + end + pushfirst!(dual_ir.argtypes, Any) + + # Data structure into which we can push any data which is to live in the captures field + # of the OpaqueClosure used to implement this rule. The index at which a piece of data + # lives in this data structure is equal to the index of the captures field of the + # OpaqueClosure in which it will live. To write code which retrieves items from the + # captures data structure, make use of `get_capture`. + captures = Any[] + + is_used = characterised_used_ssas(stmt(primal_ir.stmts)) + info = DualInfo(primal_ir, interp, is_used, debug_mode) + for (n, inst) in enumerate(dual_ir.stmts) + ssa = SSAValue(n) + modify_fwd_ad_stmts!(stmt(inst), dual_ir, ssa, captures, info) + end + + # Process new nodes etc. + dual_ir = CC.compact!(dual_ir) + + CC.verify_ir(dual_ir) + + # Now that the captured values are known, replace the placeholder value given for the + # first argument type with the actual type. + captures_tuple = (captures...,) + dual_ir.argtypes[1] = _typeof(captures_tuple) + + # Optimize dual IR + dual_ir_opt = optimise_ir!(dual_ir; do_inline) + return dual_ir_opt, captures_tuple, DualRuleInfo(isva, nargs, dual_ret_type(primal_ir)) +end + +@inline get_capture(captures::T, n::Int) where {T} = captures[n] + +""" + const_dual!(captures::Vector{Any}, stmt)::Union{Dual,Int} + +Build a `Dual` from `stmt`, with zero / uninitialised tangent. If the resulting `Dual` is +a bits type, then it is returned. If it is not, then the `Dual` is put into captures, +and its location in `captures` returned. + +Whether or not the value is a literal, or an index into the captures, can be determined from +the return type. +""" +function const_dual!(captures::Vector{Any}, stmt)::Union{Dual,Int} + v = get_const_primal_value(stmt) + x = uninit_dual(v) + if safe_for_literal(v) + return x + else + push!(captures, x) + return length(captures) + end +end + +## Modification of IR nodes + +const ATTACH_AFTER = true +const ATTACH_BEFORE = false + +modify_fwd_ad_stmts!(::Nothing, ::IRCode, ::SSAValue, ::Vector{Any}, ::DualInfo) = nothing + +modify_fwd_ad_stmts!(::GotoNode, ::IRCode, ::SSAValue, ::Vector{Any}, ::DualInfo) = nothing + +function modify_fwd_ad_stmts!( + stmt::GotoIfNot, dual_ir::IRCode, ssa::SSAValue, captures::Vector{Any}, info::DualInfo +) + # replace GotoIfNot with the call to primal + Mooncake.replace_call!(dual_ir, ssa, Expr(:call, _primal, inc_args(stmt).cond)) + + # reinsert the GotoIfNot right after the call to primal + new_gotoifnot_inst = new_inst(Core.GotoIfNot(ssa, stmt.dest)) + CC.insert_node!(dual_ir, ssa, new_gotoifnot_inst, ATTACH_AFTER) + return nothing +end + +function modify_fwd_ad_stmts!( + stmt::GlobalRef, dual_ir::IRCode, ssa::SSAValue, captures::Vector{Any}, ::DualInfo +) + if isconst(stmt) + d = const_dual!(captures, stmt) + if d isa Int + Mooncake.replace_call!(dual_ir, ssa, Expr(:call, get_capture, Argument(1), d)) + else + Mooncake.replace_call!(dual_ir, ssa, Expr(:call, identity, d)) + end + else + new_ssa = CC.insert_node!(dual_ir, ssa, new_inst(stmt), ATTACH_BEFORE) + zero_dual_call = Expr(:call, Mooncake.zero_dual, new_ssa) + Mooncake.replace_call!(dual_ir, ssa, zero_dual_call) + end + + return nothing +end + +function modify_fwd_ad_stmts!( + stmt::ReturnNode, dual_ir::IRCode, ssa::SSAValue, captures::Vector{Any}, ::DualInfo +) + # undefined `val` field means that stmt is unreachable. + isdefined(stmt, :val) || return nothing + + # stmt is an Argument, then already a dual, and must just be incremented. + if stmt.val isa Union{Argument,SSAValue} + Mooncake.replace_call!(dual_ir, ssa, ReturnNode(__inc(stmt.val))) + return nothing + end + + # stmt is a const, so we have to turn it into a dual. + dual_stmt = ReturnNode(const_dual!(captures, stmt.val)) + Mooncake.replace_call!(dual_ir, ssa, dual_stmt) + return nothing +end + +function modify_fwd_ad_stmts!( + stmt::PhiNode, dual_ir::IRCode, ssa::SSAValue, captures::Vector{Any}, ::DualInfo +) + for n in eachindex(stmt.values) + isassigned(stmt.values, n) || continue + stmt.values[n] isa Union{Argument,SSAValue} && continue + stmt.values[n] = uninit_dual(get_const_primal_value(stmt.values[n])) + end + set_stmt!(dual_ir, ssa, inc_args(stmt)) + set_ir!(dual_ir, ssa, :type, dual_type(CC.widenconst(get_ir(dual_ir, ssa, :type)))) + return nothing +end + +function modify_fwd_ad_stmts!( + stmt::PiNode, dual_ir::IRCode, ssa::SSAValue, ::Vector{Any}, ::DualInfo +) + if stmt.val isa Union{Argument,SSAValue} + v = __inc(stmt.val) + else + v = uninit_dual(get_const_primal_value(stmt.val)) + end + replace_call!(dual_ir, ssa, PiNode(v, dual_type(CC.widenconst(stmt.typ)))) + return nothing +end + +function modify_fwd_ad_stmts!( + stmt::UpsilonNode, dual_ir::IRCode, ssa::SSAValue, captures::Vector{Any}, ::DualInfo +) + if !(stmt.val isa Union{Argument,SSAValue}) + stmt = UpsilonNode(uninit_dual(get_const_primal_value(stmt.val))) + end + set_stmt!(dual_ir, ssa, inc_args(stmt)) + set_ir!(dual_ir, ssa, :type, dual_type(CC.widenconst(get_ir(dual_ir, ssa, :type)))) + return nothing +end + +function modify_fwd_ad_stmts!( + stmt::PhiCNode, dual_ir::IRCode, ssa::SSAValue, captures::Vector{Any}, ::DualInfo +) + for n in eachindex(stmt.values) + isassigned(stmt.values, n) || continue + stmt.values[n] isa Union{Argument,SSAValue} && continue + stmt.values[n] = uninit_dual(get_const_primal_value(stmt.values[n])) + end + set_stmt!(dual_ir, ssa, inc_args(stmt)) + set_ir!(dual_ir, ssa, :type, dual_type(CC.widenconst(get_ir(dual_ir, ssa, :type)))) + return nothing +end + +@static if isdefined(Core, :EnterNode) + function modify_fwd_ad_stmts!( + ::Core.EnterNode, ::IRCode, ::SSAValue, ::Vector{Any}, ::DualInfo + ) + return nothing + end +end + +## Modification of IR nodes - expressions + +__get_primal(x::Dual) = primal(x) + +function modify_fwd_ad_stmts!( + stmt::Expr, dual_ir::IRCode, ssa::SSAValue, captures::Vector{Any}, info::DualInfo +) + if isexpr(stmt, :invoke) || isexpr(stmt, :call) + raw_args = isexpr(stmt, :invoke) ? stmt.args[2:end] : stmt.args + sig_types = map(raw_args) do x + return CC.widenconst(get_forward_primal_type(info.primal_ir, x)) + end + sig = Tuple{sig_types...} + mi = isexpr(stmt, :invoke) ? stmt.args[1] : missing + args = map(__inc, raw_args) + + # Special case: if the result of a call to getfield is un-used, then leave the + # primal statment alone (just increment arguments as usual). This was causing + # performance problems in a couple of situations where the field being requested is + # not known at compile time. `getfield` cannot be dead-code eliminated, because it + # can throw an error if the requested field does not exist. Everything _other_ than + # the boundscheck is eliminated in LLVM codegen, so it's important that AD doesn't + # get in the way of this. + # + # This might need to be generalised to more things than just `getfield`, but at the + # time of writing this comment, it's unclear whether or not this is the case. + if !info.is_used[ssa.id] && get_const_primal_value(args[1]) == getfield + fwds = new_inst(Expr(:call, __fwds_pass_no_ad!, args...)) + replace_call!(dual_ir, ssa, fwds) + return nothing + end + + # Dual-ise arguments. + dual_args = map(args) do arg + arg isa Union{Argument,SSAValue} && return arg + return uninit_dual(get_const_primal_value(arg)) + end + + if is_primitive(context_type(info.interp), ForwardMode, sig) + replace_call!(dual_ir, ssa, Expr(:call, frule!!, dual_args...)) + else + dm = info.debug_mode + push!(captures, isexpr(stmt, :invoke) ? LazyFRule(mi, dm) : DynamicFRule(dm)) + get_rule = Expr(:call, get_capture, Argument(1), length(captures)) + rule_ssa = CC.insert_node!(dual_ir, ssa, new_inst(get_rule), ATTACH_BEFORE) + replace_call!(dual_ir, ssa, Expr(:call, rule_ssa, dual_args...)) + end + elseif isexpr(stmt, :boundscheck) + # Keep the boundscheck, but put it in a Dual. + inst = CC.NewInstruction(get_ir(info.primal_ir, ssa)) + bc_ssa = CC.insert_node!(dual_ir, ssa, inst, ATTACH_BEFORE) + replace_call!(dual_ir, ssa, Expr(:call, zero_dual, bc_ssa)) + elseif isexpr(stmt, :code_coverage_effect) + replace_call!(dual_ir, ssa, nothing) + elseif Meta.isexpr(stmt, :copyast) + new_copyast_inst = CC.NewInstruction(get_ir(info.primal_ir, ssa)) + new_copyast_ssa = CC.insert_node!(dual_ir, ssa, new_copyast_inst, ATTACH_BEFORE) + replace_call!(dual_ir, ssa, Expr(:call, zero_dual, new_copyast_ssa)) + elseif Meta.isexpr(stmt, :loopinfo) + # Leave this node alone. + elseif isexpr(stmt, :throw_undef_if_not) + # args[1] is a Symbol, args[2] is the condition which must be primalized + primal_cond = Expr(:call, _primal, inc_args(stmt).args[2]) + replace_call!(dual_ir, ssa, primal_cond) + new_undef_inst = new_inst(Expr(:throw_undef_if_not, stmt.args[1], ssa)) + CC.insert_node!(dual_ir, ssa, new_undef_inst, ATTACH_AFTER) + elseif isexpr(stmt, :enter) + # Leave this node alone + elseif isexpr(stmt, :leave) + # Leave this node alone + elseif isexpr(stmt, :pop_exception) + # Leave this node alone + else + msg = "Expressions of type `:$(stmt.head)` are not yet supported in forward mode" + throw(ArgumentError(msg)) + end + return nothing +end + +get_forward_primal_type(ir::CC.IRCode, a::Argument) = ir.argtypes[a.n] +get_forward_primal_type(ir::CC.IRCode, ssa::SSAValue) = get_ir(ir, ssa, :type) +get_forward_primal_type(::CC.IRCode, x::QuoteNode) = _typeof(x.value) +get_forward_primal_type(::CC.IRCode, x) = _typeof(x) +function get_forward_primal_type(::CC.IRCode, x::GlobalRef) + return isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty +end +function get_forward_primal_type(::CC.IRCode, x::Expr) + x.head === :boundscheck && return Bool + return error("Unrecognised expression $x found in argument slot.") +end + +mutable struct LazyFRule{primal_sig,Trule} + debug_mode::Bool + mi::Core.MethodInstance + rule::Trule + function LazyFRule(mi::Core.MethodInstance, debug_mode::Bool) + interp = get_interpreter(ForwardMode) + return new{mi.specTypes,frule_type(interp, mi;debug_mode)}(debug_mode, mi) + end + function LazyFRule{Tprimal_sig,Trule}( + mi::Core.MethodInstance, debug_mode::Bool + ) where {Tprimal_sig,Trule} + return new{Tprimal_sig,Trule}(debug_mode, mi) + end +end + +_copy(x::P) where {P<:LazyFRule} = P(x.mi, x.debug_mode) + +@inline function (rule::LazyFRule)(args::Vararg{Any,N}) where {N} + return isdefined(rule, :rule) ? rule.rule(args...) : _build_rule!(rule, args) +end + +@noinline function _build_rule!(rule::LazyFRule{sig,Trule}, args) where {sig,Trule} + interp = get_interpreter(ForwardMode) + rule.rule = build_frule(interp, rule.mi; debug_mode=rule.debug_mode) + return rule.rule(args...) +end + +function dual_ret_type(primal_ir::IRCode) + return dual_type(Base.Experimental.compute_ir_rettype(primal_ir)) +end + +function frule_type( + interp::MooncakeInterpreter{C}, mi::CC.MethodInstance; debug_mode +) where {C} + primal_sig = _get_sig(mi) + if is_primitive(C, ForwardMode, primal_sig) + return debug_mode ? DebugFRule{typeof(frule!!)} : typeof(frule!!) + end + ir, _ = lookup_ir(interp, mi) + nargs = length(ir.argtypes) + isva, _ = is_vararg_and_sparam_names(mi) + arg_types = map(CC.widenconst, ir.argtypes) + dual_args_type = Tuple{map(dual_type, arg_types)...} + closure_type = RuleMC{dual_args_type,dual_ret_type(ir)} + Tderived_rule = DerivedFRule{primal_sig,closure_type,isva,nargs} + return debug_mode ? DebugFRule{Tderived_rule} : Tderived_rule +end + +struct DynamicFRule{V} + cache::V + debug_mode::Bool +end + +DynamicFRule(debug_mode::Bool) = DynamicFRule(Dict{Any,Any}(), debug_mode) + +_copy(x::P) where {P<:DynamicFRule} = P(Dict{Any,Any}(), x.debug_mode) + +function (dynamic_rule::DynamicFRule)(args::Vararg{Dual,N}) where {N} + sig = Tuple{map(_typeof ∘ primal, args)...} + rule = get(dynamic_rule.cache, sig, nothing) + if rule === nothing + interp = get_interpreter(ForwardMode) + rule = build_frule(interp, sig; debug_mode=dynamic_rule.debug_mode) + dynamic_rule.cache[sig] = rule + end + return rule(args...) +end diff --git a/src/interpreter/ir_normalisation.jl b/src/interpreter/ir_normalisation.jl index 8321998687..66794895c0 100644 --- a/src/interpreter/ir_normalisation.jl +++ b/src/interpreter/ir_normalisation.jl @@ -349,6 +349,7 @@ until the pullback that it returns is run. @is_primitive MinimalCtx Tuple{typeof(gc_preserve),Vararg{Any,N}} where {N} +frule!!(::Dual{typeof(gc_preserve)}, ::Vararg{Dual,N}) where {N} = zero_dual(nothing) function rrule!!(f::CoDual{typeof(gc_preserve)}, xs::CoDual...) pb = NoPullback(f, xs...) gc_preserve_pb!!(::NoRData) = GC.@preserve xs pb(NoRData()) diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index aebfc52fde..385c203707 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -1,10 +1,39 @@ +stmt_field_name() = @static VERSION < v"1.11" ? :inst : :stmt + """ stmt(ir::CC.InstructionStream) Get the field containing the instructions in `ir`. This changed name in 1.11 from `inst` to `stmt`. """ -stmt(ir::CC.InstructionStream) = @static VERSION < v"1.11.0-rc4" ? ir.inst : ir.stmt +stmt(ir::CC.InstructionStream) = CC.getfield(ir, stmt_field_name()) + +""" + stmt(x::CC.Instruction) + +Get the statement from `x`. This field changed name in 1.11 from `inst` to `stmt`. +""" +stmt(x::CC.Instruction) = CC.getindex(x, stmt_field_name()) + +set_stmt!(ir::IRCode, ssa::SSAValue, a) = set_ir!(ir, ssa, stmt_field_name(), a) + +get_ir(ir::IRCode, idx::SSAValue) = CC.getindex(ir, idx) +get_ir(ir::IRCode, idx::SSAValue, name::Symbol) = CC.getindex(get_ir(ir, idx), name) + +""" + +""" +function set_ir!(ir::IRCode, idx::SSAValue, name::Symbol, value) + return CC.setindex!(CC.getindex(ir, idx), value, name) +end + +function replace_call!(ir, idx::SSAValue, new_call) + set_ir!(ir, idx, :inst, new_call) + set_ir!(ir, idx, :type, Any) + set_ir!(ir, idx, :info, CC.NoCallInfo()) + set_ir!(ir, idx, :flag, CC.IR_FLAG_REFINED) + return nothing +end """ ircode( @@ -220,6 +249,12 @@ function lookup_ir( return CC.typeinf_ircode(interp, mi.def, mi.specTypes, mi.sparam_vals, optimize_until) end +function lookup_ir(::CC.AbstractInterpreter, mc::MistyClosure; optimize_until=nothing) + return mc.ir[], return_type(mc.oc) +end + +return_type(::Core.OpaqueClosure{A,B}) where {A,B} = B + """ is_unreachable_return_node(x::ReturnNode) @@ -268,3 +303,31 @@ function replace_uses_with!(stmt, def::Union{Argument,SSAValue}, val) return stmt end end + +""" + characterised_used_ssas(stmts::Vector{Any})::Vector{Bool} + +For each statement in `stmts`, determine whether the SSAValue that it corresponds to has +any uses in the other statements. In particular, if `SSAValue(n)` has any uses in `stmts`, +the `n`th element of the returned `Vector{Bool}` will be `true`, and `false` otherwise. + +This function will usually be applied to the `stmts` field of an `CC.InstructionStream`. +""" +function characterised_used_ssas(stmts::Vector{Any})::Vector{Bool} + is_used = fill(false, length(stmts)) + for stmt in stmts + + # Manually written-out iteration to avoid Core.Compiler type piracy. + urs = CC.userefs(stmt) + v = CC.iterate(urs) + while v !== nothing + (use_ref, state) = v + use = CC.getindex(use_ref) + if use isa SSAValue + is_used[use.id] = true + end + v = CC.iterate(urs, state) + end + end + return is_used +end diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/reverse_mode.jl similarity index 97% rename from src/interpreter/s2s_reverse_mode_ad.jl rename to src/interpreter/reverse_mode.jl index 8e5d9386f8..af989bc248 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/reverse_mode.jl @@ -22,7 +22,7 @@ Puts `data` into `p`, and returns the `id` associated to it. This `id` should be be available during the forwards- and reverse-passes of AD, and it should further be assumed that the value associated to this `id` is always `data`. """ -function add_data!(p::SharedDataPairs, data)::ID +function add_data!(p::SharedDataPairs, data::Any)::ID id = ID() push!(p.pairs, (id, data)) return id @@ -68,7 +68,7 @@ function shared_data_stmts(p::SharedDataPairs)::Vector{IDInstPair} return (p[1], new_inst(Expr(:call, get_shared_data_field, Argument(1), n))) end end - +# maybe manually inline this @inline get_shared_data_field(shared_data, n) = getfield(shared_data, n) """ @@ -203,7 +203,7 @@ end Equivalent to `add_data!(info.shared_data_pairs, data)`. """ -add_data!(info::ADInfo, data)::ID = add_data!(info.shared_data_pairs, data) +add_data!(info::ADInfo, @nospecialize(data))::ID = add_data!(info.shared_data_pairs, data) """ add_data_if_not_singleton!(p::Union{ADInfo, SharedDataPairs}, x) @@ -212,7 +212,7 @@ Returns `x` if it is a singleton, or the `ID` of the ssa which will contain it o forwards- and reverse-passes. The reason for this is that if something is a singleton, it can be inserted directly into the IR. """ -function add_data_if_not_singleton!(p::Union{ADInfo,SharedDataPairs}, x) +function add_data_if_not_singleton!(p::Union{ADInfo,SharedDataPairs}, @nospecialize(x)) return Base.issingletontype(_typeof(x)) ? x : add_data!(p, x) end @@ -231,7 +231,7 @@ Returns the static / inferred type associated to `x`. get_primal_type(info::ADInfo, x::Argument) = info.arg_types[x] get_primal_type(info::ADInfo, x::ID) = CC.widenconst(info.ssa_insts[x].type) get_primal_type(::ADInfo, x::QuoteNode) = _typeof(x.value) -get_primal_type(::ADInfo, x) = _typeof(x) +get_primal_type(::ADInfo, @nospecialize(x)) = _typeof(x) function get_primal_type(::ADInfo, x::GlobalRef) return isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty end @@ -257,7 +257,7 @@ Create the `:new` statements which initialise the reverse-data `Ref`s. Interpola initial rdata directly into the statement, which is safe because it is always a bits type. """ function reverse_data_ref_stmts(info::ADInfo) - function make_ref_stmt(id, P) + function make_ref_stmt(id::ID, P::Type) ref_type = Base.RefValue{P<:Type ? NoRData : zero_like_rdata_type(P)} init_ref_val = P <: Type ? NoRData() : Mooncake.zero_like_rdata_from_type(P) return (id, new_inst(Expr(:new, ref_type, QuoteNode(init_ref_val)))) @@ -361,19 +361,30 @@ Used in `make_ad_stmts!`. """ inc_args(x::Expr) = Expr(x.head, map(__inc, x.args)...) inc_args(x::ReturnNode) = isdefined(x, :val) ? ReturnNode(__inc(x.val)) : x -inc_args(x::IDGotoIfNot) = IDGotoIfNot(__inc(x.cond), x.dest) -inc_args(x::IDGotoNode) = x -function inc_args(x::IDPhiNode) +inc_args(x::Union{GotoIfNot,IDGotoIfNot}) = typeof(x)(__inc(x.cond), x.dest) +inc_args(x::Union{GotoNode,IDGotoNode}) = x +function inc_args(x::T) where {T<:Union{IDPhiNode,PhiNode}} new_values = Vector{Any}(undef, length(x.values)) for n in eachindex(x.values) if isassigned(x.values, n) new_values[n] = __inc(x.values[n]) end end - return IDPhiNode(x.edges, new_values) + return T(x.edges, new_values) end inc_args(::Nothing) = nothing inc_args(x::GlobalRef) = x +inc_args(x::PiNode) = PiNode(__inc(x.val), x.typ) +function inc_args(x::PhiCNode) + new_values = Vector{Any}(undef, length(x.values)) + for n in eachindex(x.values) + if isassigned(x.values, n) + new_values[n] = __inc(x.values[n]) + end + end + return PhiCNode(new_values) +end +inc_args(x::UpsilonNode) = UpsilonNode(__inc(x.val)) __inc(x::Argument) = Argument(x.n + 1) __inc(x) = x @@ -503,7 +514,7 @@ function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo) P = get_primal_type(info, line) val_rdata_ref_id = get_rev_data_id(info, stmt.val) output_rdata_ref_id = get_rev_data_id(info, line) - fwds = PiNode(__inc(stmt.val), fcodual_type(CC.widenconst(stmt.typ))) + fwds = inc_args(PiNode(stmt.val, fcodual_type(CC.widenconst(stmt.typ)))) # Get the rdata from the output_rdata_ref, and set its new value to zero, and # increment the output ref. @@ -618,6 +629,7 @@ function get_const_primal_value(x::GlobalRef) return getglobal(x.mod, x.name) end get_const_primal_value(x::QuoteNode) = x.value +get_const_primal_value(x::Expr) = eval(x) get_const_primal_value(x) = x # Mooncake does not yet handle `PhiCNode`s. Throw an error if one is encountered. @@ -665,7 +677,7 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) # Construct signature, and determine how the rrule is to be computed. sig = Tuple{arg_types...} - raw_rule = if is_primitive(context_type(info.interp), sig) + raw_rule = if is_primitive(context_type(info.interp), ReverseMode, sig) rrule!! # intrinsic / builtin / thing we provably have rule for elseif is_invoke mi = stmt.args[1]::Core.MethodInstance @@ -916,7 +928,7 @@ function nvargs(pb::Pullback{sig}) where {sig} return Val{_isva(pb) ? _nargs(pb) - length(sig.parameters) + 1 : 0} end -@inline (pb::Pullback)(dy) = __flatten_varargs(_isva(pb), pb.pb_oc[].oc(dy), nvargs(pb)()) +@inline (pb::Pullback)(dy) = __flatten_varargs(_isva(pb), pb.pb_oc[](dy), nvargs(pb)()) struct DerivedRule{Tprimal,Tfwd_args,Tfwd_ret,Tpb_args,Tpb_ret,isva,Tnargs<:Val} fwds_oc::RuleMC{Tfwd_args,Tfwd_ret} @@ -963,7 +975,7 @@ _copy(x) = copy(x) @inline function (fwds::DerivedRule{sig})(args::Vararg{CoDual,N}) where {sig,N} uf_args = __unflatten_codual_varargs(_isva(fwds), args, fwds.nargs) pb = Pullback(sig, fwds.pb_oc_ref, _isva(fwds), N) - return fwds.fwds_oc.oc(uf_args...)::CoDual, pb + return fwds.fwds_oc(uf_args...)::CoDual, pb end """ @@ -1000,6 +1012,7 @@ end _get_sig(sig::Type) = sig _get_sig(mi::Core.MethodInstance) = mi.specTypes +_get_sig(mc::MistyClosure) = Tuple{map(CC.widenconst, mc.ir[].argtypes)...} function forwards_ret_type(primal_ir::IRCode) return fcodual_type(Base.Experimental.compute_ir_rettype(primal_ir)) @@ -1042,16 +1055,19 @@ Helper method: equivalent to extracting the signature from `args` and calling `build_rrule(sig; kwargs...)`. """ function build_rrule(args...; kwargs...) - interp = get_interpreter() + interp = get_interpreter(ReverseMode) return build_rrule(interp, _typeof(TestUtils.__get_primals(args)); kwargs...) end """ build_rrule(sig::Type{<:Tuple}; kwargs...) -Helper method: Equivalent to `build_rrule(Mooncake.get_interpreter(), sig; kwargs...)`. +Helper method: Equivalent to +`build_rrule(Mooncake.get_interpreter(ReverseMode), sig; kwargs...)`. """ -build_rrule(sig::Type{<:Tuple}; kwargs...) = build_rrule(get_interpreter(), sig; kwargs...) +function build_rrule(sig::Type{<:Tuple}; kwargs...) + return build_rrule(get_interpreter(ReverseMode), sig; kwargs...) +end const MOONCAKE_INFERENCE_LOCK = ReentrantLock() @@ -1096,7 +1112,7 @@ function build_rrule( # If we have a hand-coded rule, just use that. sig = _get_sig(sig_or_mi) - if is_primitive(C, sig) + if is_primitive(C, ReverseMode, sig) rule = build_primitive_rrule(sig) return (debug_mode ? DebugRRule(rule) : rule) end @@ -1106,7 +1122,7 @@ function build_rrule( try # If we've already derived the OpaqueClosures and info, do not re-derive, just # create a copy and pass in new shared data. - oc_cache_key = ClosureCacheKey(interp.world, (sig_or_mi, debug_mode)) + oc_cache_key = ClosureCacheKey(interp.world, (sig_or_mi, debug_mode, :reverse)) if haskey(interp.oc_cache, oc_cache_key) return _copy(interp.oc_cache[oc_cache_key]) else @@ -1733,7 +1749,8 @@ function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any,N}) where {N} sig = Tuple{map(_typeof ∘ primal, args)...} rule = get(dynamic_rule.cache, sig, nothing) if rule === nothing - rule = build_rrule(get_interpreter(), sig; debug_mode=dynamic_rule.debug_mode) + interp = get_interpreter(ReverseMode) + rule = build_rrule(interp, sig; debug_mode=dynamic_rule.debug_mode) dynamic_rule.cache[sig] = rule end return rule(args...) @@ -1806,7 +1823,7 @@ mutable struct LazyDerivedRule{primal_sig,Trule} mi::Core.MethodInstance rule::Trule function LazyDerivedRule(mi::Core.MethodInstance, debug_mode::Bool) - interp = get_interpreter() + interp = get_interpreter(ReverseMode) return new{mi.specTypes,rule_type(interp, mi;debug_mode)}(debug_mode, mi) end function LazyDerivedRule{Tprimal_sig,Trule}( @@ -1823,7 +1840,8 @@ _copy(x::P) where {P<:LazyDerivedRule} = P(x.mi, x.debug_mode) end @noinline function _build_rule!(rule::LazyDerivedRule{sig,Trule}, args) where {sig,Trule} - rule.rule = build_rrule(get_interpreter(), rule.mi; debug_mode=rule.debug_mode) + interp = get_interpreter(ReverseMode) + rule.rule = build_rrule(interp, rule.mi; debug_mode=rule.debug_mode) return rule.rule(args...) end @@ -1835,7 +1853,7 @@ important for performance in dynamic dispatch, and to ensure that recursion work properly. """ function rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C} - if is_primitive(C, _get_sig(sig_or_mi)) + if is_primitive(C, ReverseMode, _get_sig(sig_or_mi)) rule = build_primitive_rrule(_get_sig(sig_or_mi)) return debug_mode ? DebugRRule{typeof(rule)} : typeof(rule) end diff --git a/src/rrules/array_legacy.jl b/src/rrules/array_legacy.jl index 39e1e81ec0..f1e27e431c 100644 --- a/src/rrules/array_legacy.jl +++ b/src/rrules/array_legacy.jl @@ -68,11 +68,16 @@ function _diff_internal(c::MaybeCache, p::P, q::P) where {V,N,P<:Array{V,N}} return _map_if_assigned!((p, q) -> _diff_internal(c, p, q), t, p, q) end -@zero_adjoint MinimalCtx Tuple{Type{<:Array{T,N}},typeof(undef),Vararg} where {T,N} -@zero_adjoint MinimalCtx Tuple{Type{<:Array{T,N}},typeof(undef),Tuple{}} where {T,N} -@zero_adjoint MinimalCtx Tuple{Type{<:Array{T,N}},typeof(undef),NTuple{N}} where {T,N} +@zero_derivative MinimalCtx Tuple{Type{<:Array{T,N}},typeof(undef),Vararg} where {T,N} +@zero_derivative MinimalCtx Tuple{Type{<:Array{T,N}},typeof(undef),Tuple{}} where {T,N} +@zero_derivative MinimalCtx Tuple{Type{<:Array{T,N}},typeof(undef),NTuple{N}} where {T,N} @is_primitive MinimalCtx Tuple{typeof(Base._deletebeg!),Vector,Integer} +function frule!!(::Dual{typeof(Base._deletebeg!)}, a::Dual{<:Vector}, d::Dual{<:Integer}) + Base._deletebeg!(primal(a), primal(d)) + Base._deletebeg!(tangent(a), primal(d)) + return zero_dual(nothing) +end function rrule!!( ::CoDual{typeof(Base._deletebeg!)}, _a::CoDual{<:Vector}, _delta::CoDual{<:Integer} ) @@ -95,6 +100,11 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{typeof(Base._deleteend!),Vector,Integer} +function frule!!(::Dual{typeof(Base._deleteend!)}, a::Dual{<:Vector}, d::Dual{<:Integer}) + Base._deleteend!(primal(a), primal(d)) + Base._deleteend!(tangent(a), primal(d)) + return zero_dual(nothing) +end function rrule!!( ::CoDual{typeof(Base._deleteend!)}, _a::CoDual{<:Vector}, _delta::CoDual{<:Integer} ) @@ -124,6 +134,16 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{typeof(Base._deleteat!),Vector,Integer,Integer} +function frule!!( + ::Dual{typeof(Base._deleteat!)}, + a::Dual{<:Vector}, + i::Dual{<:Integer}, + delta::Dual{<:Integer}, +) + Base._deleteat!(primal(a), primal(i), primal(delta)) + Base._deleteat!(tangent(a), primal(i), primal(delta)) + return zero_dual(nothing) +end function rrule!!( ::CoDual{typeof(Base._deleteat!)}, _a::CoDual{<:Vector}, @@ -152,6 +172,13 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{typeof(Base._growbeg!),Vector,Integer} +function frule!!( + ::Dual{typeof(Base._growbeg!)}, a::Dual{<:Vector{T}}, d::Dual{<:Integer} +) where {T} + Base._growbeg!(primal(a), primal(d)) + Base._growbeg!(tangent(a), primal(d)) + return zero_dual(nothing) +end function rrule!!( ::CoDual{typeof(Base._growbeg!)}, _a::CoDual{<:Vector{T}}, _delta::CoDual{<:Integer} ) where {T} @@ -169,6 +196,11 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{typeof(Base._growend!),Vector,Integer} +function frule!!(::Dual{typeof(Base._growend!)}, a::Dual{<:Vector}, d::Dual{<:Integer}) + Base._growend!(primal(a), primal(d)) + Base._growend!(tangent(a), primal(d)) + return zero_dual(nothing) +end function rrule!!( ::CoDual{typeof(Base._growend!)}, _a::CoDual{<:Vector}, _delta::CoDual{<:Integer} ) @@ -186,6 +218,13 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{typeof(Base._growat!),Vector,Integer,Integer} +function frule!!( + ::Dual{typeof(Base._growat!)}, a::Dual{<:Vector}, i::Dual{<:Integer}, d::Dual{<:Integer} +) + Base._growat!(primal(a), primal(i), primal(d)) + Base._growat!(tangent(a), primal(i), primal(d)) + return zero_dual(nothing) +end function rrule!!( ::CoDual{typeof(Base._growat!)}, _a::CoDual{<:Vector}, @@ -209,12 +248,30 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{typeof(sizehint!),Vector,Integer} +function frule!!(::Dual{typeof(sizehint!)}, x::Dual{<:Vector}, sz::Dual{<:Integer}) + sizehint!(primal(x), primal(sz)) + sizehint!(tangent(x), primal(sz)) + return x +end function rrule!!(f::CoDual{typeof(sizehint!)}, x::CoDual{<:Vector}, sz::CoDual{<:Integer}) sizehint!(primal(x), primal(sz)) sizehint!(tangent(x), primal(sz)) return x, NoPullback(f, x, sz) end +function frule!!( + ::Dual{typeof(_foreigncall_)}, + ::Dual{Val{:jl_array_ptr}}, + ::Dual{Val{Ptr{T}}}, + ::Dual{Tuple{Val{Any}}}, + ::Dual, # nreq + ::Dual, # calling convention + a::Dual{<:Array{T},<:Array{V}}, +) where {T,V} + y = ccall(:jl_array_ptr, Ptr{T}, (Any,), primal(a)) + dy = ccall(:jl_array_ptr, Ptr{V}, (Any,), tangent(a)) + return Dual(y, dy) +end function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{:jl_array_ptr}}, @@ -234,6 +291,19 @@ end @is_primitive MinimalCtx Tuple{ typeof(unsafe_copyto!),Array{T},Any,Array{T},Any,Any } where {T} +function frule!!( + ::Dual{typeof(unsafe_copyto!)}, + dest::Dual{<:Array{T}}, + doffs::Dual, + src::Dual{<:Array{T}}, + soffs::Dual, + n::Dual, +) where {T} + _n = primal(n) + Base.unsafe_copyto!(primal(dest), primal(doffs), primal(src), primal(soffs), _n) + Base.unsafe_copyto!(tangent(dest), primal(doffs), tangent(src), primal(soffs), _n) + return dest +end function rrule!!( ::CoDual{typeof(unsafe_copyto!)}, dest::CoDual{<:Array{T}}, @@ -281,6 +351,17 @@ function rrule!!( return dest, unsafe_copyto_pb!! end +Base.@propagate_inbounds function frule!!( + ::Dual{typeof(Core.arrayref)}, + inbounds::Dual{Bool}, + x::Dual{<:Array}, + inds::Vararg{Dual{Int},N}, +) where {N} + _inds = tuple_map(primal, inds) + y = arrayref(primal(inbounds), primal(x), _inds...) + dy = arrayref(primal(inbounds), tangent(x), _inds...) + return Dual(y, dy) +end Base.@propagate_inbounds function rrule!!( ::CoDual{typeof(Core.arrayref)}, checkbounds::CoDual{Bool}, @@ -304,6 +385,18 @@ Base.@propagate_inbounds function rrule!!( return CoDual(_y, dy), arrayref_pullback!! end +function frule!!( + ::Dual{typeof(Core.arrayset)}, + inbounds::Dual{Bool}, + A::Dual{<:Array}, + v::Dual, + inds::Dual{Int}..., +) + _inds = tuple_map(primal, inds) + Core.arrayset(primal(inbounds), primal(A), primal(v), _inds...) + Core.arrayset(primal(inbounds), tangent(A), tangent(v), _inds...) + return A +end function rrule!!( ::CoDual{typeof(Core.arrayset)}, inbounds::CoDual{Bool}, @@ -364,11 +457,15 @@ function isbits_arrayset_rrule( return A, isbits_arrayset_pullback!! end +function frule!!(::Dual{typeof(Core.arraysize)}, X, dim) + return zero_dual(Core.arraysize(primal(X), primal(dim))) +end function rrule!!(f::CoDual{typeof(Core.arraysize)}, X, dim) return zero_fcodual(Core.arraysize(primal(X), primal(dim))), NoPullback(f, X, dim) end @is_primitive MinimalCtx Tuple{typeof(copy),Array} +frule!!(::Dual{typeof(copy)}, a::Dual{<:Array}) = Dual(copy(primal(a)), copy(tangent(a))) function rrule!!(::CoDual{typeof(copy)}, a::CoDual{<:Array}) dx = tangent(a) dy = copy(dx) @@ -380,7 +477,21 @@ function rrule!!(::CoDual{typeof(copy)}, a::CoDual{<:Array}) return y, copy_pullback!! end +function _copy_dict_tangent(mt::MutableTangent) + t = mt.fields + new_fields = typeof(t)(( + copy(t.slots), copy(t.keys), copy(t.vals), tuple_fill(NoTangent(), Val(5))... + )) + return MutableTangent(new_fields) +end + @is_primitive MinimalCtx Tuple{typeof(fill!),Array{<:Union{UInt8,Int8}},Integer} +function frule!!( + ::Dual{typeof(fill!)}, a::Dual{<:Array{<:Union{UInt8,Int8}}}, x::Dual{<:Integer} +) + fill!(primal(a), primal(x)) + return a +end function rrule!!( ::CoDual{typeof(fill!)}, a::CoDual{T}, x::CoDual{<:Integer} ) where {V<:Union{UInt8,Int8},T<:Array{V}} diff --git a/src/rrules/avoiding_non_differentiable_code.jl b/src/rrules/avoiding_non_differentiable_code.jl index 78161550e3..11a31e0820 100644 --- a/src/rrules/avoiding_non_differentiable_code.jl +++ b/src/rrules/avoiding_non_differentiable_code.jl @@ -2,17 +2,20 @@ # because we drop the gradient, because the tangent type of integers is NoTangent. # https://github.com/JuliaLang/julia/blob/9f9e989f241fad1ae03c3920c20a93d8017a5b8f/base/pointer.jl#L282 @is_primitive MinimalCtx Tuple{typeof(Base.:(+)),Ptr,Integer} +function frule!!(::Dual{typeof(Base.:(+))}, x::Dual{<:Ptr}, y::Dual{<:Integer}) + return Dual(primal(x) + primal(y), tangent(x) + primal(y)) +end function rrule!!(f::CoDual{typeof(Base.:(+))}, x::CoDual{<:Ptr}, y::CoDual{<:Integer}) return CoDual(primal(x) + primal(y), tangent(x) + primal(y)), NoPullback(f, x, y) end -@zero_adjoint MinimalCtx Tuple{typeof(randn),AbstractRNG,Vararg} -@zero_adjoint MinimalCtx Tuple{typeof(string),Vararg} -@zero_adjoint MinimalCtx Tuple{Type{Symbol},Vararg} -@zero_adjoint MinimalCtx Tuple{Type{Float64},Any,RoundingMode} -@zero_adjoint MinimalCtx Tuple{Type{Float32},Any,RoundingMode} -@zero_adjoint MinimalCtx Tuple{Type{Float16},Any,RoundingMode} -@zero_adjoint MinimalCtx Tuple{typeof(==),Type,Type} +@zero_derivative MinimalCtx Tuple{typeof(randn),AbstractRNG,Vararg} +@zero_derivative MinimalCtx Tuple{typeof(string),Vararg} +@zero_derivative MinimalCtx Tuple{Type{Symbol},Vararg} +@zero_derivative MinimalCtx Tuple{Type{Float64},Any,RoundingMode} +@zero_derivative MinimalCtx Tuple{Type{Float32},Any,RoundingMode} +@zero_derivative MinimalCtx Tuple{Type{Float16},Any,RoundingMode} +@zero_derivative MinimalCtx Tuple{typeof(==),Type,Type} # logging, String related primitive rules using Base: getindex, getproperty @@ -21,40 +24,31 @@ using Mooncake: zero_fcodual, MinimalCtx, @is_primitive, NoPullback, CoDual using Base.CoreLogging: LogLevel, handle_message, invokelatest import Base.CoreLogging as CoreLogging -# Rule for accessing an Atomic{T} wrapped Integer with Base.getindex as deriving a rule results -# in encountering a Atomic->Int address bitcast followed by a llvm atomic load call -@zero_adjoint MinimalCtx Tuple{typeof(getindex),Atomic{I}} where {I<:Integer} +# Rule for accessing an Atomic{T} wrapped Integer with Base.getindex as deriving a rule +# results in encountering a Atomic->Int address bitcast followed by a llvm atomic load call. +@zero_derivative MinimalCtx Tuple{typeof(getindex),Atomic{I}} where {I<:Integer} # Some Base String related rrules : -@zero_adjoint MinimalCtx Tuple{typeof(print),Vararg} -@zero_adjoint MinimalCtx Tuple{typeof(println),Vararg} -@zero_adjoint MinimalCtx Tuple{typeof(show),Vararg} -@zero_adjoint MinimalCtx Tuple{typeof(normpath),String} - -# seperate kwargs, non-kwargs Base.sprint rules are required. Julia compilation only gives a common lowered IR for any Base.sprint calls. -# refer issue #558 and PR https://github.com/chalk-lab/Mooncake.jl/pull/659 for another sneaky appearance of this problem + fix. -@zero_adjoint MinimalCtx Tuple{typeof(sprint),Vararg} -@is_primitive MinimalCtx Tuple{typeof(Core.kwcall),<:NamedTuple,typeof(sprint),Vararg} -function rrule!!( - ::CoDual{typeof(Core.kwcall)}, - kwargs::CoDual{<:NamedTuple}, - ::CoDual{typeof(sprint)}, - args::Vararg{CoDual}, -) - primal_args = map(x -> x.x, args) - result = Core.kwcall(kwargs.x, sprint, primal_args...) - return zero_fcodual(result), - NoPullback(zero_fcodual(Core.kwcall), kwargs, zero_fcodual(sprint), args...) -end +@zero_derivative MinimalCtx Tuple{typeof(print),Vararg} +@zero_derivative MinimalCtx Tuple{typeof(println),Vararg} +@zero_derivative MinimalCtx Tuple{typeof(show),Vararg} +@zero_derivative MinimalCtx Tuple{typeof(normpath),String} + +# seperate kwargs, non-kwargs Base.sprint rules are required. Julia compilation only gives a +# common lowered IR for any Base.sprint calls. Refer to issue #558 and PR +# https://github.com/chalk-lab/Mooncake.jl/pull/659 for another sneaky appearance of this +# problem + fix. +@zero_derivative MinimalCtx Tuple{typeof(sprint),Vararg} +@zero_derivative MinimalCtx Tuple{typeof(Core.kwcall),NamedTuple,typeof(sprint),Vararg} # Base.CoreLogging @logmsg related primitives. -@zero_adjoint MinimalCtx Tuple{ +@zero_derivative MinimalCtx Tuple{ typeof(Base._replace_init),String,Tuple{Pair{String,String}},Int64 } -@zero_adjoint MinimalCtx Tuple{ +@zero_derivative MinimalCtx Tuple{ typeof(CoreLogging.current_logger_for_env),LogLevel,Symbol,Module } -@zero_adjoint MinimalCtx Tuple{ +@zero_derivative MinimalCtx Tuple{ typeof(Core._call_latest), typeof(Base.CoreLogging.shouldlog), Any, @@ -63,23 +57,8 @@ end Symbol, Symbol, } -@zero_adjoint MinimalCtx Tuple{ - typeof(Core._call_latest), - typeof(CoreLogging.handle_message), - Any, - Base.CoreLogging.LogLevel, - String, - Module, - Symbol, - Symbol, - String, - Int64, -} -# specialized case for Builtin primitive Core._call_latest rrule for CoreLogging.handle_message kwargs call. -@is_primitive MinimalCtx Tuple{ +@zero_derivative MinimalCtx Tuple{ typeof(Core._call_latest), - typeof(Core.kwcall), - <:NamedTuple, typeof(CoreLogging.handle_message), Any, Base.CoreLogging.LogLevel, @@ -90,49 +69,24 @@ end String, Int64, } -function rrule!!( - ::CoDual{typeof(Core._call_latest)}, - ::CoDual{typeof(Core.kwcall)}, - kwargs::CoDual{<:NamedTuple}, - ::CoDual{typeof(CoreLogging.handle_message)}, - logger::CoDual, - loglevel::CoDual{Base.CoreLogging.LogLevel}, - message::CoDual{String}, - _module::CoDual{Module}, - group::CoDual{Symbol}, - id::CoDual{Symbol}, - file::CoDual{String}, - line::CoDual{Int64}, +# specialized case for Builtin primitive Core._call_latest rrule for CoreLogging.handle_message kwargs call. +@zero_derivative( + MinimalCtx, + Tuple{ + typeof(Core._call_latest), + typeof(Core.kwcall), + NamedTuple, + typeof(CoreLogging.handle_message), + Any, + Base.CoreLogging.LogLevel, + String, + Module, + Symbol, + Symbol, + String, + Int64, + } ) - result = Core._call_latest( - Core.kwcall, - kwargs.x, - CoreLogging.handle_message, - logger.x, - loglevel.x, - message.x, - _module.x, - group.x, - id.x, - file.x, - line.x; - ) - return zero_fcodual(result), - NoPullback( - zero_fcodual(Core._call_latest), - zero_fcodual(Core.kwcall), - kwargs, - zero_fcodual(CoreLogging.handle_message), - logger, - loglevel, - message, - _module, - group, - id, - file, - line, - ) -end function generate_hand_written_rrule!!_test_cases( rng_ctor, ::Val{:avoiding_non_differentiable_code} diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index 02b4669d59..aa63cad7c5 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -2,14 +2,6 @@ function blas_name(name::Symbol) return (BLAS.USE_BLAS64 ? Symbol(name, "64_") : name, Symbol(BLAS.libblastrampoline)) end -function wrap_ptr_as_view(ptr::Ptr{T}, N::Int, inc::Int) where {T} - return view(unsafe_wrap(Vector{T}, ptr, N * inc), 1:inc:(N * inc)) -end - -function wrap_ptr_as_view(ptr::Ptr{T}, buffer_nrows::Int, nrows::Int, ncols::Int) where {T} - return view(unsafe_wrap(Matrix{T}, ptr, (buffer_nrows, ncols)), 1:nrows, :) -end - function _trans(flag, mat) flag === 'T' && return transpose(mat) flag === 'C' && return adjoint(mat) @@ -21,11 +13,14 @@ function tri!(A, u::Char, d::Char) return u == 'L' ? tril!(A, d == 'U' ? -1 : 0) : triu!(A, d == 'U' ? 1 : 0) end -const MatrixOrView{T} = Union{Matrix{T},SubArray{T,2,<:Array{T}}} -const VecOrView{T} = Union{Vector{T},SubArray{T,1,<:Array{T}}} const BlasRealFloat = Union{Float32,Float64} const BlasComplexFloat = Union{ComplexF32,ComplexF64} +_fields(x::Tangent) = x.fields +_fields(x::FData) = x.data + +const TangentOrFData = Union{Tangent,FData} + """ arrayify(x::CoDual{<:AbstractArray{<:BlasFloat}}) @@ -33,23 +28,23 @@ Return the primal field of `x`, and convert its fdata into an array of the same primal. This operation is not guaranteed to be possible for all array types, but seems to be possible for all array types of interest so far. """ -function arrayify(x::CoDual{A}) where {A<:AbstractArray{<:BlasFloat}} +function arrayify(x::Union{Dual{A},CoDual{A}}) where {A<:AbstractArray{<:BlasFloat}} return arrayify(primal(x), tangent(x)) # NOTE: for complex number, the tangent is a reinterpreted version of the primal end arrayify(x::Array{P}, dx::Array{P}) where {P<:BlasRealFloat} = (x, dx) function arrayify(x::Array{P}, dx::Array{<:Tangent}) where {P<:BlasComplexFloat} return x, reinterpret(P, dx) end -function arrayify(x::A, dx::FData) where {A<:SubArray{<:BlasRealFloat}} - _, _dx = arrayify(x.parent, dx.data.parent) +function arrayify(x::A, dx::TangentOrFData) where {A<:SubArray{<:BlasRealFloat}} + _, _dx = arrayify(x.parent, _fields(dx).parent) return x, A(_dx, x.indices, x.offset1, x.stride1) end -function arrayify(x::A, dx::FData) where {A<:Base.ReshapedArray{<:BlasRealFloat}} - _, _dx = arrayify(x.parent, dx.data.parent) +function arrayify(x::A, dx::TangentOrFData) where {A<:Base.ReshapedArray{<:BlasRealFloat}} + _, _dx = arrayify(x.parent, _fields(dx).parent) return x, A(_dx, x.dims, x.mi) end -function arrayify(x::Base.ReinterpretArray{T}, dx::FData) where {T<:BlasFloat} - _, _dx = arrayify(x.parent, dx.data.parent) +function arrayify(x::Base.ReinterpretArray{T}, dx::TangentOrFData) where {T<:BlasFloat} + _, _dx = arrayify(x.parent, _fields(dx).parent) return x, reinterpret(T, _dx) end @@ -69,16 +64,44 @@ end # Utility # -@zero_adjoint MinimalCtx Tuple{typeof(BLAS.get_num_threads)} -@zero_adjoint MinimalCtx Tuple{typeof(BLAS.lbt_get_num_threads)} -@zero_adjoint MinimalCtx Tuple{typeof(BLAS.set_num_threads),Union{Integer,Nothing}} -@zero_adjoint MinimalCtx Tuple{typeof(BLAS.lbt_set_num_threads),Any} +@zero_derivative MinimalCtx Tuple{typeof(BLAS.get_num_threads)} +@zero_derivative MinimalCtx Tuple{typeof(BLAS.lbt_get_num_threads)} +@zero_derivative MinimalCtx Tuple{typeof(BLAS.set_num_threads),Union{Integer,Nothing}} +@zero_derivative MinimalCtx Tuple{typeof(BLAS.lbt_set_num_threads),Any} # # LEVEL 1 # for (fname, elty) in ((:cblas_ddot, :Float64), (:cblas_sdot, :Float32)) + @eval @inline function frule!!( + ::Dual{typeof(_foreigncall_)}, + ::Dual{Val{$(blas_name(fname))}}, + ::Dual, # return type + ::Dual, # argument types + ::Dual, # nreq + ::Dual, # calling convention + _n::Dual{BLAS.BlasInt}, + _DX::Dual{Ptr{$elty}}, + _incx::Dual{BLAS.BlasInt}, + _DY::Dual{Ptr{$elty}}, + _incy::Dual{BLAS.BlasInt}, + args::Vararg{Any,N}, + ) where {N} + GC.@preserve args begin + # Load in values from pointers. + n, incx, incy = map(primal, (_n, _incx, _incy)) + xinds = 1:incx:(incx * n) + yinds = 1:incy:(incy * n) + DX = view(unsafe_wrap(Vector{$elty}, primal(_DX), n * incx), xinds) + DY = view(unsafe_wrap(Vector{$elty}, primal(_DY), n * incy), yinds) + + _dDX = view(unsafe_wrap(Vector{$elty}, tangent(_DX), n * incx), xinds) + _dDY = view(unsafe_wrap(Vector{$elty}, tangent(_DY), n * incy), yinds) + + return Dual(dot(DX, DY), dot(DX, _dDY) + dot(_dDX, DY)) + end + end @eval @inline function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$(blas_name(fname))}}, @@ -126,6 +149,20 @@ end typeof(BLAS.nrm2),Int,X,Int } where {T<:BlasFloat,X<:Union{Ptr{T},AbstractArray{T}}}, ) +function frule!!( + ::Dual{typeof(BLAS.nrm2)}, + n::Dual{<:Integer}, + X_dX::Dual{<:Union{Ptr{T},AbstractArray{T}}}, + incx::Dual{<:Integer}, +) where {T<:BlasFloat} + X, dX = arrayify(X_dX) + y = BLAS.nrm2(primal(n), X, primal(incx)) + dy = zero(y) + @inbounds for i in 1:primal(incx):(primal(n) * primal(incx)) + dy = dy + real(X[i] * dX[i]') + real(X[i]' * dX[i]) + end + return Dual(y, dy / 2y) +end function rrule!!( ::CoDual{typeof(BLAS.nrm2)}, n::CoDual{<:Integer}, @@ -146,6 +183,18 @@ end MinimalCtx, Tuple{typeof(BLAS.nrm2),X} where {T<:BlasFloat,X<:Union{Ptr{T},AbstractArray{T}}}, ) +function frule!!( + ::Dual{typeof(BLAS.nrm2)}, + X_dX::Dual{<:Union{Ptr{T},AbstractArray{T}} where {T<:BlasFloat}}, +) + X, dX = arrayify(X_dX) + y = BLAS.nrm2(X) + dy = zero(y) + @inbounds for i in eachindex(X) + dy = dy + real(X[i] * dX[i]') + real(X[i]' * dX[i]) + end + return Dual(y, dy / (2y)) +end function rrule!!( ::CoDual{typeof(BLAS.nrm2)}, X_dX::CoDual{<:Union{Ptr{T},AbstractArray{T}} where {T<:BlasFloat}}, @@ -159,54 +208,68 @@ function rrule!!( return CoDual(y, NoFData()), nrm2_pb!! end -for (fname, elty) in ((:dscal_, :Float64), (:sscal_, :Float32)) - @eval @inline function Mooncake.rrule!!( - ::CoDual{typeof(_foreigncall_)}, - ::CoDual{Val{$(blas_name(fname))}}, - ::CoDual, # return type - ::CoDual, # argument types - ::CoDual, # nreq - ::CoDual, # calling convention - n::CoDual{Ptr{BLAS.BlasInt}}, - DA::CoDual{Ptr{$elty}}, - DX::CoDual{Ptr{$elty}}, - incx::CoDual{Ptr{BLAS.BlasInt}}, - args::Vararg{Any,N}, - ) where {N} - GC.@preserve args begin +@is_primitive( + MinimalCtx, + Tuple{typeof(BLAS.scal!),Integer,P,AbstractArray{P},Integer} where {P<:BlasRealFloat} +) +function frule!!( + ::Dual{typeof(BLAS.scal!)}, + _n::Dual{<:Integer}, + a_da::Dual{P}, + X_dX::Dual{<:AbstractArray{P}}, + _incx::Dual{<:Integer}, +) where {P<:BlasRealFloat} - # Load in values from pointers, and turn pointers to memory buffers into Vectors. - _n = unsafe_load(primal(n)) - _incx = unsafe_load(primal(incx)) - _DA = unsafe_load(primal(DA)) - _DX = unsafe_wrap(Vector{$elty}, primal(DX), _n * _incx) - _DX_s = unsafe_wrap(Vector{$elty}, tangent(DX), _n * _incx) + # Extract params. + n = primal(_n) + incx = primal(_incx) + a, da = extract(a_da) + X, dX = arrayify(X_dX) - inds = 1:_incx:(_incx * _n) - DX_copy = _DX[inds] - BLAS.scal!(_n, _DA, _DX, _incx) + # Compute Frechet derivative. + BLAS.scal!(n, a, dX, incx) + BLAS.axpy!(n, da, X, incx, dX, incx) - dDA = tangent(DA) - dDX = tangent(DX) - end + # Perform primal computation. + BLAS.scal!(n, a, X, incx) + return X_dX +end +function rrule!!( + ::CoDual{typeof(BLAS.scal!)}, + _n::CoDual{<:Integer}, + a_da::CoDual{P}, + X_dX::CoDual{<:AbstractArray{P}}, + _incx::CoDual{<:Integer}, +) where {P<:BlasRealFloat} - function dscal_pullback!!(::NoRData) - GC.@preserve args begin + # Extract params. + n = primal(_n) + incx = primal(_incx) + a = primal(a_da) + X, dX = arrayify(X_dX) + + # Take a copy of previous state in order to recover it on the reverse pass. + X_copy = copy(X) + dX_copy = copy(dX) - # Set primal to previous state. - _DX[inds] .= DX_copy + # Run primal computation. + BLAS.scal!(n, a, X, incx) - # Compute cotangent w.r.t. scaling. - unsafe_store!(dDA, BLAS.dot(_n, _DX, _incx, dDX, _incx) + unsafe_load(dDA)) + function scal_adjoint(::NoRData) - # Compute cotangent w.r.t. DX. - BLAS.scal!(_n, _DA, _DX_s, _incx) - end + # Set primal to previous state. + X .= X_copy - return tuple_fill(NoRData(), Val(10 + N)) - end - return zero_fcodual(Cvoid()), dscal_pullback!! + # Compute gradient w.r.t. scaling. + ∇a = BLAS.dot(n, X, incx, dX, incx) + + # Compute gradient w.r.t. DX. + BLAS.scal!(n, a, dX, incx) + BLAS.axpy!(n, one(P), dX, incx, dX_copy, incx) + + return NoRData(), NoRData(), ∇a, NoRData(), NoRData() end + return X_dX, scal_adjoint end # @@ -220,6 +283,40 @@ end } where {P<:BlasRealFloat}, ) +@inline function frule!!( + ::Dual{typeof(BLAS.gemv!)}, + tA::Dual{Char}, + alpha::Dual{P}, + A_dA::Dual{<:AbstractMatrix{P}}, + x_dx::Dual{<:AbstractVector{P}}, + beta::Dual{P}, + y_dy::Dual{<:AbstractVector{P}}, +) where {P<:BlasRealFloat} + A, dA = arrayify(A_dA) + x, dx = arrayify(x_dx) + y, dy = arrayify(y_dy) + α, dα = extract(alpha) + β, dβ = extract(beta) + + # Derivative computation. + BLAS.gemv!(primal(tA), dα, A, x, β, dy) + BLAS.gemv!(primal(tA), α, dA, x, one(P), dy) + BLAS.gemv!(primal(tA), α, A, dx, one(P), dy) + + # Strong zero is essential here, in case `y` has undefined element values. + if !iszero(dβ) + @inbounds for n in eachindex(y) + tmp = dβ * y[n] + dy[n] = ifelse(isnan(y[n]), dy[n], tmp + dy[n]) + end + end + + # Primal computation. + BLAS.gemv!(primal(tA), α, A, x, β, y) + + return y_dy +end + @inline function rrule!!( ::CoDual{typeof(BLAS.gemv!)}, _tA::CoDual{Char}, @@ -276,6 +373,40 @@ end } where {T<:BlasRealFloat}, ) +function frule!!( + ::Dual{typeof(BLAS.symv!)}, + uplo::Dual{Char}, + alpha::Dual{T}, + A_dA::Dual{<:AbstractMatrix{T}}, + x_dx::Dual{<:AbstractVector{T}}, + beta::Dual{T}, + y_dy::Dual{<:AbstractVector{T}}, +) where {T<:BlasRealFloat} + # Extract primals. + ul = primal(uplo) + α = primal(alpha) + β, dβ = extract(beta) + A, dA = arrayify(A_dA) + x, dx = arrayify(x_dx) + y, dy = arrayify(y_dy) + + # Compute Frechet derivative. + BLAS.symv!(ul, tangent(alpha), A, x, β, dy) + BLAS.symv!(ul, α, dA, x, one(T), dy) + BLAS.symv!(ul, α, A, dx, one(T), dy) + if !iszero(dβ) + @inbounds for n in eachindex(y) + tmp = dβ * y[n] + dy[n] = ifelse(isnan(y[n]), dy[n], tmp + dy[n]) + end + end + + # Run primal computation. + BLAS.symv!(ul, α, A, x, β, y) + + return y_dy +end + function rrule!!( ::CoDual{typeof(BLAS.symv!)}, uplo::CoDual{Char}, @@ -354,6 +485,36 @@ end } where {T<:BlasRealFloat}, ) +function frule!!( + ::Dual{typeof(BLAS.trmv!)}, + _uplo::Dual{Char}, + _trans::Dual{Char}, + _diag::Dual{Char}, + A_dA::Dual{<:AbstractMatrix{T}}, + x_dx::Dual{<:AbstractVector{T}}, +) where {T<:BlasRealFloat} + # Extract primals. + uplo = primal(_uplo) + trans = primal(_trans) + diag = primal(_diag) + A, dA = arrayify(A_dA) + x, dx = arrayify(x_dx) + + # Frechet derivative computation. + BLAS.trmv!(uplo, trans, diag, A, dx) + tmp = copy(x) + BLAS.trmv!(uplo, trans, diag, dA, tmp) + dx .+= tmp + if diag === 'U' + dx .-= x + end + + # Primal computation. + BLAS.trmv!(uplo, trans, diag, A, x) + + return x_dx +end + function rrule!!( ::CoDual{typeof(BLAS.trmv!)}, _uplo::CoDual{Char}, @@ -431,6 +592,46 @@ end } where {T<:BlasRealFloat}, ) +function frule!!( + ::Dual{typeof(BLAS.gemm!)}, + transA::Dual{Char}, + transB::Dual{Char}, + alpha::Dual{T}, + A_dA::Dual{<:AbstractMatrix{T}}, + B_dB::Dual{<:AbstractMatrix{T}}, + beta::Dual{T}, + C_dC::Dual{<:AbstractMatrix{T}}, +) where {T<:BlasRealFloat} + tA = primal(transA) + tB = primal(transB) + α, dα = extract(alpha) + β, dβ = extract(beta) + A, dA = arrayify(A_dA) + B, dB = arrayify(B_dB) + C, dC = arrayify(C_dC) + + # Tangent computation. + BLAS.gemm!(tA, tB, α, dA, B, β, dC) + BLAS.gemm!(tA, tB, α, A, dB, one(T), dC) + if !iszero(dα) + BLAS.gemm!(tA, tB, dα, A, B, one(T), dC) + end + if !iszero(dβ) + @inbounds for n in eachindex(C) + dC[n] = ifelse_nan(C[n], dC[n], dC[n] + dβ * C[n]) + end + end + + # Primal computation. + BLAS.gemm!(tA, tB, α, A, B, β, C) + + return C_dC +end + +function ifelse_nan(cond, left::P, right::P) where {P<:BlasRealFloat} + return isnan(cond) * left + !isnan(cond) * right +end + function rrule!!( ::CoDual{typeof(BLAS.gemm!)}, transA::CoDual{Char}, @@ -504,7 +705,42 @@ end AbstractMatrix{T}, } where {T<:BlasRealFloat}, ) +function frule!!( + ::Dual{typeof(BLAS.symm!)}, + side::Dual{Char}, + uplo::Dual{Char}, + alpha::Dual{T}, + A_dA::Dual{<:AbstractMatrix{T}}, + B_dB::Dual{<:AbstractMatrix{T}}, + beta::Dual{T}, + C_dC::Dual{<:AbstractMatrix{T}}, +) where {T<:BlasRealFloat} + + # Extract primals. + s = primal(side) + ul = primal(uplo) + α, dα = extract(alpha) + β, dβ = extract(beta) + A, dA = arrayify(A_dA) + B, dB = arrayify(B_dB) + C, dC = arrayify(C_dC) + + # Compute Frechet derivative. + BLAS.symm!(s, ul, α, A, dB, β, dC) + BLAS.symm!(s, ul, α, dA, B, one(T), dC) + if !iszero(dα) + BLAS.symm!(s, ul, dα, A, B, one(T), dC) + end + if !iszero(dβ) + @inbounds for n in eachindex(C) + dC[n] = ifelse_nan(C[n], dC[n], dC[n] + dβ * C[n]) + end + end + # Run primal computation. + BLAS.symm!(s, ul, α, A, B, β, C) + return C_dC +end function rrule!!( ::CoDual{typeof(BLAS.symm!)}, side::CoDual{Char}, @@ -578,229 +814,278 @@ function rrule!!( return C_dC, symm!_adjoint end -for (syrk, elty) in ((:dsyrk_, :Float64), (:ssyrk_, :Float32)) - @eval function rrule!!( - ::CoDual{typeof(_foreigncall_)}, - ::CoDual{Val{$(blas_name(syrk))}}, - ::CoDual{Val{Cvoid}}, - ::CoDual, # arg types - ::CoDual, # nreq - ::CoDual, # calling convention - uplo::CoDual{Ptr{UInt8}}, - trans::CoDual{Ptr{UInt8}}, - n::CoDual{Ptr{BLAS.BlasInt}}, - k::CoDual{Ptr{BLAS.BlasInt}}, - alpha::CoDual{Ptr{$elty}}, - A::CoDual{Ptr{$elty}}, - LDA::CoDual{Ptr{BLAS.BlasInt}}, - beta::CoDual{Ptr{$elty}}, - C::CoDual{Ptr{$elty}}, - LDC::CoDual{Ptr{BLAS.BlasInt}}, - args::Vararg{Any,Nargs}, - ) where {Nargs} - GC.@preserve args begin - _uplo = Char(unsafe_load(primal(uplo))) - _t = Char(unsafe_load(primal(trans))) - _n = unsafe_load(primal(n)) - _k = unsafe_load(primal(k)) - _alpha = unsafe_load(primal(alpha)) - _A = primal(A) - _LDA = unsafe_load(primal(LDA)) - _beta = unsafe_load(primal(beta)) - _C = primal(C) - _LDC = unsafe_load(primal(LDC)) - - A_mat = wrap_ptr_as_view(primal(A), _LDA, (_t == 'N' ? (_n, _k) : (_k, _n))...) - C_mat = wrap_ptr_as_view(primal(C), _LDC, _n, _n) - C_copy = collect(C_mat) - - BLAS.syrk!(_uplo, _t, _alpha, A_mat, _beta, C_mat) - - dalpha = tangent(alpha) - dA = tangent(A) - dbeta = tangent(beta) - dC = tangent(C) - end +@is_primitive( + MinimalCtx, + Tuple{ + typeof(BLAS.syrk!),Char,Char,P,AbstractMatrix{P},P,AbstractMatrix{P} + } where {P<:BlasRealFloat} +) +function frule!!( + ::Dual{typeof(BLAS.syrk!)}, + _uplo::Dual{Char}, + _t::Dual{Char}, + α_dα::Dual{P}, + A_dA::Dual{<:AbstractMatrix{P}}, + β_dβ::Dual{P}, + C_dC::Dual{<:AbstractMatrix{P}}, +) where {P<:BlasRealFloat} - function syrk!_pullback!!(::NoRData) - GC.@preserve args begin - # Restore previous state. - C_mat .= C_copy - - # Convert pointers to views. - dA_mat = wrap_ptr_as_view(dA, _LDA, (_t == 'N' ? (_n, _k) : (_k, _n))...) - dC_mat = wrap_ptr_as_view(dC, _LDC, _n, _n) - - # Increment cotangents. - B = _uplo == 'U' ? triu(dC_mat) : tril(dC_mat) - unsafe_store!(dbeta, unsafe_load(dbeta) + sum(B .* C_mat)) - dalpha_inc = tr(B' * _trans(_t, A_mat) * _trans(_t, A_mat)') - unsafe_store!(dalpha, unsafe_load(dalpha) + dalpha_inc) - dA_mat .+= _alpha * (_t == 'N' ? (B + B') * A_mat : A_mat * (B + B')) - dC_mat .= - (_uplo == 'U' ? tril!(dC_mat, -1) : triu!(dC_mat, 1)) .+ _beta .* B - end + # Extract values from pairs. + uplo = primal(_uplo) + t = primal(_t) + α, dα = extract(α_dα) + A, dA = arrayify(A_dA) + β, dβ = extract(β_dβ) + C, dC = arrayify(C_dC) - return tuple_fill(NoRData(), Val(16 + Nargs)) - end - return zero_fcodual(Cvoid()), syrk!_pullback!! + # Compute Frechet derivative. + BLAS.syr2k!(uplo, t, α, A, dA, β, dC) + iszero(dα) || BLAS.syrk!(uplo, t, dα, A, one(P), dC) + if !iszero(dβ) + dC .+= dβ .* (uplo == 'U' ? triu(C) : tril(C)) end + + # Run primal computation. + BLAS.syrk!(uplo, t, α, A, β, C) + + return C_dC end +function rrule!!( + ::CoDual{typeof(BLAS.syrk!)}, + _uplo::CoDual{Char}, + _t::CoDual{Char}, + α_dα::CoDual{P}, + A_dA::CoDual{<:AbstractMatrix{P}}, + β_dβ::CoDual{P}, + C_dC::CoDual{<:AbstractMatrix{P}}, +) where {P<:BlasRealFloat} -for (trmm, elty) in ((:dtrmm_, :Float64), (:strmm_, :Float32)) - @eval function rrule!!( - ::CoDual{typeof(_foreigncall_)}, - ::CoDual{Val{$(blas_name(trmm))}}, - ::CoDual, - ::CoDual, # arg types - ::CoDual, # nreq - ::CoDual, # calling convention - _side::CoDual{Ptr{UInt8}}, - _uplo::CoDual{Ptr{UInt8}}, - _trans::CoDual{Ptr{UInt8}}, - _diag::CoDual{Ptr{UInt8}}, - _M::CoDual{Ptr{BLAS.BlasInt}}, - _N::CoDual{Ptr{BLAS.BlasInt}}, - _alpha::CoDual{Ptr{$elty}}, - _A::CoDual{Ptr{$elty}}, - _lda::CoDual{Ptr{BLAS.BlasInt}}, - _B::CoDual{Ptr{$elty}}, - _ldb::CoDual{Ptr{BLAS.BlasInt}}, - args::Vararg{Any,Nargs}, - ) where {Nargs} - GC.@preserve args begin + # Extract values from pairs. + uplo = primal(_uplo) + trans = primal(_t) + α = primal(α_dα) + A, dA = arrayify(A_dA) + β = primal(β_dβ) + C, dC = arrayify(C_dC) - # Load in data and store B for the reverse-pass. - side, ul, tA, diag = map( - Char ∘ unsafe_load ∘ primal, (_side, _uplo, _trans, _diag) - ) - M, N, lda, ldb = map(unsafe_load ∘ primal, (_M, _N, _lda, _ldb)) - alpha = unsafe_load(primal(_alpha)) - R = side == 'L' ? M : N - A = wrap_ptr_as_view(primal(_A), lda, R, R) - B = wrap_ptr_as_view(primal(_B), ldb, M, N) - B_copy = copy(B) + # Run forwards pass, and remember previous value of `C` for the reverse-pass. + C_copy = collect(C) + BLAS.syrk!(uplo, trans, α, A, β, C) - # Run primal. - BLAS.trmm!(side, ul, tA, diag, alpha, A, B) + function syrk_adjoint(::NoRData) + # Restore previous state. + C .= C_copy - dalpha = tangent(_alpha) - _dA = tangent(_A) - _dB = tangent(_B) - end + # C_copy no longer required, so its memory can be used to store other intermediate + # results. Renaming for clarity. + tmp = C_copy - function trmm!_pullback!!(::NoRData) - GC.@preserve args begin - # Convert pointers to views. - dA = wrap_ptr_as_view(_dA, lda, R, R) - dB = wrap_ptr_as_view(_dB, ldb, M, N) + # Increment gradients. + B = uplo == 'U' ? triu(dC) : tril(dC) + ∇β = sum(B .* C) + ∇α = tr(B' * _trans(trans, A) * _trans(trans, A)') + # @show _t, size(A), size(B) + dA .+= α * (trans == 'N' ? (B + B') * A : A * (B + B')) + dC .= (uplo == 'U' ? tril!(dC, -1) : triu!(dC, 1)) .+ β .* B - # Increment alpha tangent. - alpha != 0 && unsafe_store!(dalpha, unsafe_load(dalpha) + tr(dB'B) / alpha) + return NoRData(), NoRData(), NoRData(), ∇α, NoRData(), ∇β, NoRData() + end - # Restore initial state. - B .= B_copy + return C_dC, syrk_adjoint +end - # Increment cotangents. - if side == 'L' - dA .+= alpha .* tri!(tA == 'N' ? dB * B' : B * dB', ul, diag) - else - dA .+= alpha .* tri!(tA == 'N' ? B'dB : dB'B, ul, diag) - end +@is_primitive( + MinimalCtx, + Tuple{ + typeof(BLAS.trmm!),Char,Char,Char,Char,P,AbstractMatrix{P},AbstractMatrix{P} + } where {P<:BlasRealFloat} +) +function frule!!( + ::Dual{typeof(BLAS.trmm!)}, + _side::Dual{Char}, + _uplo::Dual{Char}, + _ta::Dual{Char}, + _diag::Dual{Char}, + α_dα::Dual{P}, + A_dA::Dual{<:AbstractMatrix{P}}, + B_dB::Dual{<:AbstractMatrix{P}}, +) where {P<:BlasRealFloat} - # Compute dB tangent. - BLAS.trmm!(side, ul, tA == 'N' ? 'T' : 'N', diag, alpha, A, dB) - end + # Extract data. + side = primal(_side) + uplo = primal(_uplo) + ta = primal(_ta) + diag = primal(_diag) + α, dα = extract(α_dα) + A, dA = arrayify(A_dA) + B, dB = arrayify(B_dB) + + # Compute Frechet derivative. + BLAS.trmm!(side, uplo, ta, diag, α, A, dB) + dB .+= BLAS.trmm!(side, uplo, ta, diag, α, dA, copy(B)) + if diag == 'U' + dB .-= α .* B + end + if !iszero(dα) + dB .+= BLAS.trmm!(side, uplo, ta, diag, dα, A, copy(B)) + end + + # Compute primal. + BLAS.trmm!(side, uplo, ta, diag, α, A, B) + return B_dB +end +function rrule!!( + ::CoDual{typeof(BLAS.trmm!)}, + _side::CoDual{Char}, + _uplo::CoDual{Char}, + _ta::CoDual{Char}, + _diag::CoDual{Char}, + α_dα::CoDual{P}, + A_dA::CoDual{<:AbstractMatrix{P}}, + B_dB::CoDual{<:AbstractMatrix{P}}, +) where {P<:BlasRealFloat} + + # Extract values. + side = primal(_side) + uplo = primal(_uplo) + tA = primal(_ta) + diag = primal(_diag) + α = primal(α_dα) + A, dA = arrayify(A_dA) + B, dB = arrayify(B_dB) + B_copy = copy(B) + + # Run primal. + BLAS.trmm!(side, uplo, tA, diag, α, A, B) - return tuple_fill(NoRData(), Val(17 + Nargs)) + function trmm_adjoint(::NoRData) + + # Compute α gradient. + ∇α = tr(dB'B) / α + + # Restore initial state. + B .= B_copy + + # Increment gradients. + if side == 'L' + dA .+= α .* tri!(tA == 'N' ? dB * B' : B * dB', uplo, diag) + else + dA .+= α .* tri!(tA == 'N' ? B'dB : dB'B, uplo, diag) end - return zero_fcodual(Cvoid()), trmm!_pullback!! + # Compute dB tangent. + BLAS.trmm!(side, uplo, tA == 'N' ? 'T' : 'N', diag, α, A, dB) + + return tuple_fill(NoRData(), Val(5))..., ∇α, NoRData(), NoRData() end + + return B_dB, trmm_adjoint end -for (trsm, elty) in ((:dtrsm_, :Float64), (:strsm_, :Float32)) - @eval function rrule!!( - ::CoDual{typeof(_foreigncall_)}, - ::CoDual{Val{$(blas_name(trsm))}}, - ::CoDual, - ::CoDual, # arg types - ::CoDual, # nreq - ::CoDual, # calling convention - _side::CoDual{Ptr{UInt8}}, - _uplo::CoDual{Ptr{UInt8}}, - _trans::CoDual{Ptr{UInt8}}, - _diag::CoDual{Ptr{UInt8}}, - _M::CoDual{Ptr{BLAS.BlasInt}}, - _N::CoDual{Ptr{BLAS.BlasInt}}, - _alpha::CoDual{Ptr{$elty}}, - _A::CoDual{Ptr{$elty}}, - _lda::CoDual{Ptr{BLAS.BlasInt}}, - _B::CoDual{Ptr{$elty}}, - _ldb::CoDual{Ptr{BLAS.BlasInt}}, - args::Vararg{Any,Nargs}, - ) where {Nargs} - GC.@preserve args begin - side = Char(unsafe_load(primal(_side))) - uplo = Char(unsafe_load(primal(_uplo))) - trans = Char(unsafe_load(primal(_trans))) - diag = Char(unsafe_load(primal(_diag))) - M = unsafe_load(primal(_M)) - N = unsafe_load(primal(_N)) - R = side == 'L' ? M : N - alpha = unsafe_load(primal(_alpha)) - lda = unsafe_load(primal(_lda)) - ldb = unsafe_load(primal(_ldb)) - A = wrap_ptr_as_view(primal(_A), lda, R, R) - B = wrap_ptr_as_view(primal(_B), ldb, M, N) - B_copy = copy(B) - - trsm!(side, uplo, trans, diag, alpha, A, B) - - dalpha = tangent(_alpha) - _dA = tangent(_A) - _dB = tangent(_B) - end +@is_primitive( + MinimalCtx, + Tuple{ + typeof(BLAS.trsm!),Char,Char,Char,Char,P,AbstractMatrix{P},AbstractMatrix{P} + } where {P<:BlasRealFloat}, +) - function trsm_pb!!(::NoRData) - GC.@preserve args begin - # Convert pointers to views. - dA = wrap_ptr_as_view(_dA, lda, R, R) - dB = wrap_ptr_as_view(_dB, ldb, M, N) - - # Increment alpha tangent. - alpha != 0 && unsafe_store!(dalpha, unsafe_load(dalpha) + tr(dB'B) / alpha) - - # Increment cotangents. - if side == 'L' - if trans == 'N' - tmp = trsm!('L', uplo, 'T', diag, -one($elty), A, dB * B') - dA .+= tri!(tmp, uplo, diag) - else - tmp = trsm!('R', uplo, 'T', diag, -one($elty), A, B * dB') - dA .+= tri!(tmp, uplo, diag) - end - else - if trans == 'N' - tmp = trsm!('R', uplo, 'T', diag, -one($elty), A, B'dB) - dA .+= tri!(tmp, uplo, diag) - else - tmp = trsm!('L', uplo, 'T', diag, -one($elty), A, dB'B) - dA .+= tri!(tmp, uplo, diag) - end - end - - # Restore initial state. - B .= B_copy - - # Compute dB tangent. - BLAS.trsm!(side, uplo, trans == 'N' ? 'T' : 'N', diag, alpha, A, dB) - end +function frule!!( + ::Dual{typeof(BLAS.trsm!)}, + _side::Dual{Char}, + _uplo::Dual{Char}, + _t::Dual{Char}, + _diag::Dual{Char}, + α_dα::Dual{P}, + A_dA::Dual{<:AbstractMatrix{P}}, + B_dB::Dual{<:AbstractMatrix{P}}, +) where {P<:BlasRealFloat} + + # Extract parameters. + side = primal(_side) + uplo = primal(_uplo) + trans = primal(_t) + diag = primal(_diag) + α, dα = extract(α_dα) + A, dA = arrayify(A_dA) + B, dB = arrayify(B_dB) + + # Compute Frechet derivative. + BLAS.trsm!(side, uplo, trans, diag, α, A, dB) + tmp = copy(B) + trsm!(side, uplo, trans, diag, one(P), A, tmp) # tmp now contains inv(A) B. + dB .+= dα .* tmp + + tmp2 = copy(tmp) + BLAS.trmm!(side, uplo, trans, diag, α, dA, tmp) # tmp now contains α dA inv(A) B. + if diag == 'U' + tmp .-= α .* tmp2 + end + BLAS.trsm!(side, uplo, trans, diag, one(P), A, tmp) # tmp is now α inv(A) dA inv(A) B. + dB .-= tmp + + # Run primal computation. + BLAS.trsm!(side, uplo, trans, diag, α, A, B) + return B_dB +end + +function rrule!!( + ::CoDual{typeof(BLAS.trsm!)}, + _side::CoDual{Char}, + _uplo::CoDual{Char}, + _t::CoDual{Char}, + _diag::CoDual{Char}, + α_dα::CoDual{P}, + A_dA::CoDual{<:AbstractMatrix{P}}, + B_dB::CoDual{<:AbstractMatrix{P}}, +) where {P<:BlasRealFloat} + + # Extract parameters. + side = primal(_side) + uplo = primal(_uplo) + trans = primal(_t) + diag = primal(_diag) + α = primal(α_dα) + A, dA = arrayify(A_dA) + B, dB = arrayify(B_dB) + + # Copy memory which will be overwritten by primal computation. + B_copy = copy(B) + + # Run primal computation. + trsm!(side, uplo, trans, diag, α, A, B) + + function trsm_adjoint(::NoRData) + # Compute α gradient. + ∇α = tr(dB'B) / α - return tuple_fill(NoRData(), Val(17 + Nargs)) + # Increment cotangents. + if side == 'L' + if trans == 'N' + tmp = trsm!('L', uplo, 'T', diag, -one(P), A, dB * B') + dA .+= tri!(tmp, uplo, diag) + else + tmp = trsm!('R', uplo, 'T', diag, -one(P), A, B * dB') + dA .+= tri!(tmp, uplo, diag) + end + else + if trans == 'N' + tmp = trsm!('R', uplo, 'T', diag, -one(P), A, B'dB) + dA .+= tri!(tmp, uplo, diag) + else + tmp = trsm!('L', uplo, 'T', diag, -one(P), A, dB'B) + dA .+= tri!(tmp, uplo, diag) + end end - return zero_fcodual(Cvoid()), trsm_pb!! + + # Restore initial state. + B .= B_copy + + # Compute dB tangent. + BLAS.trsm!(side, uplo, trans == 'N' ? 'T' : 'N', diag, α, A, dB) + return tuple_fill(NoRData(), Val(5))..., ∇α, NoRData(), NoRData() end + + return B_dB, trsm_adjoint end function blas_matrices(rng::AbstractRNG, P::Type{<:BlasFloat}, p::Int, q::Int) @@ -815,6 +1100,15 @@ function blas_matrices(rng::AbstractRNG, P::Type{<:BlasFloat}, p::Int, q::Int) return Xs end +function invertible_blas_matrices(rng::AbstractRNG, P::Type{<:BlasFloat}, p::Int) + return map(blas_matrices(rng, P, p, p)) do A + U, _, V = svd(0.1 * A + I) + λs = p > 1 ? collect(range(1.0, 2.0; length=p)) : [1.0] + A .= collect(U * Diagonal(λs) * V') + return A + end +end + function blas_vectors(rng::AbstractRNG, P::Type{<:BlasFloat}, p::Int) xs = Any[ randn(rng, P, p), @@ -829,41 +1123,56 @@ end function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) t_flags = ['N', 'T', 'C'] - alphas = [1.0, -0.25] - betas = [0.0, 0.33] + αs = [1.0, -0.25] + dαs = [0.0, 0.44] + βs = [0.0, 0.33] + dβs = [0.0, -0.11] uplos = ['L', 'U'] dAs = ['N', 'U'] Ps = [Float64, Float32] rng = rng_ctor(123456) test_cases = vcat( + + # + # BLAS LEVEL 1 + # + # nrm2(x) map_prod([Ps..., ComplexF64, ComplexF32]) do (P,) return map([randn(rng, P, 105)]) do x - (false, :none, nothing, BLAS.nrm2, x) + (false, :stability, nothing, BLAS.nrm2, x) end end..., # nrm2(n, x, incx) map_prod([Ps..., ComplexF64, ComplexF32], [5, 3], [1, 2]) do (P, n, incx) return map([randn(rng, P, 105)]) do x - (false, :none, nothing, BLAS.nrm2, n, x, incx) + (false, :stability, nothing, BLAS.nrm2, n, x, incx) end end..., + map_prod(Ps, [1, 3, 11], [1, 2, 11]) do (P, n, incx) + flags = (false, :stability, nothing) + return (flags..., BLAS.scal!, n, randn(rng, P), randn(rng, P, n * incx), incx) + end, + + # + # BLAS LEVEL 2 + # # gemv! - map_prod(t_flags, [1, 3], [1, 2], Ps) do (tA, M, N, P) + map_prod(t_flags, [1, 3], [1, 2], Ps, αs, βs) do (tA, M, N, P, α, β) As = blas_matrices(rng, P, tA == 'N' ? M : N, tA == 'N' ? N : M) xs = blas_vectors(rng, P, N) ys = blas_vectors(rng, P, M) flags = (false, :stability, (lb=1e-3, ub=10.0)) return map(As, xs, ys) do A, x, y - (flags..., BLAS.gemv!, tA, randn(rng, P), A, x, randn(rng, P), y) + (flags..., BLAS.gemv!, tA, P(α), A, x, P(β), y) end end..., # symv! - map_prod(['L', 'U'], alphas, betas, Ps) do (uplo, α, β, P) + map_prod(['L', 'U'], αs, βs, Ps) do (uplo, α, β, P) As = blas_matrices(rng, P, 5, 5) ys = blas_vectors(rng, P, 5) xs = blas_vectors(rng, P, 5) @@ -881,18 +1190,25 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) end end..., + # # + # # BLAS LEVEL 3 + # # + # gemm! - map_prod(t_flags, t_flags, alphas, betas, Ps) do (tA, tB, a, b, P) + map_prod(t_flags, t_flags, αs, βs, Ps, dαs, dβs) do (tA, tB, α, β, P, dα, dβ) As = blas_matrices(rng, P, tA == 'N' ? 3 : 4, tA == 'N' ? 4 : 3) Bs = blas_matrices(rng, P, tB == 'N' ? 4 : 5, tB == 'N' ? 5 : 4) Cs = blas_matrices(rng, P, 3, 5) + return map(As, Bs, Cs) do A, B, C - (false, :stability, nothing, BLAS.gemm!, tA, tB, P(a), A, B, P(b), C) + a_da = CoDual(P(α), P(dα)) + b_db = CoDual(P(β), P(dβ)) + (false, :stability, nothing, BLAS.gemm!, tA, tB, a_da, A, B, b_db, C) end end..., # symm! - map_prod(['L', 'R'], ['L', 'U'], alphas, betas, Ps) do (side, ul, α, β, P) + map_prod(['L', 'R'], ['L', 'U'], αs, βs, Ps) do (side, ul, α, β, P) nA = side == 'L' ? 5 : 7 As = blas_matrices(rng, P, nA, nA) Bs = blas_matrices(rng, P, 5, 7) @@ -901,6 +1217,48 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) (false, :stability, nothing, BLAS.symm!, side, ul, P(α), A, B, P(β), C) end end..., + + # syrk! + map_prod(uplos, t_flags, Ps, dαs, dβs) do (uplo, t, P, dα, dβ) + As = blas_matrices(rng, P, t == 'N' ? 3 : 4, t == 'N' ? 4 : 3) + return map(As) do A + α_dα = CoDual(randn(rng, P), P(dα)) + β_dβ = CoDual(randn(rng, P), P(dβ)) + C = randn(rng, P, 3, 3) + (false, :stability, nothing, BLAS.syrk!, uplo, t, α_dα, A, β_dβ, C) + end + end..., + + # trmm! + map_prod( + ['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps, dαs + ) do (side, ul, tA, dA, M, N, P, dα) + t = tA == 'N' + R = side == 'L' ? M : N + As = blas_matrices(rng, P, R, R) + Bs = blas_matrices(rng, P, M, N) + return map(As, Bs) do A, B + α_dα = CoDual(randn(rng, P), P(dα)) + (false, :stability, nothing, BLAS.trmm!, side, ul, tA, dA, α_dα, A, B) + end + end..., + + # trsm! + map_prod( + ['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps + ) do (side, ul, tA, dA, M, N, P) + t = tA == 'N' + R = side == 'L' ? M : N + a = randn(rng, P) + As = map(blas_matrices(rng, P, R, R)) do A + A[diagind(A)] .+= 1 + return A + end + Bs = blas_matrices(rng, P, M, N) + return map(As, Bs) do A, B + (false, :stability, nothing, BLAS.trsm!, side, ul, tA, dA, a, A, B) + end + end..., ) memory = Any[] @@ -934,7 +1292,6 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) (flags..., BLAS.dot, 3, randn(rng, P, 6), 2, randn(rng, P, 4), 1), (flags..., BLAS.dot, 3, randn(rng, P, 6), 1, randn(rng, P, 9), 3), (flags..., BLAS.dot, 3, randn(rng, P, 12), 3, randn(rng, P, 9), 2), - (flags..., BLAS.scal!, 10, P(2.4), randn(rng, P, 30), 2), ] end..., @@ -953,47 +1310,11 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) end end..., - # syrk! - map_prod(uplos, t_flags, Ps) do (uplo, t, P) - As = blas_matrices(rng, P, t == 'N' ? 3 : 4, t == 'N' ? 4 : 3) - C = randn(rng, P, 3, 3) - a = randn(rng, P) - b = randn(rng, P) - return map(As) do A - (false, :none, nothing, BLAS.syrk!, uplo, t, a, A, b, C) - end - end..., - - # trmm! - map_prod( - ['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps - ) do (side, ul, tA, dA, M, N, P) - t = tA == 'N' - R = side == 'L' ? M : N - a = randn(rng, P) - As = blas_matrices(rng, P, R, R) - Bs = blas_matrices(rng, P, M, N) - return map(As, Bs) do A, B - (false, :none, nothing, BLAS.trmm!, side, ul, tA, dA, a, A, B) - end - end..., + # + # Misc extra tests + # - # trsm! - map_prod( - ['L', 'R'], uplos, t_flags, dAs, [1, 3], [1, 2], Ps - ) do (side, ul, tA, dA, M, N, P) - t = tA == 'N' - R = side == 'L' ? M : N - a = randn(rng, P) - As = map(blas_matrices(rng, P, R, R)) do A - A[diagind(A)] .+= 1 - return A - end - Bs = blas_matrices(rng, P, M, N) - return map(As, Bs) do A, B - (false, :none, nothing, BLAS.trsm!, side, ul, tA, dA, a, A, B) - end - end..., + (false, :none, nothing, x -> sum(complex(x) * x), rand(rng, 5, 5)), ) memory = Any[] return test_cases, memory diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index b78c83baf2..84613d6710 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -89,7 +89,9 @@ using Core: Intrinsics using Mooncake import ..Mooncake: rrule!!, + frule!!, CoDual, + Dual, primal, tangent, zero_tangent, @@ -107,7 +109,10 @@ import ..Mooncake: NoRData, rdata, increment_rdata!!, - zero_fcodual + zero_fcodual, + zero_dual, + NoTangent, + Mode using Core.Intrinsics: atomic_pointerref @@ -130,7 +135,11 @@ end macro intrinsic(name) expr = quote $name(x...) = Intrinsics.$name(x...) - (is_primitive)(::Type{MinimalCtx}, ::Type{<:Tuple{typeof($name),Vararg}}) = true + function is_primitive( + ::Type{MinimalCtx}, ::Type{<:Mode}, ::Type{<:Tuple{typeof($name),Vararg}} + ) + true + end translate(::Val{Intrinsics.$name}) = $name end return esc(expr) @@ -139,16 +148,28 @@ end macro inactive_intrinsic(name) expr = quote $name(x...) = Intrinsics.$name(x...) - (is_primitive)(::Type{MinimalCtx}, ::Type{<:Tuple{typeof($name),Vararg}}) = true + function is_primitive( + ::Type{MinimalCtx}, ::Type{<:Mode}, ::Type{<:Tuple{typeof($name),Vararg}} + ) + true + end translate(::Val{Intrinsics.$name}) = $name function rrule!!(f::CoDual{typeof($name)}, args::Vararg{Any,N}) where {N} return Mooncake.zero_adjoint(f, args...) end + function frule!!(f::Dual{typeof($name)}, args::Vararg{Dual,N}) where {N} + f_primal = primal(f) + args_primal = map(primal, args) + return zero_dual(f_primal(args_primal...)) + end end return esc(expr) end @intrinsic abs_float +function frule!!(::Dual{typeof(abs_float)}, x) + return Dual(abs_float(primal(x)), sign(primal(x)) * tangent(x)) +end function rrule!!(::CoDual{typeof(abs_float)}, x) abs_float_pullback!!(dy) = NoRData(), sign(primal(x)) * dy y = abs_float(primal(x)) @@ -156,6 +177,9 @@ function rrule!!(::CoDual{typeof(abs_float)}, x) end @intrinsic add_float +function frule!!(::Dual{typeof(add_float)}, a, b) + return Dual(add_float(primal(a), primal(b)), add_float(tangent(a), tangent(b))) +end function rrule!!(::CoDual{typeof(add_float)}, a, b) add_float_pb!!(c̄) = NoRData(), c̄, c̄ c = add_float(primal(a), primal(b)) @@ -163,6 +187,11 @@ function rrule!!(::CoDual{typeof(add_float)}, a, b) end @intrinsic add_float_fast +function frule!!(::Dual{typeof(add_float_fast)}, a, b) + c = add_float_fast(primal(a), primal(b)) + dc = add_float_fast(tangent(a), tangent(b)) + return Dual(c, dc) +end function rrule!!(::CoDual{typeof(add_float_fast)}, a, b) add_float_fast_pb!!(c̄) = NoRData(), c̄, c̄ c = add_float_fast(primal(a), primal(b)) @@ -205,6 +234,25 @@ end # atomic_pointerswap @intrinsic bitcast +function frule!!(f::Dual{typeof(bitcast)}, t::Dual{Type{T}}, x) where {T} + if T <: IEEEFloat + msg = + "It is not permissible to bitcast to a differentiable type during AD, as " * + "this risks dropping tangents, and therefore risks silently giving the wrong " * + "answer. If this call to bitcast appears as part of the implementation of a " * + "differentiable function, you should write a rule for this function, or modify " * + "its implementation to avoid the bitcast." + throw(ArgumentError(msg)) + end + _x = primal(x) + v = bitcast(T, _x) + if T <: Ptr && _x isa Ptr + dv = bitcast(Ptr{tangent_type(eltype(T))}, tangent(x)) + else + dv = NoTangent() + end + return Dual(v, dv) +end function rrule!!(f::CoDual{typeof(bitcast)}, t::CoDual{Type{T}}, x) where {T} if T <: IEEEFloat msg = @@ -256,7 +304,14 @@ special handling of `cglobal` is used. __cglobal(::Val{s}, x::Vararg{Any,N}) where {s,N} = cglobal(s, x...) translate(::Val{Intrinsics.cglobal}) = __cglobal -Mooncake.is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(__cglobal),Vararg}}) = true +function Mooncake.is_primitive( + ::Type{MinimalCtx}, ::Type{<:Mode}, ::Type{<:Tuple{typeof(__cglobal),Vararg}} +) + return true +end +function frule!!(::Dual{typeof(__cglobal)}, args...) + return Mooncake.uninit_dual(__cglobal(map(primal, args)...)) +end function rrule!!(f::CoDual{typeof(__cglobal)}, args...) return Mooncake.uninit_fcodual(__cglobal(map(primal, args)...)), NoPullback(f, args...) end @@ -273,6 +328,11 @@ end @inactive_intrinsic checked_usub_int @intrinsic copysign_float +function frule!!(::Dual{typeof(copysign_float)}, x, y) + z = copysign_float(primal(x), primal(y)) + dz = sign(primal(y)) * tangent(x) + return Dual(z, dz) +end function rrule!!(::CoDual{typeof(copysign_float)}, x, y) _x = primal(x) _y = primal(y) @@ -286,6 +346,13 @@ end @inactive_intrinsic cttz_int @intrinsic div_float +function frule!!(::Dual{typeof(div_float)}, a, b) + c = div_float(primal(a), primal(b)) + da = tangent(a) + db = tangent(b) + dc = div_float(da, primal(b)) - div_float(primal(a) * db, primal(b)^2) + return Dual(c, dc) +end function rrule!!(::CoDual{typeof(div_float)}, a, b) _a = primal(a) _b = primal(b) @@ -295,6 +362,13 @@ function rrule!!(::CoDual{typeof(div_float)}, a, b) end @intrinsic div_float_fast +function frule!!(::Dual{typeof(div_float_fast)}, a, b) + c = div_float_fast(primal(a), primal(b)) + da = tangent(a) + db = tangent(b) + dc = div_float_fast(da, primal(b)) - div_float_fast(primal(a) * db, primal(b)^2) + return Dual(c, dc) +end function rrule!!(::CoDual{typeof(div_float_fast)}, a, b) _a = primal(a) _b = primal(b) @@ -312,6 +386,11 @@ end @inactive_intrinsic floor_llvm @intrinsic fma_float +function frule!!(::Dual{typeof(fma_float)}, x, y, z) + a = fma_float(primal(x), primal(y), primal(z)) + da = fma_float(tangent(x), primal(y), fma_float(primal(x), tangent(y), tangent(z))) + return Dual(a, da) +end function rrule!!(::CoDual{typeof(fma_float)}, x, y, z) _x = primal(x) _y = primal(y) @@ -320,6 +399,11 @@ function rrule!!(::CoDual{typeof(fma_float)}, x, y, z) end @intrinsic fpext +function frule!!( + ::Dual{typeof(fpext)}, ::Dual{Type{Pext}}, x::Dual{P} +) where {Pext<:IEEEFloat,P<:IEEEFloat} + return Dual(fpext(Pext, primal(x)), fpext(Pext, tangent(x))) +end function rrule!!( ::CoDual{typeof(fpext)}, ::CoDual{Type{Pext}}, x::CoDual{P} ) where {Pext<:IEEEFloat,P<:IEEEFloat} @@ -332,6 +416,11 @@ end @inactive_intrinsic fptoui @intrinsic fptrunc +function frule!!( + ::Dual{typeof(fptrunc)}, ::Dual{Type{Ptrunc}}, x::Dual{P} +) where {Ptrunc<:IEEEFloat,P<:IEEEFloat} + return Dual(fptrunc(Ptrunc, primal(x)), fptrunc(Ptrunc, tangent(x))) +end function rrule!!( ::CoDual{typeof(fptrunc)}, ::CoDual{Type{Ptrunc}}, x::CoDual{P} ) where {Ptrunc<:IEEEFloat,P<:IEEEFloat} @@ -350,6 +439,11 @@ end @inactive_intrinsic lt_float_fast @intrinsic mul_float +function frule!!(::Dual{typeof(mul_float)}, a, b) + p = mul_float(primal(a), primal(b)) + dp = add_float(mul_float(primal(a), tangent(b)), mul_float(primal(b), tangent(a))) + return Dual(p, dp) +end function rrule!!(::CoDual{typeof(mul_float)}, a, b) _a = primal(a) _b = primal(b) @@ -358,6 +452,11 @@ function rrule!!(::CoDual{typeof(mul_float)}, a, b) end @intrinsic mul_float_fast +function frule!!(::Dual{typeof(mul_float_fast)}, a, b) + c = mul_float_fast(primal(a), primal(b)) + dc = mul_float_fast(primal(a), tangent(b)) + mul_float_fast(tangent(a), primal(b)) + return Dual(c, dc) +end function rrule!!(::CoDual{typeof(mul_float_fast)}, a, b) _a = primal(a) _b = primal(b) @@ -368,6 +467,12 @@ end @inactive_intrinsic mul_int @intrinsic muladd_float +function frule!!(::Dual{typeof(muladd_float)}, x, y, z) + a = muladd_float(primal(x), primal(y), primal(z)) + dz = tangent(z) + da = muladd_float(tangent(x), primal(y), muladd_float(primal(x), tangent(y), dz)) + return Dual(a, da) +end function rrule!!(::CoDual{typeof(muladd_float)}, x, y, z) _x = primal(x) _y = primal(y) @@ -381,6 +486,7 @@ end @inactive_intrinsic ne_int @intrinsic neg_float +frule!!(::Dual{typeof(neg_float)}, x) = Dual(neg_float(primal(x)), neg_float(tangent(x))) function rrule!!(::CoDual{typeof(neg_float)}, x) _x = primal(x) neg_float_pullback!!(dy) = NoRData(), -dy @@ -388,6 +494,9 @@ function rrule!!(::CoDual{typeof(neg_float)}, x) end @intrinsic neg_float_fast +function frule!!(::Dual{typeof(neg_float_fast)}, x) + return Dual(neg_float_fast(primal(x)), neg_float_fast(tangent(x))) +end function rrule!!(::CoDual{typeof(neg_float_fast)}, x) _x = primal(x) neg_float_fast_pullback!!(dy) = NoRData(), -dy @@ -399,6 +508,11 @@ end @inactive_intrinsic or_int @intrinsic pointerref +function frule!!(::Dual{typeof(pointerref)}, x, y, z) + a = pointerref(primal(x), primal(y), primal(z)) + da = pointerref(tangent(x), primal(y), primal(z)) + return Dual(a, da) +end function rrule!!(::CoDual{typeof(pointerref)}, x, y, z) _x = primal(x) _y = primal(y) @@ -417,6 +531,11 @@ function rrule!!(::CoDual{typeof(pointerref)}, x, y, z) end @intrinsic pointerset +function frule!!(::Dual{typeof(pointerset)}, p, x, idx, z) + pointerset(primal(p), primal(x), primal(idx), primal(z)) + pointerset(tangent(p), tangent(x), primal(idx), primal(z)) + return p +end function rrule!!(::CoDual{typeof(pointerset)}, p, x, idx, z) _p = primal(p) _idx = primal(idx) @@ -444,22 +563,38 @@ end @inactive_intrinsic slt_int @intrinsic sqrt_llvm +function frule!!(::Dual{typeof(sqrt_llvm)}, x) + y = sqrt_llvm(primal(x)) + dy = tangent(x) / (2 * y) + return Dual(y, dy) +end function rrule!!(::CoDual{typeof(sqrt_llvm)}, x) _x = primal(x) - llvm_sqrt_pullback!!(dy) = NoRData(), dy * inv(2 * sqrt(_x)) - return CoDual(sqrt_llvm(_x), NoFData()), llvm_sqrt_pullback!! + _y = sqrt_llvm(primal(x)) + llvm_sqrt_pullback!!(dy) = NoRData(), dy / (2 * _y) + return CoDual(_y, NoFData()), llvm_sqrt_pullback!! end @intrinsic sqrt_llvm_fast +function frule!!(::Dual{typeof(sqrt_llvm_fast)}, x) + y = sqrt_llvm_fast(primal(x)) + dy = tangent(x) / (2 * y) + return Dual(y, dy) +end function rrule!!(::CoDual{typeof(sqrt_llvm_fast)}, x) - _x = primal(x) - llvm_sqrt_fast_pullback!!(dy) = NoRData(), dy * inv(2 * sqrt(_x)) - return CoDual(sqrt_llvm_fast(_x), NoFData()), llvm_sqrt_fast_pullback!! + _y = sqrt_llvm_fast(primal(x)) + llvm_sqrt_fast_pullback!!(dy) = NoRData(), dy / (2 * _y) + return CoDual(_y, NoFData()), llvm_sqrt_fast_pullback!! end @inactive_intrinsic srem_int @intrinsic sub_float +function frule!!(::Dual{typeof(sub_float)}, a, b) + c = sub_float(primal(a), primal(b)) + dc = sub_float(tangent(a), tangent(b)) + return Dual(c, dc) +end function rrule!!(::CoDual{typeof(sub_float)}, a, b) _a = primal(a) _b = primal(b) @@ -468,6 +603,11 @@ function rrule!!(::CoDual{typeof(sub_float)}, a, b) end @intrinsic sub_float_fast +function frule!!(::Dual{typeof(sub_float_fast)}, a, b) + c = sub_float_fast(primal(a), primal(b)) + dc = sub_float_fast(tangent(a), tangent(b)) + return Dual(c, dc) +end function rrule!!(::CoDual{typeof(sub_float_fast)}, a, b) _a = primal(a) _b = primal(b) @@ -499,8 +639,8 @@ end end # IntrinsicsWrappers -@zero_adjoint MinimalCtx Tuple{typeof(<:),Any,Any} -@zero_adjoint MinimalCtx Tuple{typeof(===),Any,Any} +@zero_derivative MinimalCtx Tuple{typeof(<:),Any,Any} +@zero_derivative MinimalCtx Tuple{typeof(===),Any,Any} # Core._abstracttype @@ -522,6 +662,14 @@ end __vec_to_tuple(v::Vector) = Tuple(v) @is_primitive MinimalCtx Tuple{typeof(__vec_to_tuple),Vector} +function frule!!(::Dual{typeof(__vec_to_tuple)}, v::Dual{<:Vector}) + x = __vec_to_tuple(primal(v)) + if tangent_type(_typeof(x)) == NoTangent + return zero_dual(x) + else + return Dual(x, __vec_to_tuple(tangent(v))) + end +end function rrule!!(::CoDual{typeof(__vec_to_tuple)}, v::CoDual{<:Vector}) dv = tangent(v) @@ -552,6 +700,15 @@ end # Core._structtype # Verify that there thing at the index is non-differentiable. Otherwise error. +function frule!!( + ::Dual{typeof(Core._svec_ref)}, v::Dual{Core.SimpleVector}, _ind::Dual{Int} +) + ind = primal(_ind) + pv = Core._svec_ref(primal(v), ind) + tv = getindex(tangent(v), ind) + isa(tv, NoTangent) || error("expected non-differentiable thing in SimpleVector") + return Dual(pv, tv) +end function rrule!!( f::CoDual{typeof(Core._svec_ref)}, v::CoDual{Core.SimpleVector}, _ind::CoDual{Int} ) @@ -563,16 +720,27 @@ function rrule!!( end # Core._typebody! - +function frule!!(::Dual{typeof(Core._typevar)}, args...) + return zero_dual(Core._typevar(map(primal, args)...)) +end function rrule!!(f::CoDual{typeof(Core._typevar)}, args...) return zero_fcodual(Core._typevar(map(primal, args)...)), NoPullback(f, args...) end +function frule!!(::Dual{typeof(Core.apply_type)}, args...) + return zero_dual(Core.apply_type(map(primal, args)...)) +end function rrule!!(f::CoDual{typeof(Core.apply_type)}, args...) T = Core.apply_type(tuple_map(primal, args)...) return CoDual{_typeof(T),NoFData}(T, NoFData()), NoPullback(f, args...) end +function frule!!(::Dual{typeof(compilerbarrier)}, setting::Dual{Symbol}, v::Dual) + return Dual( + compilerbarrier(primal(setting), primal(v)), + compilerbarrier(primal(setting), tangent(v)), + ) +end function rrule!!(::CoDual{typeof(compilerbarrier)}, setting::CoDual{Symbol}, val::CoDual) compilerbarrier_pb(dout) = NoRData(), NoRData(), dout return compilerbarrier(setting.x, val), compilerbarrier_pb @@ -582,6 +750,10 @@ end # Core.finalizer # Core.get_binding_type +function frule!!(::Dual{typeof(Core.ifelse)}, cond::Dual{Bool}, a::Dual, b::Dual) + _cond = primal(cond) + return Dual(ifelse(_cond, primal(a), primal(b)), ifelse(_cond, tangent(a), tangent(b))) +end function rrule!!(f::CoDual{typeof(Core.ifelse)}, cond, a::A, b::B) where {A,B} _cond = primal(cond) p_a = primal(a) @@ -607,15 +779,39 @@ function rrule!!(f::CoDual{typeof(Core.ifelse)}, cond, a::A, b::B) where {A,B} return CoDual(ifelse(_cond, p_a, p_b), ifelse(_cond, tangent(a), tangent(b))), pb!! end -@zero_adjoint MinimalCtx Tuple{typeof(Core.sizeof),Any} +@zero_derivative MinimalCtx Tuple{typeof(Core.sizeof),Any} # Core.svec -@zero_adjoint MinimalCtx Tuple{typeof(applicable),Vararg} -@zero_adjoint MinimalCtx Tuple{typeof(fieldtype),Vararg} +@zero_derivative MinimalCtx Tuple{typeof(applicable),Vararg} +@zero_derivative MinimalCtx Tuple{typeof(fieldtype),Vararg} +const StandardTangentType = Union{Tuple,NamedTuple,Tangent,MutableTangent,NoTangent} const StandardFDataType = Union{Tuple,NamedTuple,FData,MutableTangent,NoFData} +function frule!!( + ::Dual{typeof(getfield)}, x::Dual{P,<:StandardTangentType}, name::Dual +) where {P} + _name = primal(name) + if tangent_type(P) == NoTangent + return uninit_dual(getfield(primal(x), _name)) + else + return Dual(getfield(primal(x), _name), _get_tangent_field(tangent(x), _name)) + end +end +function frule!!( + ::Dual{typeof(getfield)}, x::Dual{P,<:StandardTangentType}, name::Dual, inbounds::Dual +) where {P} + _name = primal(name) + _inbounds = primal(inbounds) + if tangent_type(P) == NoTangent + return uninit_dual(getfield(primal(x), _name, _inbounds)) + else + y = getfield(primal(x), _name, _inbounds) + dy = _get_tangent_field(tangent(x), _name, _inbounds) + return Dual(y, dy) + end +end function rrule!!( f::CoDual{typeof(getfield)}, x::CoDual{P,<:StandardFDataType}, name::CoDual ) where {P} @@ -676,19 +872,23 @@ is_homogeneous_and_immutable(::Any) = false # return y, pb!! # end -@zero_adjoint MinimalCtx Tuple{typeof(getglobal),Any,Any} +@zero_derivative MinimalCtx Tuple{typeof(getglobal),Any,Any} # invoke -@zero_adjoint MinimalCtx Tuple{typeof(isa),Any,Any} -@zero_adjoint MinimalCtx Tuple{typeof(isdefined),Vararg} +@zero_derivative MinimalCtx Tuple{typeof(isa),Any,Any} +@zero_derivative MinimalCtx Tuple{typeof(isdefined),Vararg} # modifyfield! -@zero_adjoint MinimalCtx Tuple{typeof(nfields),Any} +@zero_derivative MinimalCtx Tuple{typeof(nfields),Any} # replacefield! +function frule!!(::Dual{typeof(setfield!)}, value::Dual, name::Dual, x::Dual) + literal_name = zero_dual(Val(primal(name))) + return frule!!(zero_dual(lsetfield!), value, literal_name, x) +end function rrule!!(::CoDual{typeof(setfield!)}, value::CoDual, name::CoDual, x::CoDual) literal_name = uninit_fcodual(Val(primal(name))) return rrule!!(uninit_fcodual(lsetfield!), value, literal_name, x) @@ -696,6 +896,7 @@ end # swapfield! +frule!!(::Dual{typeof(throw)}, args::Dual...) = throw(map(primal, args)...) rrule!!(::CoDual{typeof(throw)}, args::CoDual...) = throw(map(primal, args)...) struct TuplePullback{N} end @@ -710,6 +911,15 @@ end @inline tuple_pullback(dy::NoRData) = NoRData() +function frule!!(f::Dual{typeof(tuple)}, args::Vararg{Any,N}) where {N} + primal_output = tuple(map(primal, args)...) + if tangent_type(_typeof(primal_output)) == NoTangent + return zero_dual(primal_output) + else + return Dual(primal_output, tuple(map(tangent, args)...)) + end +end + function rrule!!(f::CoDual{typeof(tuple)}, args::Vararg{Any,N}) where {N} primal_output = tuple(map(primal, args)...) if tangent_type(_typeof(primal_output)) == NoTangent @@ -723,12 +933,15 @@ function rrule!!(f::CoDual{typeof(tuple)}, args::Vararg{Any,N}) where {N} end end +function frule!!(::Dual{typeof(typeassert)}, x::Dual, type::Dual) + return Dual(typeassert(primal(x), primal(type)), tangent(x)) +end function rrule!!(::CoDual{typeof(typeassert)}, x::CoDual, type::CoDual) typeassert_pullback(dy) = NoRData(), dy, NoRData() return CoDual(typeassert(primal(x), primal(type)), tangent(x)), typeassert_pullback end -@zero_adjoint MinimalCtx Tuple{typeof(typeof),Any} +@zero_derivative MinimalCtx Tuple{typeof(typeof),Any} function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) _x = Ref(5.0) # data used in tests which aren't protected by GC. @@ -773,15 +986,15 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) # atomic_pointermodify -- NEEDS IMPLEMENTING AND TESTING # atomic_pointerref -- NEEDS IMPLEMENTING AND TESTING # atomic_pointerreplace -- NEEDS IMPLEMENTING AND TESTING - ( - true, - :none, - nothing, - IntrinsicsWrappers.atomic_pointerset, - CoDual(p, dp), - 1.0, - :monotonic, - ), + # ( + # true, + # :none, + # nothing, + # IntrinsicsWrappers.atomic_pointerset, + # CoDual(p, dp), + # 1.0, + # :monotonic, + # ), # atomic_pointerswap -- NEEDS IMPLEMENTING AND TESTING (false, :stability, nothing, IntrinsicsWrappers.bitcast, Int64, 5.0), (false, :stability, nothing, IntrinsicsWrappers.bswap_int, 5), @@ -927,6 +1140,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) (false, :none, nothing, __vec_to_tuple, [1.0]), (false, :none, nothing, __vec_to_tuple, Any[1.0]), (false, :none, nothing, __vec_to_tuple, Any[[1.0]]), + (false, :none, nothing, __vec_to_tuple, [1]), # Core._apply_pure -- NEEDS IMPLEMENTING AND TESTING # Core._call_in_world -- NEEDS IMPLEMENTING AND TESTING # Core._call_in_world_total -- NEEDS IMPLEMENTING AND TESTING diff --git a/src/rrules/dispatch_doctor.jl b/src/rrules/dispatch_doctor.jl index d807fbfc2f..425eccbbad 100644 --- a/src/rrules/dispatch_doctor.jl +++ b/src/rrules/dispatch_doctor.jl @@ -4,7 +4,7 @@ module DispatchDoctorRules # the logic here is the same as other DispatchDoctor extensions # for, e.g., Enzyme and ChainRulesCore. -import ..@zero_adjoint +import ..@zero_derivative import ..DefaultCtx import DispatchDoctor._RuntimeChecks: is_precompiling, checking_enabled @@ -16,16 +16,16 @@ import DispatchDoctor._Utils: type_instability, type_instability_limit_unions -@zero_adjoint DefaultCtx Tuple{typeof(_show_warning),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(_construct_pairs),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(_show_warning),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(_construct_pairs),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(specializing_typeof),Any} -@zero_adjoint DefaultCtx Tuple{typeof(map_specializing_typeof),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(_promote_op),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(type_instability),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(type_instability_limit_unions),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(specializing_typeof),Any} +@zero_derivative DefaultCtx Tuple{typeof(map_specializing_typeof),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(_promote_op),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(type_instability),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(type_instability_limit_unions),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(is_precompiling)} -@zero_adjoint DefaultCtx Tuple{typeof(checking_enabled)} +@zero_derivative DefaultCtx Tuple{typeof(is_precompiling)} +@zero_derivative DefaultCtx Tuple{typeof(checking_enabled)} end diff --git a/src/rrules/fastmath.jl b/src/rrules/fastmath.jl index 7acbb59abf..740d1dd6c9 100644 --- a/src/rrules/fastmath.jl +++ b/src/rrules/fastmath.jl @@ -1,4 +1,8 @@ @is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp_fast),IEEEFloat} +function frule!!(::Dual{typeof(Base.FastMath.exp_fast)}, x::Dual{P}) where {P<:IEEEFloat} + y = Base.FastMath.exp_fast(primal(x)) + return Dual(y, y * tangent(x)) +end function rrule!!( ::CoDual{typeof(Base.FastMath.exp_fast)}, x::CoDual{P} ) where {P<:IEEEFloat} @@ -8,6 +12,10 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp2_fast),IEEEFloat} +function frule!!(::Dual{typeof(Base.FastMath.exp2_fast)}, x::Dual{P}) where {P<:IEEEFloat} + y = Base.FastMath.exp2_fast(primal(x)) + return Dual(y, y * tangent(x) * P(log(2))) +end function rrule!!( ::CoDual{typeof(Base.FastMath.exp2_fast)}, x::CoDual{P} ) where {P<:IEEEFloat} @@ -17,6 +25,10 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp10_fast),IEEEFloat} +function frule!!(::Dual{typeof(Base.FastMath.exp10_fast)}, x::Dual{P}) where {P<:IEEEFloat} + y = Base.FastMath.exp10_fast(primal(x)) + return Dual(y, y * tangent(x) * P(log(10))) +end function rrule!!( ::CoDual{typeof(Base.FastMath.exp10_fast)}, x::CoDual{P} ) where {P<:IEEEFloat} @@ -26,6 +38,11 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{typeof(Base.FastMath.sincos),IEEEFloat} +function frule!!(::Dual{typeof(Base.FastMath.sincos)}, x::Dual{P}) where {P<:IEEEFloat} + y = Base.FastMath.sincos(primal(x)) + dx = tangent(x) + return Dual(y, (y[2] * dx, -y[1] * dx)) +end function rrule!!(::CoDual{typeof(Base.FastMath.sincos)}, x::CoDual{P}) where {P<:IEEEFloat} y = Base.FastMath.sincos(primal(x)) sincos_fast_adj!!(dy::Tuple{P,P}) = NoRData(), dy[1] * y[2] - dy[2] * y[1] @@ -33,7 +50,7 @@ function rrule!!(::CoDual{typeof(Base.FastMath.sincos)}, x::CoDual{P}) where {P< end @is_primitive MinimalCtx Tuple{typeof(Base.log),Union{IEEEFloat,Int}} -@zero_adjoint MinimalCtx Tuple{typeof(log),Int} +@zero_derivative MinimalCtx Tuple{typeof(log),Int} function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:fastmath}) test_cases = reduce( diff --git a/src/rrules/foreigncall.jl b/src/rrules/foreigncall.jl index e50745b1e2..d27d854bf6 100644 --- a/src/rrules/foreigncall.jl +++ b/src/rrules/foreigncall.jl @@ -5,18 +5,25 @@ end Base.showerror(io::IO, err::MissingForeigncallRuleError) = print(io, err.msg) -# Fallback foreigncall rrule. This is a sufficiently common special case, that it's worth +# Fallback foreigncall rules. This is a sufficiently common special case, that it's worth # creating an informative error message, so that users have some chance of knowing why # they're not able to differentiate a piece of code. +function frule!!(::Dual{typeof(_foreigncall_)}, args...) + return throw_missing_foreigncall_rule_error(:frule!!, args...) +end function rrule!!(::CoDual{typeof(_foreigncall_)}, args...) + return throw_missing_foreigncall_rule_error(:rrule!!, args...) +end + +function throw_missing_foreigncall_rule_error(rule_name::Symbol, args...) throw( MissingForeigncallRuleError( - "No rrule!! available for foreigncall with primal argument types " * + "No $rule_name available for foreigncall with primal argument types " * "$(typeof(map(primal, args))). " * "This problem has most likely arisen because there is a ccall somewhere in the " * - "function you are trying to differentiate, for which an rrule!! has not been " * + "function you are trying to differentiate, for which an $rule_name has not been " * "explicitly written." * - "You have three options: write an rrule!! for this foreigncall, write an rrule!! " * + "You have three options: write an $rule_name for this foreigncall, write an $rule_name " * "for a Julia function that calls this foreigncall, or re-write your code to " * "avoid this foreigncall entirely. " * "If you believe that this error has arisen for some other reason than the above, " * @@ -72,11 +79,16 @@ end # Rules to handle / avoid foreigncall nodes # -@zero_adjoint MinimalCtx Tuple{typeof(Base.allocatedinline),Type} +@zero_derivative MinimalCtx Tuple{typeof(Base.allocatedinline),Type} -@zero_adjoint MinimalCtx Tuple{typeof(objectid),Any} +@zero_derivative MinimalCtx Tuple{typeof(objectid),Any} @is_primitive MinimalCtx Tuple{typeof(pointer_from_objref),Any} +function frule!!(::Dual{typeof(pointer_from_objref)}, x) + y = pointer_from_objref(primal(x)) + dy = bitcast(Ptr{tangent_type(Nothing)}, pointer_from_objref(tangent(x))) + return Dual(y, dy) +end function rrule!!(f::CoDual{typeof(pointer_from_objref)}, x) y = CoDual( pointer_from_objref(primal(x)), @@ -85,16 +97,19 @@ function rrule!!(f::CoDual{typeof(pointer_from_objref)}, x) return y, NoPullback(f, x) end -@zero_adjoint MinimalCtx Tuple{typeof(CC.return_type),Vararg} +@zero_derivative MinimalCtx Tuple{typeof(CC.return_type),Vararg} @is_primitive MinimalCtx Tuple{typeof(Base.unsafe_pointer_to_objref),Ptr} +function frule!!(::Dual{typeof(Base.unsafe_pointer_to_objref)}, x::Dual{<:Ptr}) + return Dual(unsafe_pointer_to_objref(primal(x)), unsafe_pointer_to_objref(tangent(x))) +end function rrule!!(f::CoDual{typeof(Base.unsafe_pointer_to_objref)}, x::CoDual{<:Ptr}) y = CoDual(unsafe_pointer_to_objref(primal(x)), unsafe_pointer_to_objref(tangent(x))) return y, NoPullback(f, x) end -@zero_adjoint MinimalCtx Tuple{typeof(Threads.threadid)} -@zero_adjoint MinimalCtx Tuple{typeof(typeintersect),Any,Any} +@zero_derivative MinimalCtx Tuple{typeof(Threads.threadid)} +@zero_derivative MinimalCtx Tuple{typeof(typeintersect),Any,Any} function _increment_pointer!(x::Ptr{T}, y::Ptr{T}, N::Integer) where {T} increment!!(unsafe_wrap(Vector{T}, x, N), unsafe_wrap(Vector{T}, y, N)) @@ -105,6 +120,13 @@ end # Since we can't differentiate `memmove` (due to a lack of type information), it is # necessary to work with `unsafe_copyto!` instead. @is_primitive MinimalCtx Tuple{typeof(unsafe_copyto!),Ptr{T},Ptr{T},Any} where {T} +function frule!!( + ::Dual{typeof(unsafe_copyto!)}, dest::Dual{Ptr{T}}, src::Dual{Ptr{T}}, n::Dual +) where {T} + unsafe_copyto!(primal(dest), primal(src), primal(n)) + unsafe_copyto!(tangent(dest), tangent(src), primal(n)) + return dest +end function rrule!!( ::CoDual{typeof(unsafe_copyto!)}, dest::CoDual{Ptr{T}}, src::CoDual{Ptr{T}}, n::CoDual ) where {T} @@ -137,6 +159,22 @@ function rrule!!( return dest, unsafe_copyto!_pb!! end +function frule!!( + ::Dual{typeof(_foreigncall_)}, + ::Dual{Val{:jl_reshape_array}}, + ::Dual{Val{Array{P,M}}}, + ::Dual{Tuple{Val{Any},Val{Any},Val{Any}}}, + ::Dual, # nreq + ::Dual, # calling convention + x::Dual{Type{Array{P,M}}}, + a::Dual{Array{P,N},Array{T,N}}, + dims::Dual, +) where {P,T,M,N} + d = primal(dims) + y = ccall(:jl_reshape_array, Array{P,M}, (Any, Any, Any), Array{P,M}, primal(a), d) + dy = ccall(:jl_reshape_array, Array{T,M}, (Any, Any, Any), Array{T,M}, tangent(a), d) + return Dual(y, dy) +end function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{:jl_reshape_array}}, @@ -156,6 +194,23 @@ function rrule!!( return y, NoPullback(ntuple(_ -> NoRData(), 9)) end +function frule!!( + ::Dual{typeof(_foreigncall_)}, + ::Dual{Val{:jl_array_isassigned}}, + ::Dual{RT}, # return type is Int32 + arg_types::Dual{AT}, # arg types are (Any, UInt64) + ::Dual{nreq}, # nreq + ::Dual{calling_convention}, # calling convention + a::Dual{<:Array}, + ii::Dual{UInt}, + args..., +) where {RT,AT,nreq,calling_convention} + GC.@preserve args begin + y = ccall(:jl_array_isassigned, Cint, (Any, UInt), primal(a), primal(ii)) + end + return zero_dual(y) +end + function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{:jl_array_isassigned}}, @@ -173,6 +228,18 @@ function rrule!!( return zero_fcodual(y), NoPullback(ntuple(_ -> NoRData(), length(args) + 8)) end +function frule!!( + ::Dual{typeof(_foreigncall_)}, + ::Dual{Val{:jl_type_unionall}}, + ::Dual{Val{Any}}, # return type + ::Dual{Tuple{Val{Any},Val{Any}}}, # arg types + ::Dual{Val{0}}, # number of required args + ::Dual{Val{:ccall}}, + a::Dual, + b::Dual, +) + return zero_dual(ccall(:jl_type_unionall, Any, (Any, Any), primal(a), primal(b))) +end function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{:jl_type_unionall}}, @@ -188,6 +255,7 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{typeof(deepcopy),Any} +frule!!(::Dual{typeof(deepcopy)}, x::Dual) = Dual(deepcopy(primal(x)), deepcopy(tangent(x))) function rrule!!(::CoDual{typeof(deepcopy)}, x::CoDual) fdx = tangent(x) dx = zero_rdata(primal(x)) @@ -200,10 +268,16 @@ function rrule!!(::CoDual{typeof(deepcopy)}, x::CoDual) return y, deepcopy_pb!! end -@zero_adjoint MinimalCtx Tuple{typeof(fieldoffset),DataType,Integer} -@zero_adjoint MinimalCtx Tuple{Type{UnionAll},TypeVar,Any} -@zero_adjoint MinimalCtx Tuple{Type{UnionAll},TypeVar,Type} -@zero_adjoint MinimalCtx Tuple{typeof(hash),Vararg} +@zero_derivative MinimalCtx Tuple{typeof(fieldoffset),DataType,Integer} +@zero_derivative MinimalCtx Tuple{Type{UnionAll},TypeVar,Any} +@zero_derivative MinimalCtx Tuple{Type{UnionAll},TypeVar,Type} +@zero_derivative MinimalCtx Tuple{typeof(hash),Vararg} + +function frule!!( + ::Dual{typeof(_foreigncall_)}, ::Dual{Val{:jl_string_ptr}}, args::Vararg{Dual,N} +) where {N} + return uninit_dual(_foreigncall_(Val(:jl_string_ptr), tuple_map(primal, args)...)) +end function rrule!!( f::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{:jl_string_ptr}}, args::Vararg{CoDual,N} @@ -257,6 +331,9 @@ for name in [ ) where {RT,nreq,calling_convention} return unexpected_foreigncall_error($name) end + @eval function frule!!(::Dual{typeof(_foreigncall_)}, ::Dual{Val{$name}}, args...) + return unexpected_foreigncall_error($name) + end @eval function rrule!!(::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$name}}, args...) return unexpected_foreigncall_error($name) end diff --git a/src/rrules/iddict.jl b/src/rrules/iddict.jl index 069bdba946..d947f06c44 100644 --- a/src/rrules/iddict.jl +++ b/src/rrules/iddict.jl @@ -114,6 +114,11 @@ tangent(f::IdDict, ::NoRData) = f # standard built-in functionality on `IdDict`s. @is_primitive MinimalCtx Tuple{typeof(Base.rehash!),IdDict,Any} +function frule!!(::Dual{typeof(Base.rehash!)}, d::Dual{<:IdDict}, newsz::Dual) + Base.rehash!(primal(d), primal(newsz)) + Base.rehash!(tangent(d), primal(newsz)) + return d +end function rrule!!(::CoDual{typeof(Base.rehash!)}, d::CoDual{<:IdDict}, newsz::CoDual) Base.rehash!(primal(d), primal(newsz)) Base.rehash!(tangent(d), primal(newsz)) @@ -121,6 +126,11 @@ function rrule!!(::CoDual{typeof(Base.rehash!)}, d::CoDual{<:IdDict}, newsz::CoD end @is_primitive MinimalCtx Tuple{typeof(setindex!),IdDict,Any,Any} +function frule!!(::Dual{typeof(setindex!)}, d::Dual{IdDict{K,V}}, val, key) where {K,V} + setindex!(primal(d), primal(val), primal(key)) + setindex!(tangent(d), tangent(val), primal(key)) + return d +end function rrule!!(::CoDual{typeof(setindex!)}, d::CoDual{IdDict{K,V}}, val, key) where {K,V} k = primal(key) restore_state = in(k, keys(primal(d))) @@ -154,6 +164,13 @@ function rrule!!(::CoDual{typeof(setindex!)}, d::CoDual{IdDict{K,V}}, val, key) end @is_primitive MinimalCtx Tuple{typeof(get),IdDict,Any,Any} +function frule!!( + ::Dual{typeof(get)}, d::Dual{IdDict{K,V}}, key::Dual, default::Dual +) where {K,V} + x = get(primal(d), primal(key), primal(default)) + dx = get(tangent(d), primal(key), tangent(default)) + return Dual(x, dx) +end function rrule!!( ::CoDual{typeof(get)}, d::CoDual{IdDict{K,V}}, key::CoDual, default::CoDual ) where {K,V} @@ -177,6 +194,9 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{typeof(getindex),IdDict,Any} +function frule!!(::Dual{typeof(getindex)}, d::Dual{IdDict{K,V}}, key::Dual) where {K,V} + return Dual(getindex(primal(d), primal(key)), getindex(tangent(d), primal(key))) +end function rrule!!( ::CoDual{typeof(getindex)}, d::CoDual{IdDict{K,V}}, key::CoDual ) where {K,V} @@ -193,12 +213,18 @@ end for name in [:(:jl_idtable_rehash), :(:jl_eqtable_put), :(:jl_eqtable_get), :(:jl_eqtable_nextind)] + @eval function frule!!(::Dual{typeof(_foreigncall_)}, ::Dual{Val{$name}}, args...) + return unexpected_foreigncall_error($name) + end @eval function rrule!!(::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$name}}, args...) return unexpected_foreigncall_error($name) end end @is_primitive MinimalCtx Tuple{Type{IdDict{K,V}} where {K,V}} +function frule!!(::Dual{Type{IdDict{K,V}}}) where {K,V} + return Dual(IdDict{K,V}(), IdDict{K,tangent_type(V)}()) +end function rrule!!(f::CoDual{Type{IdDict{K,V}}}) where {K,V} return CoDual(IdDict{K,V}(), IdDict{K,tangent_type(V)}()), NoPullback(f) end diff --git a/src/rrules/lapack.jl b/src/rrules/lapack.jl index 4312e9cb56..ede01c2217 100644 --- a/src/rrules/lapack.jl +++ b/src/rrules/lapack.jl @@ -1,4 +1,11 @@ +# See https://sethaxen.com/blog/2021/02/differentiating-the-lu-decomposition/ for details. @is_primitive(MinimalCtx, Tuple{typeof(LAPACK.getrf!),AbstractMatrix{<:BlasRealFloat}}) +function frule!!( + ::Dual{typeof(LAPACK.getrf!)}, A_dA::Dual{<:AbstractMatrix{P}} +) where {P<:BlasRealFloat} + _, ipiv, info = LAPACK.getrf!(primal(A_dA)) + return _getrf_fwd(A_dA, ipiv, info) +end function rrule!!( ::CoDual{typeof(LAPACK.getrf!)}, _A::CoDual{<:AbstractMatrix{P}} ) where {P<:BlasRealFloat} @@ -25,6 +32,16 @@ end typeof(Core.kwcall),NamedTuple,typeof(LAPACK.getrf!),AbstractMatrix{<:BlasRealFloat} }, ) +function frule!!( + ::Dual{typeof(Core.kwcall)}, + _kwargs::Dual{<:NamedTuple}, + ::Dual{typeof(getrf!)}, + A_dA::Dual{<:AbstractMatrix{P}}, +) where {P<:BlasRealFloat} + check = primal(_kwargs).check + _, ipiv, info = LAPACK.getrf!(primal(A_dA); check) + return _getrf_fwd(A_dA, ipiv, info) +end function rrule!!( ::CoDual{typeof(Core.kwcall)}, _kwargs::CoDual{<:NamedTuple}, @@ -49,6 +66,19 @@ function rrule!!( return CoDual((_A.x, ipiv, code), (_A.dx, dipiv, NoFData())), getrf_pb!! end +function _getrf_fwd(A_dA, ipiv, info) + A, dA = arrayify(A_dA) + + # Compute Frechet derivative. + L = UnitLowerTriangular(A) + U = UpperTriangular(A) + p = LinearAlgebra.ipiv2perm(ipiv, size(A, 2)) + F = rdiv!(ldiv!(L, dA[p, :]), U) + dA .= L * tril(F, -1) + triu(F) * U + + return Dual((A, ipiv, info), (tangent(A_dA), zero_tangent(ipiv), NoTangent())) +end + function _getrf_pb!(A, dA, ipiv, A_copy) # Run reverse-pass. @@ -65,7 +95,6 @@ function _getrf_pb!(A, dA, ipiv, A_copy) dA .= (inv(L') * _dF * inv(U'))[invperm(p), :] # Restore initial state. - # ipiv .= ipiv_copy A .= A_copy return nothing @@ -77,6 +106,43 @@ end typeof(trtrs!),Char,Char,Char,AbstractMatrix{P},AbstractVecOrMat{P} } where {P<:BlasRealFloat}, ) +function frule!!( + ::Dual{typeof(trtrs!)}, + _uplo::Dual{Char}, + _trans::Dual{Char}, + _diag::Dual{Char}, + A_dA::Dual{<:AbstractMatrix{P}}, + B_dB::Dual{<:AbstractVecOrMat{P}}, +) where {P<:BlasRealFloat} + + # Extract data. + uplo = primal(_uplo) + trans = primal(_trans) + diag = primal(_diag) + A, dA = arrayify(A_dA) + B, dB = arrayify(B_dB) + + # Compute Frechet derivative. + LAPACK.trtrs!(uplo, trans, diag, A, dB) + tmp = copy(B) + LAPACK.trtrs!(uplo, trans, diag, A, tmp) # tmp now contains inv(A) B. + + tmp2 = copy(tmp) + if diag == 'N' + a = uplo == 'L' ? LowerTriangular(dA) : UpperTriangular(dA) + lmul!(trans == 'N' ? a : a', tmp) + else + a = uplo == 'L' ? UnitLowerTriangular(dA) : UnitUpperTriangular(dA) + lmul!(trans == 'N' ? a : a', tmp) + tmp .-= tmp2 + end + LAPACK.trtrs!(uplo, trans, diag, A, tmp) # tmp is now α inv(A) dA inv(A) B. + dB .-= tmp + + # Run primal computation. + LAPACK.trtrs!(uplo, trans, diag, A, B) + return B_dB +end function rrule!!( ::CoDual{typeof(trtrs!)}, _uplo::CoDual{Char}, @@ -120,6 +186,41 @@ end typeof(getrs!),Char,AbstractMatrix{P},AbstractVector{Int},AbstractVecOrMat{P} } where {P<:BlasRealFloat} ) +function frule!!( + ::Dual{typeof(getrs!)}, + _trans::Dual{Char}, + A_dA::Dual{<:AbstractMatrix{P}}, + _ipiv::Dual{<:AbstractVector{Int}}, + B_dB::Dual{<:AbstractVecOrMat{P}}, +) where {P<:BlasRealFloat} + + # Extract data. + trans = primal(_trans) + A, dA = arrayify(A_dA) + ipiv = primal(_ipiv) + B, dB = arrayify(B_dB) + + # Run primal computation. + LAPACK.getrs!(trans, A, ipiv, B) + + # Compute Frechet derivative. + L = UnitLowerTriangular(A) + dL_plus_I = UnitLowerTriangular(dA) + U = UpperTriangular(A) + dU = UpperTriangular(dA) + p = LinearAlgebra.ipiv2perm(ipiv, size(dB, 1)) + tmp = dL_plus_I * U + tmp .-= U + tmp2 = mul!(tmp, L, dU, one(P), one(P))[invperm(p), :] + if trans == 'N' + mul!(dB, tmp2, B, -one(P), one(P)) + else + mul!(dB, tmp2', B, -one(P), one(P)) + end + LAPACK.getrs!(trans, A, ipiv, dB) + + return B_dB +end function rrule!!( ::CoDual{typeof(getrs!)}, _trans::CoDual{Char}, @@ -162,7 +263,7 @@ function rrule!!( B2 .= B2[invperm(p), :] end - function trtrs_pb!!(::NoRData) + function getrs_pb!!(::NoRData) if trans == 'N' # Run pullback for inv(U) * B. @@ -194,12 +295,39 @@ function rrule!!( B .= B0 return tuple_fill(NoRData(), Val(5)) end - return _B, trtrs_pb!! + return _B, getrs_pb!! end @is_primitive( MinimalCtx, Tuple{typeof(getri!),AbstractMatrix{<:BlasRealFloat},AbstractVector{Int}}, ) +function frule!!( + ::Dual{typeof(getri!)}, + A_dA::Dual{<:AbstractMatrix{P}}, + _ipiv::Dual{<:AbstractVector{Int}}, +) where {P<:BlasRealFloat} + # Extract args. + A, dA = arrayify(A_dA) + ipiv = primal(_ipiv) + + # Compute part of Frechet derivative. + L = UnitLowerTriangular(A) + dL_plus_I = UnitLowerTriangular(dA) + U = UpperTriangular(A) + dU = UpperTriangular(dA) + p = LinearAlgebra.ipiv2perm(ipiv, size(dA, 1)) + tmp = dL_plus_I * U + tmp .-= U + tmp2 = mul!(tmp, L, dU, one(P), one(P))[invperm(p), :] + + # Perform primal computation. + LAPACK.getri!(A, ipiv) + + # Compute Frechet derivative. + dA .= (-A * tmp2 * A) + + return A_dA +end function rrule!!( ::CoDual{typeof(getri!)}, _A::CoDual{<:AbstractMatrix{<:BlasRealFloat}}, @@ -234,6 +362,35 @@ end __sym(X) = (X + X') / 2 @is_primitive(MinimalCtx, Tuple{typeof(potrf!),Char,AbstractMatrix{<:BlasRealFloat}}) +function frule!!( + ::Dual{typeof(potrf!)}, _uplo::Dual{Char}, A_dA::Dual{<:AbstractMatrix{<:BlasRealFloat}} +) + # Extract args and take a copy of A. + uplo = primal(_uplo) + A, dA = arrayify(A_dA) + + # Run primal computation. + _, info = LAPACK.potrf!(uplo, A) + + # Compute Frechet derivative. + if uplo == 'L' + L = LowerTriangular(A) + tmp = LowerTriangular(ldiv!(L, Symmetric(dA, :L) / L')) + @inbounds for n in 1:size(A, 1) + tmp[n, n] = tmp[n, n] / 2 + end + _copytrito!(dA, lmul!(L, tmp), 'L') + else + U = UpperTriangular(A) + tmp = UpperTriangular(rdiv!(U' \ Symmetric(dA, :U), U)) + @inbounds for n in 1:size(A, 1) + tmp[n, n] = tmp[n, n] / 2 + end + _copytrito!(dA, rmul!(tmp, U), 'U') + end + + return Dual((A, info), (tangent(A_dA), NoTangent())) +end function rrule!!( ::CoDual{typeof(potrf!)}, _uplo::CoDual{Char}, @@ -278,6 +435,36 @@ end typeof(potrs!),Char,AbstractMatrix{P},AbstractVecOrMat{P} } where {P<:BlasRealFloat}, ) +function frule!!( + ::Dual{typeof(potrs!)}, + _uplo::Dual{Char}, + A_dA::Dual{<:AbstractMatrix{P}}, + B_dB::Dual{<:AbstractVecOrMat{P}}, +) where {P<:BlasRealFloat} + + # Extract args and take a copy of B. + uplo = primal(_uplo) + A, dA = arrayify(A_dA) + B, dB = arrayify(B_dB) + + # Run primal computation. + LAPACK.potrs!(uplo, A, B) + + # Compute Frechet derivative. + if uplo == 'L' + L = LowerTriangular(A) + dL = LowerTriangular(dA) + mul!(dB, Symmetric(dL * L' + L * dL'), B, -one(P), one(P)) + LAPACK.potrs!(uplo, A, dB) + else + U = UpperTriangular(A) + dU = UpperTriangular(dA) + mul!(dB, Symmetric(U'dU + dU'U), B, -one(P), one(P)) + LAPACK.potrs!(uplo, A, dB) + end + + return B_dB +end function rrule!!( ::CoDual{typeof(potrs!)}, _uplo::CoDual{Char}, @@ -322,44 +509,48 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) test_cases = vcat( # getrf! + map_prod(Ps) do (P,) + As = blas_matrices(rng, P, 5, 5) + ipiv = Vector{Int}(undef, 5) + return map(As) do A + (false, :stability, nothing, getrf!, A) + end + end..., map_prod(bools, Ps) do (check, P) As = blas_matrices(rng, P, 5, 5) ipiv = Vector{Int}(undef, 5) return map(As) do A - (false, :none, nothing, getrf!, A) + (false, :stability, nothing, Core.kwcall, (; check), getrf!, A) end end..., - # trtrs + # trtrs! map_prod( - ['U', 'L'], ['N', 'T', 'C'], ['N', 'U'], [1, 3], [1, 2], Ps + ['U', 'L'], ['N', 'T', 'C'], ['N', 'U'], [1, 3], [-1, 1, 2], Ps ) do (ul, tA, diag, N, Nrhs, P) - As = blas_matrices(rng, P, N, N) - Bs = blas_matrices(rng, P, N, Nrhs) - return map(As, Bs) do A, B + As = invertible_blas_matrices(rng, P, N) + Bs = Nrhs == -1 ? blas_vectors(rng, P, N) : blas_matrices(rng, P, N, Nrhs) + Bs = filter(B -> stride(B, 1) == 1, Bs) + return map_prod(As, Bs) do (A, B) (false, :none, nothing, trtrs!, ul, tA, diag, A, B) end end..., # getrs - map_prod(['N', 'T'], [1, 9], [1, 2], Ps) do (trans, N, Nrhs, P) - As = map(blas_matrices(rng, P, N, N)) do A - A[diagind(A)] .+= 5 - return getrf!(A) - end - Bs = blas_matrices(rng, P, N, Nrhs) - return map(As, Bs) do (A, ipiv), B + map_prod(['N', 'T', 'C'], [1, 5], [-1, 1, 2], Ps) do (trans, N, Nrhs, P) + As = map(LAPACK.getrf!, invertible_blas_matrices(rng, P, N)) + Bs = Nrhs == -1 ? [randn(rng, P, N)] : blas_matrices(rng, P, N, Nrhs) + return map_prod(As, Bs) do ((A, _), B) + ipiv = fill(N, N) (false, :none, nothing, getrs!, trans, A, ipiv, B) end end..., # getri map_prod([1, 9], Ps) do (N, P) - As = map(blas_matrices(rng, P, N, N)) do A - A[diagind(A)] .+= 5 - return getrf!(A) - end - return map(As) do (A, ipiv) + As = map(LAPACK.getrf!, invertible_blas_matrices(rng, P, N)) + return map(As) do (A, _) + ipiv = fill(N, N) (false, :none, nothing, getri!, A, ipiv) end end..., @@ -370,18 +561,19 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) A .= A * A' + I return A end - return map(['L', 'U'], As) do uplo, A - return (false, :none, nothing, potrf!, uplo, A) + return map_prod(['L', 'U'], As) do (uplo, A) + return (false, :stability, nothing, potrf!, uplo, A) end end..., # potrs - map_prod([1, 3, 9], [1, 2], Ps) do (N, Nrhs, P) + map_prod([1, 3, 9], [-1, 1, 2], Ps) do (N, Nrhs, P) X = randn(rng, P, N, N) A = X * X' + I - Bs = blas_matrices(rng, P, N, Nrhs) - return map(['L', 'U'], Bs) do uplo, B - (false, :none, nothing, potrs!, uplo, potrf!(uplo, copy(A))[1], copy(B)) + Bs = Nrhs == -1 ? blas_vectors(rng, P, N) : blas_matrices(rng, P, N, Nrhs) + return map_prod(['L', 'U'], Bs) do (uplo, B) + tmp = potrf!(uplo, copy(A))[1] + (false, :none, nothing, potrs!, uplo, tmp, copy(B)) end end..., ) @@ -394,7 +586,6 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) getrf_wrapper!(x, check) = getrf!(x; check) test_cases = vcat(map_prod([false, true], [Float64, Float32]) do (check, P) As = blas_matrices(rng, P, 5, 5) - # ipiv = Vector{Int}(undef, 5) return map(As) do A (false, :none, nothing, getrf_wrapper!, A, check) end diff --git a/src/rrules/linear_algebra.jl b/src/rrules/linear_algebra.jl index a503fbb98f..fc31ce4265 100644 --- a/src/rrules/linear_algebra.jl +++ b/src/rrules/linear_algebra.jl @@ -12,6 +12,11 @@ function (pb::ExpPullback)(::NoRData) return NoRData(), NoRData() end +function frule!!(::Dual{typeof(exp)}, X_dX::Dual{Matrix{P}}) where {P<:IEEEFloat} + X = copy(primal(X_dX)) + dX = copy(tangent(X_dX)) + return Dual(ChainRules.frule((ChainRules.NoTangent(), dX), LinearAlgebra.exp!, X)...) +end function rrule!!(::CoDual{typeof(exp)}, X::CoDual{Matrix{P}}) where {P<:IEEEFloat} Y, pb = ChainRules.rrule(exp, X.x) Ȳ = zero(Y) diff --git a/src/rrules/low_level_maths.jl b/src/rrules/low_level_maths.jl index 60f1c093eb..a1d64e076e 100644 --- a/src/rrules/low_level_maths.jl +++ b/src/rrules/low_level_maths.jl @@ -11,6 +11,10 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) pb_name = Symbol("$(M).$(f)_pb!!") @eval begin @is_primitive MinimalCtx Tuple{typeof($M.$f),P} where {P<:IEEEFloat} + function frule!!(::Dual{typeof($M.$f)}, _x::Dual{P}) where {P<:IEEEFloat} + x = primal(_x) + return Dual(($M.$f)(x), tangent(_x) * $dx) + end function rrule!!(::CoDual{typeof($M.$f)}, _x::CoDual{P}) where {P<:IEEEFloat} x = primal(_x) # needed for dx expression $pb_name(ȳ::P) = NoRData(), ȳ * $dx @@ -22,6 +26,13 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) pb_name = Symbol("$(M).$(f)_pb!!") @eval begin @is_primitive MinimalCtx Tuple{typeof($M.$f),P,P} where {P<:IEEEFloat} + function frule!!( + ::Dual{typeof($M.$f)}, _a::Dual{P}, _b::Dual{P} + ) where {P<:IEEEFloat} + a = primal(_a) + b = primal(_b) + return Dual(($M.$f)(a, b), tangent(_a) * $da + tangent(_b) * $db) + end function rrule!!( ::CoDual{typeof($M.$f)}, _a::CoDual{P}, _b::CoDual{P} ) where {P<:IEEEFloat} @@ -34,30 +45,50 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) end end -@is_primitive MinimalCtx Tuple{typeof(sin),<:IEEEFloat} +@is_primitive MinimalCtx Mode Tuple{typeof(sin),<:IEEEFloat} +function frule!!(::Dual{typeof(sin)}, x::Dual{<:IEEEFloat}) + s, c = sincos(primal(x)) + return Dual(s, c * tangent(x)) +end function rrule!!(::CoDual{typeof(sin),NoFData}, x::CoDual{P,NoFData}) where {P<:IEEEFloat} s, c = sincos(primal(x)) sin_pullback!!(dy::P) = NoRData(), dy * c return CoDual(s, NoFData()), sin_pullback!! end -@is_primitive MinimalCtx Tuple{typeof(cos),<:IEEEFloat} +@is_primitive MinimalCtx Mode Tuple{typeof(cos),<:IEEEFloat} +function frule!!(::Dual{typeof(cos)}, x::Dual{<:IEEEFloat}) + s, c = sincos(primal(x)) + return Dual(c, -s * tangent(x)) +end function rrule!!(::CoDual{typeof(cos),NoFData}, x::CoDual{P,NoFData}) where {P<:IEEEFloat} s, c = sincos(primal(x)) cos_pullback!!(dy::P) = NoRData(), -dy * s return CoDual(c, NoFData()), cos_pullback!! end -@is_primitive MinimalCtx Tuple{typeof(exp),<:IEEEFloat} +@is_primitive MinimalCtx Mode Tuple{typeof(exp),<:IEEEFloat} +function frule!!(::Dual{typeof(exp)}, x::Dual{P}) where {P<:IEEEFloat} + y = exp(primal(x)) + return Dual(y, y * tangent(x)) +end function rrule!!(::CoDual{typeof(exp)}, x::CoDual{P}) where {P<:IEEEFloat} y = exp(primal(x)) exp_pb!!(dy::P) = NoRData(), dy * y return zero_fcodual(y), exp_pb!! end -@from_rrule MinimalCtx Tuple{typeof(^),P,P} where {P<:IEEEFloat} +@from_chainrules MinimalCtx Tuple{typeof(^),P,P} where {P<:IEEEFloat} +function frule!!(::Dual{typeof(^)}, x::Dual{P}, y::Dual{P}) where {P<:IEEEFloat} + t = (ChainRules.NoTangent(), tangent(x), tangent(y)) + z, dz = ChainRules.frule(t, ^, primal(x), primal(y)) + return Dual(z, dz) +end @is_primitive MinimalCtx Tuple{typeof(Base.eps),<:IEEEFloat} +function frule!!(::Dual{typeof(Base.eps)}, x::Dual{<:IEEEFloat}) + return Dual(eps(primal(x)), zero(primal(x))) +end function rrule!!(::CoDual{typeof(Base.eps)}, x::CoDual{P}) where {P<:IEEEFloat} y = Base.eps(primal(x)) eps_pb!!(dy::P) = NoRData(), zero(y) diff --git a/src/rrules/memory.jl b/src/rrules/memory.jl index f6311c97d3..64ba52d8bf 100644 --- a/src/rrules/memory.jl +++ b/src/rrules/memory.jl @@ -222,6 +222,16 @@ end @is_primitive( MinimalCtx, Tuple{typeof(unsafe_copyto!),MemoryRef{P},MemoryRef{P},Int} where {P} ) +function frule!!( + ::Dual{typeof(unsafe_copyto!)}, + dest::Dual{MemoryRef{P}}, + src::Dual{MemoryRef{P}}, + n::Dual{Int}, +) where {P} + unsafe_copyto!(primal(dest), primal(src), primal(n)) + unsafe_copyto!(tangent(dest), tangent(src), primal(n)) + return dest +end function rrule!!( ::CoDual{typeof(unsafe_copyto!)}, dest::CoDual{MemoryRef{P}}, @@ -359,7 +369,9 @@ _val(::Val{c}) where {c} = c using Core: memoryref_isassigned, memoryrefget, memoryrefset!, memoryrefnew, memoryrefoffset -@zero_adjoint(MinimalCtx, Tuple{typeof(memoryref_isassigned),GenericMemoryRef,Symbol,Bool}) +@zero_derivative( + MinimalCtx, Tuple{typeof(memoryref_isassigned),GenericMemoryRef,Symbol,Bool} +) @inline function lmemoryrefget( x::MemoryRef, ::Val{ordering}, ::Val{boundscheck} @@ -368,6 +380,18 @@ using Core: memoryref_isassigned, memoryrefget, memoryrefset!, memoryrefnew, mem end @is_primitive MinimalCtx Tuple{typeof(lmemoryrefget),MemoryRef,Val,Val} +@inline function frule!!( + ::Dual{typeof(lmemoryrefget)}, + x::Dual{<:MemoryRef}, + _ordering::Dual{<:Val}, + _boundscheck::Dual{<:Val}, +) + ordering = primal(_ordering) + bc = primal(_boundscheck) + y = memoryrefget(primal(x), _val(ordering), _val(bc)) + dy = memoryrefget(tangent(x), _val(ordering), _val(bc)) + return Dual(y, dy) +end @inline function rrule!!( ::CoDual{typeof(lmemoryrefget)}, x::CoDual{<:MemoryRef}, @@ -387,6 +411,18 @@ end return CoDual(y, dy), lmemoryrefget_adjoint end +@inline Base.@propagate_inbounds function frule!!( + ::Dual{typeof(memoryrefget)}, + x::Dual{<:MemoryRef}, + _ordering::Dual{Symbol}, + _boundscheck::Dual{Bool}, +) + ordering = primal(_ordering) + boundscheck = primal(_boundscheck) + y = memoryrefget(primal(x), ordering, boundscheck) + dy = memoryrefget(tangent(x), ordering, boundscheck) + return Dual(y, dy) +end @inline Base.@propagate_inbounds function rrule!!( ::CoDual{typeof(memoryrefget)}, x::CoDual{<:MemoryRef}, @@ -405,16 +441,32 @@ end # Core.memoryrefmodify! +@inline function frule!!(::Dual{typeof(memoryrefnew)}, x::Dual{<:Memory}) + return Dual(memoryrefnew(primal(x)), memoryrefnew(tangent(x))) +end @inline function rrule!!(f::CoDual{typeof(memoryrefnew)}, x::CoDual{<:Memory}) return CoDual(memoryrefnew(x.x), memoryrefnew(x.dx)), NoPullback(f, x) end +@inline function frule!!(::Dual{typeof(memoryrefnew)}, x::Dual{<:MemoryRef}, ii::Dual{Int}) + return Dual(memoryrefnew(primal(x), primal(ii)), memoryrefnew(tangent(x), primal(ii))) +end @inline function rrule!!( f::CoDual{typeof(memoryrefnew)}, x::CoDual{<:MemoryRef}, ii::CoDual{Int} ) return CoDual(memoryrefnew(x.x, ii.x), memoryrefnew(x.dx, ii.x)), NoPullback(f, x, ii) end +@inline function frule!!( + ::Dual{typeof(memoryrefnew)}, + x::Dual{<:MemoryRef}, + ii::Dual{Int}, + boundscheck::Dual{Bool}, +) + y = memoryrefnew(primal(x), primal(ii), primal(boundscheck)) + dy = memoryrefnew(tangent(x), primal(ii), primal(boundscheck)) + return Dual(y, dy) +end @inline function rrule!!( f::CoDual{typeof(memoryrefnew)}, x::CoDual{<:MemoryRef}, @@ -426,7 +478,7 @@ end return CoDual(y, dy), NoPullback(f, x, ii, boundscheck) end -@zero_adjoint MinimalCtx Tuple{typeof(memoryrefoffset),GenericMemoryRef} +@zero_derivative MinimalCtx Tuple{typeof(memoryrefoffset),GenericMemoryRef} # Core.memoryrefreplace! @@ -438,6 +490,17 @@ end @is_primitive MinimalCtx Tuple{typeof(lmemoryrefset!),MemoryRef,Any,Val,Val} +@inline function frule!!( + ::Dual{typeof(lmemoryrefset!)}, + x::Dual{<:MemoryRef{P},<:MemoryRef{V}}, + value::Dual, + ::Dual{Val{ordering}}, + ::Dual{Val{boundscheck}}, +) where {P,V,ordering,boundscheck} + memoryrefset!(primal(x), primal(value), ordering, boundscheck) + memoryrefset!(tangent(x), tangent(value), ordering, boundscheck) + return value +end @inline function rrule!!( ::CoDual{typeof(lmemoryrefset!)}, x::CoDual{<:MemoryRef{P},<:MemoryRef{V}}, @@ -490,6 +553,21 @@ function isbits_lmemoryrefset!_rule(x::CoDual, value::CoDual, ordering::Val, bc: return value, isbits_lmemoryrefset!_adjoint end +@inline function frule!!( + ::Dual{typeof(memoryrefset!)}, + x::Dual{<:MemoryRef{P},<:MemoryRef{V}}, + value::Dual, + ordering::Dual{Symbol}, + boundscheck::Dual{Bool}, +) where {P,V} + return frule!!( + zero_dual(lmemoryrefset!), + x, + value, + zero_dual(Val(primal(ordering))), + zero_dual(Val(primal(boundscheck))), + ) +end @inline function rrule!!( ::CoDual{typeof(memoryrefset!)}, x::CoDual{<:MemoryRef{P},<:MemoryRef{V}}, @@ -515,6 +593,10 @@ end # _new_ and _new_-adjacent rules for Memory, MemoryRef, and Array. @is_primitive MinimalCtx Tuple{Type{<:Memory},UndefInitializer,Int} +function frule!!(::Dual{Type{Memory{P}}}, ::Dual{UndefInitializer}, n::Dual{Int}) where {P} + x = Memory{P}(undef, primal(n)) + return Dual(x, zero_tangent_internal(x, NoCache())) +end function rrule!!( ::CoDual{Type{Memory{P}}}, ::CoDual{UndefInitializer}, n::CoDual{Int} ) where {P} @@ -534,6 +616,16 @@ function rrule!!( return CoDual(y, dy), NoPullback(ntuple(_ -> NoRData(), 4)) end +function frule!!( + ::Dual{typeof(_new_)}, + ::Dual{Type{Array{P,N}}}, + ref::Dual{MemoryRef{P}}, + size::Dual{<:NTuple{N,Int}}, +) where {P,N} + y = _new_(Array{P,N}, primal(ref), primal(size)) + dy = _new_(Array{tangent_type(P),N}, tangent(ref), primal(size)) + return Dual(y, dy) +end function rrule!!( ::CoDual{typeof(_new_)}, ::CoDual{Type{Array{P,N}}}, @@ -545,6 +637,17 @@ function rrule!!( return CoDual(y, dy), NoPullback(ntuple(_ -> NoRData(), 4)) end +function frule!!( + ::Dual{typeof(_foreigncall_)}, + ::Dual{Val{:jl_genericmemory_copy}}, + ::Dual, + ::Dual{Tuple{Val{Any}}}, + ::Dual{Val{0}}, + ::Dual{Val{:ccall}}, + x::Dual{<:Memory}, +) + return Dual(primal(copy(x)), tangent(copy(x))) +end function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{:jl_genericmemory_copy}}, @@ -566,6 +669,17 @@ end # getfield / lgetfield rules for Memory, MemoryRef, and Array. +function frule!!( + ::Dual{typeof(lgetfield)}, + x::Dual{<:Memory,<:Memory}, + ::Dual{Val{name}}, + ::Dual{Val{order}}, +) where {name,order} + y = getfield(primal(x), name, order) + wants_length = name === 1 || name === :length + dy = wants_length ? NoTangent() : bitcast(Ptr{NoTangent}, tangent(x).ptr) + return Dual(y, dy) +end function rrule!!( ::CoDual{typeof(lgetfield)}, x::CoDual{<:Memory,<:Memory}, @@ -578,6 +692,17 @@ function rrule!!( return CoDual(y, dy), NoPullback(ntuple(_ -> NoRData(), 4)) end +function frule!!( + ::Dual{typeof(lgetfield)}, + x::Dual{<:MemoryRef,<:MemoryRef}, + ::Dual{Val{name}}, + ::Dual{Val{order}}, +) where {name,order} + y = getfield(primal(x), name, order) + wants_offset = name === 1 || name === :ptr_or_offset + dy = wants_offset ? bitcast(Ptr{NoTangent}, tangent(x).ptr_or_offset) : tangent(x).mem + return Dual(y, dy) +end function rrule!!( ::CoDual{typeof(lgetfield)}, x::CoDual{<:MemoryRef,<:MemoryRef}, @@ -590,6 +715,17 @@ function rrule!!( return CoDual(y, dy), NoPullback(ntuple(_ -> NoRData(), 4)) end +function frule!!( + ::Dual{typeof(lgetfield)}, + x::Dual{<:Array,<:Array}, + ::Dual{Val{name}}, + ::Dual{Val{order}}, +) where {name,order} + y = getfield(primal(x), name, order) + wants_size = name === 2 || name === :size + dy = wants_size ? NoTangent() : tangent(x).ref + return Dual(y, dy) +end function rrule!!( ::CoDual{typeof(lgetfield)}, x::CoDual{<:Array,<:Array}, @@ -604,6 +740,11 @@ end const _MemTypes = Union{Memory,MemoryRef,Array} +function frule!!( + f::Dual{typeof(lgetfield)}, x::Dual{<:_MemTypes,<:_MemTypes}, name::Dual{<:Val} +) + return frule!!(f, x, name, zero_dual(Val(:not_atomic))) +end function rrule!!( f::CoDual{typeof(lgetfield)}, x::CoDual{<:_MemTypes,<:_MemTypes}, name::CoDual{<:Val} ) @@ -612,6 +753,16 @@ function rrule!!( return y, ternary_lgetfield_adjoint end +function frule!!( + ::Dual{typeof(getfield)}, + x::Dual{<:_MemTypes,<:_MemTypes}, + name::Dual{<:Union{Int,Symbol}}, + order::Dual{Symbol}, +) + return frule!!( + zero_dual(lgetfield), x, zero_dual(Val(primal(name))), zero_dual(Val(primal(order))) + ) +end function rrule!!( ::CoDual{typeof(getfield)}, x::CoDual{<:_MemTypes,<:_MemTypes}, @@ -628,6 +779,13 @@ function rrule!!( return y, getfield_adjoint end +function frule!!( + ::Dual{typeof(getfield)}, + x::Dual{<:_MemTypes,<:_MemTypes}, + name::Dual{<:Union{Int,Symbol}}, +) + return frule!!(zero_dual(lgetfield), x, zero_dual(Val(primal(name)))) +end function rrule!!( f::CoDual{typeof(getfield)}, x::CoDual{<:_MemTypes,<:_MemTypes}, @@ -638,6 +796,13 @@ function rrule!!( return y, ternary_getfield_adjoint end +@inline function frule!!( + ::Dual{typeof(lsetfield!)}, value::Dual{<:Array,<:Array}, ::Dual{Val{name}}, x::Dual +) where {name} + setfield!(primal(value), name, primal(x)) + setfield!(tangent(value), name, (name === :size || name === 2) ? primal(x) : tangent(x)) + return x +end @inline function rrule!!( ::CoDual{typeof(lsetfield!)}, value::CoDual{<:Array,<:Array}, @@ -659,6 +824,7 @@ end # Misc. other rules which are required for correctness. @is_primitive MinimalCtx Tuple{typeof(copy),Array} +frule!!(::Dual{typeof(copy)}, a::Dual{<:Array}) = Dual(copy(primal(a)), copy(tangent(a))) function rrule!!(::CoDual{typeof(copy)}, a::CoDual{<:Array}) dx = tangent(a) dy = copy(dx) @@ -672,6 +838,11 @@ end @is_primitive MinimalCtx Tuple{typeof(fill!),Array{<:Union{UInt8,Int8}},Integer} @is_primitive MinimalCtx Tuple{typeof(fill!),Memory{<:Union{UInt8,Int8}},Integer} +function frule!!( + ::Dual{typeof(fill!)}, a::Dual{T}, x::Dual{<:Integer} +) where {V<:Union{UInt8,Int8},T<:Union{Array{V},Memory{V}}} + return Dual(fill!(primal(a), primal(x)), tangent(a)) +end function rrule!!( ::CoDual{typeof(fill!)}, a::CoDual{T}, x::CoDual{<:Integer} ) where {V<:Union{UInt8,Int8},T<:Union{Array{V},Memory{V}}} @@ -879,6 +1050,9 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:memory}) (false, :none, nothing, setfield!, randn(rng, 10), 1, randn(rng, 10).ref), (false, :none, nothing, setfield!, randn(rng, 10), :size, (10,)), (false, :none, nothing, setfield!, randn(rng, 10), 2, (10,)), + (false, :stability, nothing, copy, randn(10)), + (false, :stability, nothing, fill!, fill!(Memory{Int8}(undef, 5), 0), Int8(1)), + (false, :stability, nothing, fill!, fill!(Memory{UInt8}(undef, 5), 0), UInt8(1)), ) memory = Any[] return test_cases, memory diff --git a/src/rrules/misc.jl b/src/rrules/misc.jl index 56e6430b2a..fac109d797 100644 --- a/src/rrules/misc.jl +++ b/src/rrules/misc.jl @@ -6,37 +6,37 @@ # deduce that these bits of code are inactive though. # -@zero_adjoint DefaultCtx Tuple{typeof(in),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(iszero),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(isempty),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(isbitstype),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(sizeof),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(promote_type),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.elsize),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Core.Compiler.sizeof_nothrow),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_haspadding),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_nfields),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_pointerfree),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_alignment),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_fielddesc_type),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(LinearAlgebra.chkstride1),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Threads.nthreads),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.depwarn),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.reduced_indices),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.check_reducedims),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.throw_boundserror),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.Broadcast.eltypes),Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.eltype),Vararg} -@zero_adjoint MinimalCtx Tuple{typeof(Base.padding),DataType} -@zero_adjoint MinimalCtx Tuple{typeof(Base.padding),DataType,Int} -@zero_adjoint MinimalCtx Tuple{Type,TypeVar,Type} +@zero_derivative DefaultCtx Tuple{typeof(in),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(iszero),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(isempty),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(isbitstype),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(sizeof),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(promote_type),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(Base.elsize),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(Core.Compiler.sizeof_nothrow),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(Base.datatype_haspadding),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(Base.datatype_nfields),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(Base.datatype_pointerfree),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(Base.datatype_alignment),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(Base.datatype_fielddesc_type),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(LinearAlgebra.chkstride1),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(Threads.nthreads),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(Base.depwarn),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(Base.reduced_indices),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(Base.check_reducedims),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(Base.throw_boundserror),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(Base.Broadcast.eltypes),Vararg} +@zero_derivative DefaultCtx Tuple{typeof(Base.eltype),Vararg} +@zero_derivative MinimalCtx Tuple{typeof(Base.padding),DataType} +@zero_derivative MinimalCtx Tuple{typeof(Base.padding),DataType,Int} +@zero_derivative MinimalCtx Tuple{Type,TypeVar,Type} # Required to avoid an ambiguity. -@zero_adjoint MinimalCtx Tuple{Type{Symbol},TypeVar,Type} +@zero_derivative MinimalCtx Tuple{Type{Symbol},TypeVar,Type} @static if VERSION >= v"1.11-" - @zero_adjoint MinimalCtx Tuple{typeof(Random.hash_seed),Vararg} - @zero_adjoint MinimalCtx Tuple{typeof(Base.dataids),Memory} + @zero_derivative MinimalCtx Tuple{typeof(Random.hash_seed),Vararg} + @zero_derivative MinimalCtx Tuple{typeof(Base.dataids),Memory} end """ @@ -58,6 +58,24 @@ This approach is identical to the one taken by `Zygote.jl` to circumvent the sam lgetfield(x, ::Val{f}) where {f} = getfield(x, f) @is_primitive MinimalCtx Tuple{typeof(lgetfield),Any,Val} +@inline function frule!!( + ::Dual{typeof(lgetfield)}, x::Dual{P,T}, ::Dual{Val{f}} +) where {P,T<:StandardTangentType,f} + primal_field = getfield(primal(x), f) + if tangent_type(P) === NoTangent + return uninit_dual(primal_field) + else + Dual(primal_field, _get_tangent_field(tangent(x), f)) + end +end + +_get_tangent_field(f::Union{NamedTuple,Tuple}, name) = getfield(f, name) +_get_tangent_field(f::Union{NamedTuple,Tuple}, name, inbounds) = getfield(f, name, inbounds) +_get_tangent_field(f::Union{Tangent,MutableTangent}, name) = val(getfield(f.fields, name)) +function _get_tangent_field(f::Union{Tangent,MutableTangent}, name, inbounds) + return val(getfield(f.fields, name, inbounds)) +end + @inline function rrule!!( ::CoDual{typeof(lgetfield)}, x::CoDual{P,F}, ::CoDual{Val{f}} ) where {P,F<:StandardFDataType,f} @@ -100,6 +118,19 @@ end # code duplication, but it wound up not being any cleaner than this copy + pasted version. @is_primitive MinimalCtx Tuple{typeof(lgetfield),Any,Val,Val} +@inline function frule!!( + ::Dual{typeof(lgetfield)}, + x::Dual{P,<:StandardTangentType}, + ::Dual{Val{f}}, + ::Dual{Val{order}}, +) where {P,f,order} + primal_field = getfield(primal(x), f, order) + if tangent_type(P) === NoTangent + return uninit_dual(primal_field) + else + return Dual(primal_field, _get_tangent_field(tangent(x), f)) + end +end @inline function rrule!!( ::CoDual{typeof(lgetfield)}, x::CoDual{P,F}, ::CoDual{Val{f}}, ::CoDual{Val{order}} ) where {P,F<:StandardFDataType,f,order} @@ -121,12 +152,23 @@ end end @is_primitive MinimalCtx Tuple{typeof(lsetfield!),Any,Any,Any} +@inline function frule!!( + ::Dual{typeof(lsetfield!)}, value::Dual{P,T}, name::Dual, x::Dual +) where {P,T<:StandardTangentType} + return lsetfield_frule(value, name, x) +end @inline function rrule!!( ::CoDual{typeof(lsetfield!)}, value::CoDual{P,F}, name::CoDual, x::CoDual ) where {P,F<:StandardFDataType} return lsetfield_rrule(value, name, x) end +function lsetfield_frule(value::Dual{P,T}, ::Dual{Val{name}}, x::Dual) where {P,T,name} + setfield!(primal(value), name, primal(x)) + T !== NoTangent && set_tangent_field!(tangent(value), name, tangent(x)) + return x +end + function lsetfield_rrule( value::CoDual{P,F}, ::CoDual{Val{name}}, x::CoDual ) where {P,F,name} @@ -162,6 +204,9 @@ end @static if VERSION < v"1.11" @is_primitive MinimalCtx Tuple{typeof(copy),Dict} + function frule!!(::Dual{typeof(copy)}, a::Dual{<:Dict}) + return Dual(copy(primal(a)), _copy_dict_tangent(tangent(a))) + end function rrule!!(::CoDual{typeof(copy)}, a::CoDual{<:Dict}) dx = tangent(a) t = dx.fields diff --git a/src/rrules/misty_closures.jl b/src/rrules/misty_closures.jl new file mode 100644 index 0000000000..d5a3ea4bf4 --- /dev/null +++ b/src/rrules/misty_closures.jl @@ -0,0 +1,160 @@ +""" + MistyClosureTangent(captures_tangent::Any, dual_callable::Any) + +The tangent type for `MistyClosure`. `captures_tangent` contains the tangent to the captured +variables, and `dual_callable` contains a callable object which performs forwards-mode AD. + +That the field type of `captures_tangent` is `Any` is unavoidable since the `captures` +field of an `OpaqueClosure` has field type `Any`. + +That the field type of `dual_callable` is `Any` is a limitation of the current +implementation. The concrete type of `dual_callable` might be one of a couple of things, +notably either `typeof(frule!!)` or `DerivedFRule`. It might be possible to figure out which +it is, and use this information to improve the type stability of this function. +""" +struct MistyClosureTangent + captures_tangent::Any + dual_callable::Any +end + +_dual_mc(p::MistyClosure) = build_frule(get_interpreter(ForwardMode), p) + +tangent_type(::Type{<:MistyClosure}) = MistyClosureTangent + +function zero_tangent_internal(p::MistyClosure, d::MaybeCache) + return MistyClosureTangent(zero_tangent_internal(p.oc.captures, d), _dual_mc(p)) +end + +function randn_tangent_internal(rng::AbstractRNG, p::MistyClosure, d::MaybeCache) + return MistyClosureTangent(randn_tangent_internal(rng, p.oc.captures, d), _dual_mc(p)) +end + +function increment_internal!!(c::IncCache, t::T, s::T) where {T<:MistyClosureTangent} + new_captures_tangent = increment_internal!!(c, t.captures_tangent, s.captures_tangent) + return MistyClosureTangent(new_captures_tangent, t.dual_callable) +end + +function set_to_zero_internal!!(c::SetToZeroCache, t::MistyClosureTangent) + new_captures_tangent = set_to_zero_internal!!(c, t.captures_tangent) + return MistyClosureTangent(new_captures_tangent, t.dual_callable) +end + +function _add_to_primal_internal( + c::MaybeCache, p::MistyClosure, t::MistyClosureTangent, unsafe::Bool +) + new_captures = _add_to_primal_internal(c, p.oc.captures, t.captures_tangent, unsafe) + return replace_captures(p, new_captures) +end + +function _diff_internal(c::MaybeCache, p::P, q::P) where {P<:MistyClosure} + # Just assumes that the code associated to `p` is the same as that of `q`. + captures_tangent = _diff_internal(c, p.oc.captures, q.oc.captures) + return MistyClosureTangent(captures_tangent, _dual_mc(p)) +end + +function _dot_internal(c::MaybeCache, t::T, s::T) where {T<:MistyClosureTangent} + return _dot_internal(c, t.captures_tangent, s.captures_tangent) +end + +function _scale_internal(c::MaybeCache, a::Float64, t::T) where {T<:MistyClosureTangent} + captures_tangent = _scale_internal(c, a, t.captures_tangent) + return T(captures_tangent, t.dual_callable) +end + +import .TestUtils: populate_address_map_internal, AddressMap +function populate_address_map_internal( + m::AddressMap, p::MistyClosure, t::MistyClosureTangent +) + return populate_address_map_internal(m, p.oc.captures, t.captures_tangent) +end + +struct MistyClosureFData + captures_fdata::Any + dual_callable::Any +end + +struct MistyClosureRData{Tr} + captures_rdata::Tr +end + +_copy(r::MistyClosureRData) = MistyClosureRData(deepcopy(r.captures_rdata)) + +fdata_type(::Type{<:MistyClosureTangent}) = MistyClosureFData +function fdata(t::MistyClosureTangent) + return MistyClosureFData(fdata(t.captures_tangent), t.dual_callable) +end + +rdata_type(::Type{<:MistyClosureTangent}) = MistyClosureRData +rdata(t::MistyClosureTangent) = MistyClosureRData(rdata(t.captures_tangent)) + +@foldable function tangent_type(::Type{<:MistyClosureFData}, ::Type{<:MistyClosureRData}) + return MistyClosureTangent +end +function tangent(f::MistyClosureFData, r::MistyClosureRData) + return MistyClosureTangent(tangent(f.captures_fdata, r.captures_rdata), f.dual_callable) +end + +function __verify_fdata_value(::IdDict{Any,Nothing}, p::MistyClosure, t::MistyClosureFData) + return nothing +end +_verify_rdata_value(p::MistyClosure, r::MistyClosureRData) = nothing + +zero_rdata(p::MistyClosure) = MistyClosureRData(zero_rdata(p.oc.captures)) + +function increment!!(x::MistyClosureFData, y::MistyClosureFData) + return MistyClosureFData( + increment!!(x.captures_fdata, y.captures_fdata), x.dual_callable + ) +end + +function increment_internal!!(c::IncCache, x::MistyClosureRData, y::MistyClosureRData) + return MistyClosureRData(increment_internal!!(c, x.captures_rdata, y.captures_rdata)) +end + +function rrule!!( + ::CoDual{typeof(lgetfield)}, x::CoDual{P,F}, ::CoDual{Val{f}} +) where {P<:MistyClosure,F<:MistyClosureFData,f} + misty_closure_getfield_rrule_exception() +end + +function rrule!!( + ::CoDual{typeof(lgetfield)}, x::CoDual{P,F}, ::CoDual{Val{f}}, ::CoDual{Val{order}} +) where {P<:MistyClosure,F<:MistyClosureFData,f,order} + misty_closure_getfield_rrule_exception() +end + +function misty_closure_getfield_rrule_exception() + msg = + "rrule!! for `lgetfield` and `getfield` not implemented for " * + "`MistyClosure`s. That is, you cannot currently query a field of a " * + "`MistyClosure` in code which you differentiate. If this is a " * + "problem for your use-case, please open an issue on the Mooncake.jl " * + "repository." + throw(UnhandledLanguageFeatureException(msg)) +end + +function rrule!!(::CoDual{typeof(_new_)}, p::CoDual{<:MistyClosure}, x::Vararg{CoDual}) + misty_closure_getfield_rrule_exception() +end + +function misty_closure_new_rrule_exception() + msg = + "rrule!! for `_new_` not implemented for `MistyClosure`. That is, " * + "you cannot currently construct a `MistyClosure` in code that you " * + "differentiate. If this is a problem for your use-case, please open " * + "an issue on the Mooncake.jl repository." + throw(UnhandledLanguageFeatureException(msg)) +end + +@is_primitive MinimalCtx Tuple{MistyClosure,Vararg{Any,N}} where {N} +function frule!!(f::Dual{<:MistyClosure}, x::Dual...) + dual_captures = Dual(primal(f).oc.captures, tangent(f).captures_tangent) + return tangent(f).dual_callable(dual_captures, x...) +end +function rrule!!(f::CoDual{<:MistyClosure}, x::CoDual...) + msg = + "Attempted to compute the adjoint associated to a `MistyClosure`. " * + "This is not currently supported. Please open an issue if you need " * + "this functionality." + throw(ArgumentError(msg)) +end diff --git a/src/rrules/new.jl b/src/rrules/new.jl index de141bd8e9..fee1a9be9c 100644 --- a/src/rrules/new.jl +++ b/src/rrules/new.jl @@ -1,5 +1,16 @@ @is_primitive MinimalCtx Tuple{typeof(_new_),Vararg} +function frule!!(f::Dual{typeof(_new_)}, p::Dual{Type{P}}, x::Vararg{Dual,N}) where {P,N} + y = _new_(P, tuple_map(primal, x)...) + T = tangent_type(P) + dy = if T == NoTangent + NoTangent() + else + build_output_tangent(P, tuple_map(primal, x), tuple_map(tangent, x)) + end + return Dual(y, dy) +end + function rrule!!( f::CoDual{typeof(_new_)}, p::CoDual{Type{P}}, x::Vararg{CoDual,N} ) where {P,N} @@ -33,6 +44,21 @@ function rrule!!( return CoDual(y, dy), pb!! end +@generated function build_output_tangent(::Type{P}, x::Tuple, t::Tuple) where {P} + names = fieldnames(P) + tangent_exprs = map(eachindex(names)) do n + F = tangent_field_types(P)[n] + if n <= length(t.parameters) + data_expr = Expr(:call, __get_data, P, :x, :t, n) + return F <: PossiblyUninitTangent ? Expr(:call, F, data_expr) : data_expr + else + return :($F()) + end + end + T_out = tangent_type(P) + return :($T_out(NamedTuple{$names}($(Expr(:call, tuple, tangent_exprs...))))) +end + @inline function build_fdata(::Type{P}, x::Tuple, fdata::Tuple) where {P} return _build_fdata_cartesian(P, x, fdata, Val(fieldcount(P)), Val(fieldnames(P))) end diff --git a/src/rrules/performance_patches.jl b/src/rrules/performance_patches.jl index ce26596919..4c59cacd92 100644 --- a/src/rrules/performance_patches.jl +++ b/src/rrules/performance_patches.jl @@ -17,6 +17,9 @@ # Performance issue: https://github.com/chalk-lab/Mooncake.jl/issues/156 @is_primitive(DefaultCtx, Tuple{typeof(sum),Array{<:IEEEFloat}}) +function frule!!(::Dual{typeof(sum)}, x::Dual{<:Array{P}}) where {P<:IEEEFloat} + return Dual(sum(primal(x)), sum(tangent(x))) +end function rrule!!(::CoDual{typeof(sum)}, x::CoDual{<:Array{P}}) where {P<:IEEEFloat} dx = x.dx function sum_pb!!(dz::P) @@ -28,6 +31,11 @@ end # Performance issue: https://github.com/chalk-lab/Mooncake.jl/issues/156 @is_primitive(DefaultCtx, Tuple{typeof(sum),typeof(abs2),Array{<:IEEEFloat}}) +function frule!!( + ::Dual{typeof(sum)}, ::Dual{typeof(abs2)}, x::Dual{<:Array{P}} +) where {P<:IEEEFloat} + return Dual(sum(abs2, primal(x)), 2 * dot(primal(x), tangent(x))) +end function rrule!!( ::CoDual{typeof(sum)}, ::CoDual{typeof(abs2)}, x::CoDual{<:Array{P}} ) where {P<:IEEEFloat} diff --git a/src/rrules/random.jl b/src/rrules/random.jl index 5df9a6c43d..3c8e10cf47 100644 --- a/src/rrules/random.jl +++ b/src/rrules/random.jl @@ -1,15 +1,22 @@ # Contains a ccall, which must be avoided. -@zero_adjoint MinimalCtx Tuple{Type{MersenneTwister},Any} +@zero_derivative MinimalCtx Tuple{Type{MersenneTwister},Any} const KnownRNGs = Union{MersenneTwister,RandomDevice,TaskLocalRNG,Xoshiro} -@zero_adjoint MinimalCtx Tuple{typeof(randn),KnownRNGs} -@zero_adjoint MinimalCtx Tuple{typeof(randexp),KnownRNGs} -@zero_adjoint MinimalCtx Tuple{typeof(randn),KnownRNGs,Type{<:IEEEFloat}} -@zero_adjoint MinimalCtx Tuple{typeof(randexp),KnownRNGs,Type{<:IEEEFloat}} +@zero_derivative MinimalCtx Tuple{typeof(randn),KnownRNGs} +@zero_derivative MinimalCtx Tuple{typeof(randexp),KnownRNGs} +@zero_derivative MinimalCtx Tuple{typeof(randn),KnownRNGs,Type{<:IEEEFloat}} +@zero_derivative MinimalCtx Tuple{typeof(randexp),KnownRNGs,Type{<:IEEEFloat}} const SpecialisedRNGs = Union{MersenneTwister,TaskLocalRNG,Xoshiro} for f in [randn!, randexp!] @eval @is_primitive MinimalCtx Tuple{typeof($f),SpecialisedRNGs,Array{Float64}} + @eval function frule!!( + ::Dual{typeof($f)}, rng::Dual{<:SpecialisedRNGs}, x::Dual{<:Array{Float64}} + ) + $f(primal(rng), primal(x)) + tangent(x) .= 0 + return x + end @eval function rrule!!( ::CoDual{typeof($f)}, rng::CoDual{<:SpecialisedRNGs}, x::CoDual{<:Array{Float64}} ) diff --git a/src/rrules/tasks.jl b/src/rrules/tasks.jl index 42b29ca8c9..88047576f8 100644 --- a/src/rrules/tasks.jl +++ b/src/rrules/tasks.jl @@ -49,13 +49,21 @@ rdata_type(::Type{TaskTangent}) = NoRData tangent(t::TaskTangent, ::NoRData) = t +@inline function _get_tangent_field(t::TaskTangent, f) + f === :rngState0 && return NoTangent() + f === :rngState1 && return NoTangent() + f === :rngState2 && return NoTangent() + f === :rngState3 && return NoTangent() + f === :rngState4 && return NoTangent() + return error("Unhandled field $f") +end @inline function _get_fdata_field(_, t::TaskTangent, f) f === :rngState0 && return NoFData() f === :rngState1 && return NoFData() f === :rngState2 && return NoFData() f === :rngState3 && return NoFData() f === :rngState4 && return NoFData() - throw(error("Unhandled field $f")) + return error("Unhandled field $f") end @inline increment_field_rdata!(::TaskTangent, ::NoRData, ::Val) = nothing @@ -69,8 +77,12 @@ function get_tangent_field(t::TaskTangent, f) throw(error("Unhandled field $f")) end +const TaskDual = Dual{Task,TaskTangent} const TaskCoDual = CoDual{Task,TaskTangent} +function frule!!(::Dual{typeof(lgetfield)}, x::TaskDual, ::Dual{Val{f}}) where {f} + return Dual(getfield(primal(x), f), _get_tangent_field(tangent(x), f)) +end function rrule!!(::CoDual{typeof(lgetfield)}, x::TaskCoDual, ::CoDual{Val{f}}) where {f} dx = x.dx function mutable_lgetfield_pb!!(dy) @@ -81,17 +93,23 @@ function rrule!!(::CoDual{typeof(lgetfield)}, x::TaskCoDual, ::CoDual{Val{f}}) w return y, mutable_lgetfield_pb!! end +function frule!!(::Dual{typeof(getfield)}, x::TaskDual, f::Dual) + return Dual(getfield(primal(x), primal(f)), _get_tangent_field(tangent(x), primal(f))) +end function rrule!!(::CoDual{typeof(getfield)}, x::TaskCoDual, f::CoDual) return rrule!!(zero_fcodual(lgetfield), x, zero_fcodual(Val(primal(f)))) end +function frule!!(::Dual{typeof(lsetfield!)}, task::TaskDual, name::Dual, val::Dual) + return lsetfield_frule(task, name, val) +end function rrule!!(::CoDual{typeof(lsetfield!)}, task::TaskCoDual, name::CoDual, val::CoDual) return lsetfield_rrule(task, name, val) end set_tangent_field!(t::TaskTangent, f, ::NoTangent) = NoTangent() -@zero_adjoint MinimalCtx Tuple{typeof(current_task)} +@zero_derivative MinimalCtx Tuple{typeof(current_task)} __verify_fdata_value(::IdDict{Any,Nothing}, ::Task, ::TaskTangent) = nothing diff --git a/src/rrules/twice_precision.jl b/src/rrules/twice_precision.jl index 40c6572e81..ffb7fad0f9 100644 --- a/src/rrules/twice_precision.jl +++ b/src/rrules/twice_precision.jl @@ -62,6 +62,13 @@ zero_rdata_from_type(P::Type{<:TWP{F}}) where {F} = P(zero(F), zero(F)) # @is_primitive MinimalCtx Tuple{typeof(_new_),<:TWP,IEEEFloat,IEEEFloat} +function frule!!( + ::Dual{typeof(_new_)}, ::Dual{Type{TWP{P}}}, hi::Dual{P}, lo::Dual{P} +) where {P<:IEEEFloat} + x = _new_(TWP{P}, primal(hi), primal(lo)) + dx = _new_(TWP{P}, tangent(hi), tangent(lo)) + return Dual(x, dx) +end function rrule!!( ::CoDual{typeof(_new_)}, ::CoDual{Type{TWP{P}}}, hi::CoDual{P}, lo::CoDual{P} ) where {P<:IEEEFloat} @@ -70,6 +77,13 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{typeof(twiceprecision),IEEEFloat,Integer} +function frule!!( + ::Dual{typeof(twiceprecision)}, val::Dual{P}, nb::Dual{<:Integer} +) where {P<:IEEEFloat} + x = twiceprecision(primal(val), primal(nb)) + dx = twiceprecision(tangent(val), primal(nb)) + return Dual(x, dx) +end function rrule!!( ::CoDual{typeof(twiceprecision)}, val::CoDual{P}, nb::CoDual{<:Integer} ) where {P<:IEEEFloat} @@ -78,6 +92,13 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{typeof(twiceprecision),TWP,Integer} +function frule!!( + ::Dual{typeof(twiceprecision)}, val::Dual{P}, nb::Dual{<:Integer} +) where {P<:TWP} + x = twiceprecision(primal(val), primal(nb)) + dx = twiceprecision(tangent(val), primal(nb)) + return Dual(x, dx) +end function rrule!!( ::CoDual{typeof(twiceprecision)}, val::CoDual{P}, nb::CoDual{<:Integer} ) where {P<:TWP} @@ -86,18 +107,25 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{Type{<:IEEEFloat},TWP} +function frule!!(::Dual{Type{P}}, x::Dual{S}) where {P<:IEEEFloat,S<:TWP} + return Dual(P(primal(x)), P(tangent(x))) +end function rrule!!(::CoDual{Type{P}}, x::CoDual{S}) where {P<:IEEEFloat,S<:TWP} float_from_twice_precision_pb(dy::P) = NoRData(), S(dy) return zero_fcodual(P(x.x)), float_from_twice_precision_pb end @is_primitive MinimalCtx Tuple{typeof(-),TWP} +frule!!(::Dual{typeof(-)}, x::Dual{P}) where {P<:TWP} = Dual(-primal(x), -tangent(x)) function rrule!!(::CoDual{typeof(-)}, x::CoDual{P}) where {P<:TWP} negate_twice_precision_pb(dy::P) = NoRData(), -dy return zero_fcodual(-(x.x)), negate_twice_precision_pb end @is_primitive MinimalCtx Tuple{typeof(+),TWP,IEEEFloat} +function frule!!(::Dual{typeof(+)}, x::Dual{P}, y::Dual{S}) where {P<:TWP,S<:IEEEFloat} + return Dual(primal(x) + primal(y), tangent(x) + tangent(y)) +end function rrule!!( ::CoDual{typeof(+)}, x::CoDual{P}, y::CoDual{S} ) where {P<:TWP,S<:IEEEFloat} @@ -106,12 +134,20 @@ function rrule!!( end @is_primitive(MinimalCtx, Tuple{typeof(+),P,P} where {P<:TWP}) +function frule!!(::Dual{typeof(+)}, x::Dual{P}, y::Dual{P}) where {P<:TWP} + return Dual(primal(x) + primal(y), tangent(x) + tangent(y)) +end function rrule!!(::CoDual{typeof(+)}, x::CoDual{P}, y::CoDual{P}) where {P<:TWP} plus_pullback(dz::P) = NoRData(), dz, dz return zero_fcodual(x.x + y.x), plus_pullback end @is_primitive MinimalCtx Tuple{typeof(*),TWP,IEEEFloat} +function frule!!(::Dual{typeof(*)}, x::Dual{P}, y::Dual{S}) where {P<:TWP,S<:IEEEFloat} + z = primal(x) * primal(y) + dz = primal(x) * tangent(y) + tangent(x) * primal(y) + return Dual(z, dz) +end function rrule!!( ::CoDual{typeof(*)}, x::CoDual{P}, y::CoDual{S} ) where {P<:TWP,S<:IEEEFloat} @@ -121,6 +157,9 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{typeof(*),TWP,Integer} +function frule!!(::Dual{typeof(*)}, x::Dual{P}, y::Dual{<:Integer}) where {P<:TWP} + return Dual(primal(x) * primal(y), tangent(x) * primal(y)) +end function rrule!!(::CoDual{typeof(*)}, x::CoDual{P}, y::CoDual{<:Integer}) where {P<:TWP} _y = y.x mul_twice_precision_and_int_pb(dz::P) = NoRData(), dz * _y, NoRData() @@ -128,6 +167,11 @@ function rrule!!(::CoDual{typeof(*)}, x::CoDual{P}, y::CoDual{<:Integer}) where end @is_primitive MinimalCtx Tuple{typeof(/),TWP,IEEEFloat} +function frule!!(::Dual{typeof(/)}, x::Dual{P}, y::Dual{S}) where {P<:TWP,S<:IEEEFloat} + z = primal(x) / primal(y) + dz = tangent(x) / primal(y) - tangent(y) * primal(x) / primal(y)^2 + return Dual(z, dz) +end function rrule!!( ::CoDual{typeof(/)}, x::CoDual{P}, y::CoDual{S} ) where {P<:TWP,S<:IEEEFloat} @@ -137,6 +181,9 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{typeof(/),TWP,Integer} +function frule!!(::Dual{typeof(/)}, x::Dual{P}, y::Dual{<:Integer}) where {P<:TWP} + return Dual(primal(x) / primal(y), tangent(x) / primal(y)) +end function rrule!!(::CoDual{typeof(/)}, x::CoDual{P}, y::CoDual{<:Integer}) where {P<:TWP} _y = y.x div_twice_precision_and_int_pb(dz::P) = NoRData(), dz / _y, NoRData() @@ -145,13 +192,13 @@ end # Primitives -@zero_adjoint MinimalCtx Tuple{Type{<:TwicePrecision},Tuple{Integer,Integer},Integer} -@zero_adjoint MinimalCtx Tuple{typeof(Base.splitprec),Type,Integer} -@zero_adjoint( +@zero_derivative MinimalCtx Tuple{Type{<:TwicePrecision},Tuple{Integer,Integer},Integer} +@zero_derivative MinimalCtx Tuple{typeof(Base.splitprec),Type,Integer} +@zero_derivative( MinimalCtx, Tuple{typeof(Base.floatrange),Type{<:IEEEFloat},Integer,Integer,Integer,Integer}, ) -@zero_adjoint( +@zero_derivative( MinimalCtx, Tuple{typeof(Base._linspace),Type{<:IEEEFloat},Integer,Integer,Integer,Integer}, ) @@ -160,6 +207,14 @@ using Base: range_start_step_length @is_primitive( MinimalCtx, Tuple{typeof(range_start_step_length),T,T,Integer} where {T<:IEEEFloat} ) +function frule!!( + ::Dual{typeof(range_start_step_length)}, a::Dual{T}, st::Dual{T}, len::Dual{<:Integer} +) where {T<:IEEEFloat} + x = range_start_step_length(primal(a), primal(st), primal(len)) + Tx = tangent_type(typeof(x)) + dx = Tx((ref=tangent(a), step=tangent(st), len=NoTangent(), offset=NoTangent())) + return Dual(x, dx) +end function rrule!!( ::CoDual{typeof(range_start_step_length)}, a::CoDual{T}, @@ -173,6 +228,15 @@ end using Base: unsafe_getindex const TWPStepRangeLen = StepRangeLen{<:Any,<:TWP,<:TWP} @is_primitive(MinimalCtx, Tuple{typeof(unsafe_getindex),TWPStepRangeLen,Integer}) +function frule!!( + ::Dual{typeof(unsafe_getindex)}, r::Dual{P}, i::Dual{<:Integer} +) where {P<:TWPStepRangeLen} + x = unsafe_getindex(primal(r), primal(i)) + dref = _get_tangent_field(tangent(r), :ref) + dstep = _get_tangent_field(tangent(r), :step) + dx = eltype(P)(dref + dstep * (primal(i) - primal(r).offset)) + return Dual(x, dx) +end function rrule!!( ::CoDual{typeof(unsafe_getindex)}, r::CoDual{P}, i::CoDual{<:Integer} ) where {P<:TWPStepRangeLen} @@ -190,6 +254,16 @@ end using Base: _getindex_hiprec @is_primitive(MinimalCtx, Tuple{typeof(_getindex_hiprec),TWPStepRangeLen,Integer}) +function frule!!( + ::Dual{typeof(_getindex_hiprec)}, r::Dual{P}, i::Dual{<:Integer} +) where {P<:TWPStepRangeLen} + x = _getindex_hiprec(primal(r), primal(i)) + offset = primal(r).offset + dstep = _get_tangent_field(tangent(r), :step) + dref = _get_tangent_field(tangent(r), :ref) + dx = (primal(i) - offset) * dstep + dref + return Dual(x, dx) +end function rrule!!( ::CoDual{typeof(_getindex_hiprec)}, r::CoDual{P}, i::CoDual{<:Integer} ) where {P<:TWPStepRangeLen} @@ -205,6 +279,14 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{typeof(:),P,P,P} where {P<:IEEEFloat} +function frule!!( + ::Dual{typeof(:)}, start::Dual{P}, step::Dual{P}, stop::Dual{P} +) where {P<:IEEEFloat} + x = (:)(primal(start), primal(step), primal(stop)) + T = tangent_type(typeof(x)) + dx = T((ref=tangent(start), step=tangent(step), len=NoTangent(), offset=NoTangent())) + return Dual(x, dx) +end function rrule!!( ::CoDual{typeof(:)}, start::CoDual{P}, step::CoDual{P}, stop::CoDual{P} ) where {P<:IEEEFloat} @@ -213,6 +295,15 @@ function rrule!!( end @is_primitive MinimalCtx Tuple{typeof(sum),TWPStepRangeLen} +function frule!!(::Dual{typeof(sum)}, x::Dual{P}) where {P<:TWPStepRangeLen} + y = sum(primal(x)) + l = primal(x).len + offset = primal(x).offset + dref = _get_tangent_field(tangent(x), :ref) + dstep = _get_tangent_field(tangent(x), :step) + dy = dref * l + dstep * (0.5 * l * (l + 1) - l * offset) + return Dual(y, typeof(y)(dy)) +end function rrule!!(::CoDual{typeof(sum)}, x::CoDual{P}) where {P<:TWPStepRangeLen} l = x.x.len offset = x.x.offset @@ -230,6 +321,20 @@ end MinimalCtx, Tuple{typeof(Base.range_start_stop_length),P,P,Integer} where {P<:IEEEFloat}, ) +function frule!!( + ::Dual{typeof(Base.range_start_stop_length)}, + start::Dual{P}, + stop::Dual{P}, + length::Dual{<:Integer}, +) where {P<:IEEEFloat} + l = primal(length) - 1 + y = Base.range_start_stop_length(primal(start), primal(stop), primal(length)) + T = tangent_type(typeof(y)) + dref = tangent(start) + dstep = (tangent(stop) - tangent(start)) / l + dy = T((ref=dref, step=dstep, len=NoTangent(), offset=NoTangent())) + return Dual(y, dy) +end function rrule!!( ::CoDual{typeof(Base.range_start_stop_length)}, start::CoDual{P}, @@ -250,6 +355,12 @@ end @is_primitive MinimalCtx Tuple{ typeof(Base._exp_allowing_twice64),TwicePrecision{Float64} } + function frule!!( + ::Dual{typeof(Base._exp_allowing_twice64)}, x::Dual{TwicePrecision{Float64}} + ) + y = Base._exp_allowing_twice64(primal(x)) + return Dual(y, typeof(y)(y * tangent(x))) + end function rrule!!( ::CoDual{typeof(Base._exp_allowing_twice64)}, x::CoDual{TwicePrecision{Float64}} ) @@ -259,6 +370,10 @@ end end @is_primitive(MinimalCtx, Tuple{typeof(Base._log_twice64_unchecked),Float64}) + function frule!!(::Dual{typeof(Base._log_twice64_unchecked)}, x::Dual{Float64}) + y = Base._log_twice64_unchecked(primal(x)) + return Dual(y, typeof(y)(tangent(x) / primal(x))) + end function rrule!!(::CoDual{typeof(Base._log_twice64_unchecked)}, x::CoDual{Float64}) _x = x.x _log_twice64_pb(dy::TwicePrecision{Float64}) = NoRData(), Float64(dy) / _x @@ -299,16 +414,8 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:twice_precisi (false, :stability_and_allocs, nothing, Base.splitprec, Float16, 5), (false, :stability_and_allocs, nothing, Base.floatrange, Float64, 5, 6, 7, 8), (false, :stability_and_allocs, nothing, Base._linspace, Float64, 5, 6, 7, 8), - (false, :stability_and_allocs, nothing, Base.range_start_step_length, 5.0, 6.0, 10), - ( - false, - :stability_and_allocs, - nothing, - Base.range_start_step_length, - 5.0, - Float64(π), - 10, - ), + (false, :allocs, nothing, Base.range_start_step_length, 5.0, 6.0, 10), + (false, :allocs, nothing, Base.range_start_step_length, 5.0, Float64(π), 10), ( false, :stability_and_allocs, @@ -325,26 +432,10 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:twice_precisi StepRangeLen(TwicePrecision(-0.45), TwicePrecision(0.98), 10, 3), 5, ), - (false, :stability_and_allocs, nothing, (:), -0.1, 0.99, 5.1), + (false, :allocs, nothing, (:), -0.1, 0.99, 5.1), (false, :stability_and_allocs, nothing, sum, range(-0.1, 9.9; length=51)), - ( - false, - :stability_and_allocs, - nothing, - Base.range_start_stop_length, - -0.5, - 11.7, - 7, - ), - ( - false, - :stability_and_allocs, - nothing, - Base.range_start_stop_length, - -0.5, - -11.7, - 11, - ), + (false, :allocs, nothing, Base.range_start_stop_length, -0.5, 11.7, 7), + (false, :allocs, nothing, Base.range_start_stop_length, -0.5, -11.7, 11), ] @static if VERSION >= v"1.11" extra_test_cases = Any[ diff --git a/src/test_resources.jl b/src/test_resources.jl index c5888a8e9a..541fa78dab 100644 --- a/src/test_resources.jl +++ b/src/test_resources.jl @@ -10,6 +10,7 @@ module TestResources using ..Mooncake using ..Mooncake: CoDual, + Dual, Tangent, MutableTangent, NoTangent, @@ -576,6 +577,9 @@ end @noinline edge_case_tester(x::Int) = 10 @noinline edge_case_tester(x::String) = "hi" @is_primitive MinimalCtx Tuple{typeof(edge_case_tester),Float64} +function Mooncake.frule!!(::Dual{typeof(edge_case_tester)}, x::Dual{Float64}) + return Dual(5 * primal(x), 5 * tangent(x)) +end function Mooncake.rrule!!(::CoDual{typeof(edge_case_tester)}, x::CoDual{Float64}) edge_case_tester_pb!!(dy) = Mooncake.NoRData(), 5 * dy return Mooncake.zero_fcodual(5 * primal(x)), edge_case_tester_pb!! diff --git a/src/test_utils.jl b/src/test_utils.jl index 0e88b3b8f9..6802424343 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -11,18 +11,7 @@ module TestTypes using Base.Iterators: product using Core: svec using ExprTools: combinedef -using ..Mooncake: - NoTangent, - tangent_type, - _typeof, - set_to_zero!!, - increment!!, - is_primitive, - randn_tangent, - _scale, - _add_to_primal, - _diff, - _dot +using ..Mooncake: NoTangent const PRIMALS = Tuple{Bool,Any,Tuple}[] @@ -105,6 +94,7 @@ using Mooncake: PossiblyUninitTangent, Tangent, MutableTangent, + frule!!, rrule!!, build_rrule, tangent_type, @@ -114,7 +104,6 @@ using Mooncake: is_init, zero_codual, DefaultCtx, - @is_primitive, val, is_always_fully_initialised, get_tangent_field, @@ -133,6 +122,8 @@ using Mooncake: instantiate, can_produce_zero_rdata_from_type, increment_rdata!!, + dual_type, + randn_dual, fcodual_type, verify_fdata_type, verify_rdata_type, @@ -148,7 +139,6 @@ using Mooncake: CC, set_to_zero!!, increment!!, - is_primitive, randn_tangent, _scale, _add_to_primal, @@ -159,7 +149,15 @@ using Mooncake: fdata, NoRData, rdata_type, - rdata + rdata, + Dual, + Mode, + ForwardMode, + ReverseMode, + DebugRRule, + build_frule, + build_rrule, + get_interpreter struct Shim end @@ -433,7 +431,81 @@ function address_maps_are_consistent(x::AddressMap, y::AddressMap) end # Assumes that the interface has been tested, and we can simply check for numerical issues. -function test_rule_correctness(rng::AbstractRNG, x_x̄...; rule, unsafe_perturb::Bool) +function test_frule_correctness(rng::AbstractRNG, x_ẋ...; frule, unsafe_perturb::Bool) + @nospecialize rng x_ẋ + + x_ẋ = map(_deepcopy, x_ẋ) # defensive copy + + # Run original function on deep-copies of inputs. + x = map(primal, x_ẋ) + ẋ = map(tangent, x_ẋ) + x_primal = _deepcopy(x) + y_primal = x_primal[1](x_primal[2:end]...) + + # Use finite differences to estimate Frechet derivative. Compute the estimate at a range + # of different step sizes. We'll just require that one of them ends up being close to + # what AD gives. + ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] + fd_results = Vector{Any}(undef, length(ε_list)) + for (n, ε) in enumerate(ε_list) + x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb) + y′_l = x′_l[1](x′_l[2:end]...) + x′_r = _add_to_primal(x, _scale(-ε, ẋ), unsafe_perturb) + y′_r = x′_r[1](x′_r[2:end]...) + fd_results[n] = ( + ẏ=_scale(1 / 2ε, _diff(y′_l, y′_r)), + ẋ=map((_x′, _x_p) -> _scale(1 / 2ε, _diff(_x′, _x_p)), x′_l, x′_r), + ) + end + + # Use AD to compute Frechet derivative at ẋ. + x_ẋ_rule = map((x, ẋ) -> dual_type(_typeof(x))(_deepcopy(x), ẋ), x, ẋ) + inputs_address_map = populate_address_map( + map(primal, x_ẋ_rule), map(tangent, x_ẋ_rule) + ) + y_ẏ_rule = frule(x_ẋ_rule...) + ẋ_ad = map(tangent, x_ẋ_rule) + ẏ_ad = tangent(y_ẏ_rule) + + # Verify that inputs / outputs are the same under `f` and its rrule. + @test has_equal_data(x_primal, map(primal, x_ẋ_rule)) + @test has_equal_data(y_primal, primal(y_ẏ_rule)) + + # Query both `x_ẋ` and `y`, because `x_ẋ` may have been mutated by `f`. + outputs_address_map = populate_address_map( + (map(primal, x_ẋ_rule)..., primal(y_ẏ_rule)), + (map(tangent, x_ẋ_rule)..., tangent(y_ẏ_rule)), + ) + + # Check that all aliasing structure is correct. + @test address_maps_are_consistent(inputs_address_map, outputs_address_map) + + # Any linear projection of the outputs ought to do. Require only one + # precision to be close to the answer AD gives. i.e. prove that there exists a step size + # such that AD and central differences agree on the answer. + x̄ = map(Base.Fix1(randn_tangent, rng), x_primal) + ȳ = randn_tangent(rng, y_primal) + isapprox_results = map(fd_results) do result + ẏ_fd, ẋ_fd = result + return isapprox( + _dot(ȳ, ẏ_fd) + _dot(x̄, ẋ_fd), + _dot(ȳ, ẏ_ad) + _dot(x̄, ẋ_ad); + rtol=1e-3, + atol=1e-3, + ) + end + if !any(isapprox_results) + vals = map(fd_results) do result + ẏ_fd, ẋ_fd = result + (_dot(ȳ, ẏ_fd) + _dot(x̄, ẋ_fd), _dot(ȳ, ẏ_ad) + _dot(x̄, ẋ_ad)) + end + display(vals) + end + @test any(isapprox_results) +end + +# Assumes that the interface has been tested, and we can simply check for numerical issues. +function test_rrule_correctness(rng::AbstractRNG, x_x̄...; rrule, unsafe_perturb::Bool) @nospecialize rng x_x̄ x_x̄ = map(_deepcopy, x_x̄) # defensive copy @@ -446,9 +518,13 @@ function test_rule_correctness(rng::AbstractRNG, x_x̄...; rule, unsafe_perturb: x_primal = _deepcopy(x) y_primal = x_primal[1](x_primal[2:end]...) + # Construct tangent to inputs, and normalise to be of unit length. + ẋ_unnormalised = map(_x -> randn_tangent(rng, _x), x) + nrm = sqrt(sum(x -> _dot(x, x), ẋ_unnormalised)) + ẋ = map(_x -> _scale(1 / nrm, _x), ẋ_unnormalised) + # Use finite differences to estimate vjps. Compute the estimate at a range of different # step sizes. We'll just require that one of them ends up being close to what AD gives. - ẋ = map(_x -> randn_tangent(rng, _x), x) ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] fd_results = Vector{Any}(undef, length(ε_list)) for (n, ε) in enumerate(ε_list) @@ -470,7 +546,7 @@ function test_rule_correctness(rng::AbstractRNG, x_x̄...; rule, unsafe_perturb: inputs_address_map = populate_address_map( map(primal, x_x̄_rule), map(tangent, x_x̄_rule) ) - y_ȳ_rule, pb!! = rule(x_x̄_rule...) + y_ȳ_rule, pb!! = rrule(x_x̄_rule...) # Verify that inputs / outputs are the same under `f` and its rrule. @test has_equal_data(x_primal, map(primal, x_x̄_rule)) @@ -510,6 +586,14 @@ function test_rule_correctness(rng::AbstractRNG, x_x̄...; rule, unsafe_perturb: atol=1e-3, ) end + if !any(isapprox_results) + vals = map(fd_results) do result + ẏ, ẋ_post = result + (_dot(ȳ_delta, ẏ) + _dot(x̄_delta, ẋ_post), _dot(x̄, ẋ)) + end + display(vals) + println() + end @test any(isapprox_results) end @@ -520,7 +604,42 @@ _deepcopy(x::Module) = x rrule_output_type(::Type{Ty}) where {Ty} = Tuple{Mooncake.fcodual_type(Ty),Any} -function test_rrule_interface(f_f̄, x_x̄...; rule) +function test_frule_interface(x_ẋ...; frule) + @nospecialize x_ẋ + + # Pull out primals and run primal computation. + x_ẋ = map(_deepcopy, x_ẋ) + x = map(primal, x_ẋ) + + # Run the primal programme. Bail out early if this doesn't work. + y = try + # Note: the function itself occassionally contains a `module`. Since `module`s + # cannot be `deepcopy`-ed, we do not do so. This will cause some trouble if the + # function mutates itself during execution. + x[1](deepcopy(x[2:end])...) + catch + throw(ArgumentError("Primal does not run, signature is $(_typeof(x_ẋ)).")) + end + + # Check that input types are valid. + for x_ẋ_component in x_ẋ + @test Mooncake.verify_dual_type(x_ẋ_component) + end + + # Run the frule, check it has output a thing of the correct type, and extract results. + # Throw a meaningful exception if the frule doesn't run at all. + y_ẏ = try + frule(x_ẋ...) + catch + throw(ArgumentError("rule does not run, signature is $(_typeof(x_ẋ)).")) + end + + # Check that returned fdata type is correct. + @test y_ẏ isa Dual + @test Mooncake.verify_dual_type(y_ẏ) +end + +function test_rrule_interface(f_f̄, x_x̄...; rrule) @nospecialize f_f̄ x_x̄ # Pull out primals and run primal computation. @@ -532,6 +651,9 @@ function test_rrule_interface(f_f̄, x_x̄...; rule) # Run the primal programme. Bail out early if this doesn't work. y = try + # Note: the function itself occassionally contains a `module`. Since `module`s + # cannot be `deepcopy`-ed, we do not do so. This will cause some trouble if the + # function mutates itself during execution. f(deepcopy(x)...) catch e display(e) @@ -553,7 +675,7 @@ function test_rrule_interface(f_f̄, x_x̄...; rule) # Throw a meaningful exception if the rrule doesn't run at all. x_addresses = map(get_address, x) rrule_ret = try - rule(f_fwds, x_fwds...) + rrule(f_fwds, x_fwds...) catch e display(e) println() @@ -597,11 +719,52 @@ function test_rrule_interface(f_f̄, x_x̄...; rule) @test all(map((a, b) -> _typeof(a) == _typeof(rdata(b)), x̄_new, x̄)) end +__forwards(frule::F, x_ẋ::Vararg{Any,N}) where {F,N} = frule(x_ẋ...) + @noinline function __forwards_and_backwards(rule::R, x_x̄::Vararg{Any,N}) where {R,N} out, pb!! = rule(x_x̄...) return pb!!(Mooncake.zero_rdata(primal(out))) end +function test_frule_performance( + performance_checks_flag::Symbol, rule::R, f_ḟ::F, x_ẋ::Vararg{Any,N} +) where {R,F,N} + x_ẋ = _deepcopy(x_ẋ) + + # Verify that a valid performance flag has been passed. + valid_flags = (:none, :stability, :allocs, :stability_and_allocs) + if !in(performance_checks_flag, valid_flags) + throw( + ArgumentError( + "performance_checks=$performance_checks_flag. Must be one of $valid_flags" + ), + ) + end + performance_checks_flag == :none && return nothing + + if performance_checks_flag in (:stability, :stability_and_allocs) + + # Test primal stability. + test_opt(primal(f_ḟ), map(_typeof ∘ primal, x_ẋ)) + + # Test forwards-mode stability. + test_opt(rule, (_typeof(f_ḟ), map(_typeof, x_ẋ)...)) + end + + if performance_checks_flag in (:allocs, :stability_and_allocs) + f = primal(f_ḟ) + x = map(primal, x_ẋ) + + # Test allocations in primal. + f(x...) + @test (@allocations f(x...)) == 0 + + # Test allocations in forwards-mode. + __forwards(rule, f_ḟ, x_ẋ...) + @test (@allocations __forwards(rule, f_ḟ, x_ẋ...)) == 0 + end +end + function test_rrule_performance( performance_checks_flag::Symbol, rule::R, f_f̄::F, x_x̄::Vararg{Any,N} ) where {R,F,N} @@ -648,17 +811,19 @@ function test_rrule_performance( end end -__get_primals(xs) = map(x -> x isa CoDual ? primal(x) : x, xs) +__get_primals(xs) = map(x -> x isa Union{Dual,CoDual} ? primal(x) : x, xs) """ test_rule( - rng, x...; - interface_only=false, + rng::AbstractRNG, + x...; + interface_only::Bool=false, is_primitive::Bool=true, perf_flag::Symbol=:none, - interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(), + mode::Union{Nothing,Type{ForwardMode},Type{ReverseMode}}=nothing, debug_mode::Bool=false, unsafe_perturb::Bool=false, + print_results=true, ) Run standardised tests on the `rule` for `x`. @@ -670,8 +835,8 @@ though, in partcular `Ptr`s. In this case, the argument for which `randn_tangent readily defined should be a `CoDual` containing the primal, and a _manually_ constructed tangent field. -This function uses [`Mooncake.build_rrule`](@ref) to construct a rule. This will use an -`rrule!!` if one exists, and derive a rule otherwise. +This function is intended for use with both hand-written rules and derived rules. If the +signature associated to `x` corresponds to a primitive, a hand-written rule will be used. # Arguments - `rng::AbstractRNG`: a random number generator @@ -699,8 +864,8 @@ This function uses [`Mooncake.build_rrule`](@ref) to construct a rule. This will this to `:stability` (at present we cannot verify whether a derived rule is type stable for technical reasons). If you believe that a hand-written rule should be _both_ allocation-free and type-stable, set this to `:stability_and_allocs`. -- `interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter()`: the abstract - interpreter to be used when testing this rule. The default should generally be used. +- `mode::Union{Nothing,Type{ForwardMode},Type{ReverseMode}}=nothing`: the mode of AD to + test. If `mode===nothing` (default), then both forward and reverse mode are tested. - `debug_mode::Bool=false`: whether or not the rule should be tested in debug mode. Typically this should be left at its default `false` value, but if you are finding that the tests are failing for a given rule, you may wish to temporarily set it to `true` in @@ -715,18 +880,30 @@ function test_rule( interface_only::Bool=false, is_primitive::Bool=true, perf_flag::Symbol=:none, - interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(), + mode::Union{Nothing,Type{ForwardMode},Type{ReverseMode}}=nothing, debug_mode::Bool=false, unsafe_perturb::Bool=false, + print_results=true, ) + # Take a copy of `x` to ensure that we do not mutate the original. + x = deepcopy(x) + # Construct the rule. sig = _typeof(__get_primals(x)) - rule = Mooncake.build_rrule(interp, sig; debug_mode) + test_fwd = mode in [nothing, ForwardMode] + test_rvs = mode in [nothing, ReverseMode] + fwd_interp = test_fwd ? get_interpreter(ForwardMode) : missing + rvs_interp = test_rvs ? get_interpreter(ReverseMode) : missing + frule = test_fwd ? build_frule(fwd_interp, sig; debug_mode) : missing + rrule = test_rvs ? build_rrule(rvs_interp, sig; debug_mode) : missing # If something is primitive, then the rule should be `rrule!!`. - is_primitive && @test rule == (debug_mode ? Mooncake.DebugRRule(rrule!!) : rrule!!) + test_fwd && is_primitive && @test frule == frule!! + test_rvs && is_primitive && @test rrule == (debug_mode ? DebugRRule(rrule!!) : rrule!!) # Generate random tangents for anything that is not already a CoDual. + x_ẋ = map(x -> x isa CoDual ? Dual(primal(x), tangent(x)) : randn_dual(rng, x), x) + x_x̄ = map(x -> if x isa CoDual x elseif interface_only @@ -735,20 +912,57 @@ function test_rule( zero_codual(x) end, x) - # Test that the interface is basically satisfied (checks types / memory addresses). - test_rrule_interface(x_x̄...; rule) + redirector = print_results ? ((f, x) -> f()) : redirect_stdout + ts = redirector(devnull) do + @testset "$(typeof(x))" begin + # Test that the interface is basically satisfied (checks types / memory addresses). + @testset "Interface (1)" begin + test_fwd && test_frule_interface(x_ẋ...; frule) + test_rvs && test_rrule_interface(x_x̄...; rrule) + end - # Test that answers are numerically correct / consistent. - interface_only || test_rule_correctness(rng, x_x̄...; rule, unsafe_perturb) + # Test that answers are numerically correct / consistent. + @testset "Correctness" begin + if test_fwd && !interface_only + test_frule_correctness(rng, x_ẋ...; frule, unsafe_perturb) + end + if test_rvs && !interface_only + test_rrule_correctness(rng, x_x̄...; rrule, unsafe_perturb) + end + end - # Test the performance of the rule. - test_rrule_performance(perf_flag, rule, x_x̄...) + # Test the performance of the rule. + @testset "Performance" begin + test_fwd && test_frule_performance(perf_flag, frule, x_ẋ...) + test_rvs && test_rrule_performance(perf_flag, rrule, x_x̄...) + end + + # Verify that rules have been cached. + @testset "Caching" begin + if test_fwd + C_fwd = Mooncake.context_type(fwd_interp) + if !Mooncake.is_primitive(C_fwd, ForwardMode, sig) + cache_key = (sig, false, :forward) + k = Mooncake.ClosureCacheKey(fwd_interp.world, cache_key) + @test haskey(fwd_interp.oc_cache, k) + end + end + if test_rvs + C_rvs = Mooncake.context_type(rvs_interp) + if !Mooncake.is_primitive(C_rvs, ReverseMode, sig) + cache_key = (sig, false, :reverse) + k = Mooncake.ClosureCacheKey(rvs_interp.world, cache_key) + @test haskey(rvs_interp.oc_cache, k) + end + end + end + end + end - # Test the interface again, in order to verify that caching is working correctly. - return test_rrule_interface(x_x̄...; rule=Mooncake.build_rrule(interp, sig; debug_mode)) + return ts end -function run_hand_written_rrule!!_test_cases(rng_ctor, v::Val) +function run_hand_written_rule_test_cases(rng_ctor, v::Val, mode::Type{<:Mode}) test_cases, memory = test_hook(Mooncake.generate_hand_written_rrule!!_test_cases, rng_ctor, v) do Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, v) @@ -757,26 +971,36 @@ function run_hand_written_rrule!!_test_cases(rng_ctor, v::Val) interface_only, perf_flag, _, f, x... ) in test_cases - test_rule(rng_ctor(123), f, x...; interface_only, perf_flag) + test_rule(rng_ctor(123), f, x...; interface_only, perf_flag, mode) + test_rule(rng_ctor(123), f, x...; interface_only, perf_flag, mode) end end -function run_derived_rrule!!_test_cases(rng_ctor, v::Val) +function run_derived_rule_test_cases(rng_ctor, v::Val, mode::Type{<:Mode}) test_cases, memory = - test_hook(Mooncake.generate_derived_rrule!!_test_cases, rng_ctor, v) do + test_hook(Mooncake.generate_derived_rrule!!_test_cases, rng_ctor, v, mode) do Mooncake.generate_derived_rrule!!_test_cases(rng_ctor, v) end - GC.@preserve memory @testset "$f, $(typeof(x))" for ( + GC.@preserve memory @testset "$mode, $f, $(typeof(x))" for ( interface_only, perf_flag, _, f, x... ) in test_cases - test_rule(rng_ctor(123), f, x...; interface_only, perf_flag, is_primitive=false) + test_rule( + rng_ctor(123), f, x...; interface_only, perf_flag, is_primitive=false, mode + ) end end -function run_rrule!!_test_cases(rng_ctor, v::Val) - run_hand_written_rrule!!_test_cases(rng_ctor, v) - return run_derived_rrule!!_test_cases(rng_ctor, v) +function run_rule_test_cases(rng_ctor, v::Val, mode=nothing) + if mode in [nothing, ForwardMode] + run_hand_written_rule_test_cases(rng_ctor, v, ForwardMode) + run_derived_rule_test_cases(rng_ctor, v, ForwardMode) + end + if mode in [nothing, ReverseMode] + run_hand_written_rule_test_cases(rng_ctor, v, ReverseMode) + run_derived_rule_test_cases(rng_ctor, v, ReverseMode) + end + return nothing end """ @@ -1269,7 +1493,7 @@ function _test_tangent_splitting_internal( lazy_rzero = @inferred lazy_zero_rdata(p) @test instantiate(lazy_rzero) isa R - # Check incrementing the fdata component of a tangnet yields the correct type. + # Check incrementing the fdata component of a tangent yields the correct type. @test increment!!(f, f) isa F # Check incrementing the rdata component of a tangent yields the correct type. @@ -1328,7 +1552,7 @@ function test_rule_and_type_interactions(rng::AbstractRNG, p::P) where {P} interface_only=true, is_primitive=true, perf_flag=:none, - interp=Mooncake.get_interpreter(), + mode=ReverseMode, ) end end diff --git a/src/tools_for_rules.jl b/src/tools_for_rules.jl index c0478c6dde..f747121154 100644 --- a/src/tools_for_rules.jl +++ b/src/tools_for_rules.jl @@ -18,8 +18,15 @@ return arg_type_symbols, where_params end -function construct_def(arg_names, arg_types, where_params, body) - name = :(Mooncake.rrule!!) +function construct_rrule_def(arg_names, arg_types, where_params, body) + return construct_rule_def(:(Mooncake.rrule!!), arg_names, arg_types, where_params, body) +end + +function construct_frule_def(arg_names, arg_types, where_params, body) + return construct_rule_def(:(Mooncake.frule!!), arg_names, arg_types, where_params, body) +end + +function construct_rule_def(name, arg_names, arg_types, where_params, body) arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types) def = Dict(:head => :function, :name => name, :args => arg_exprs, :body => body) where_params !== nothing && setindex!(def, where_params, :whereparams) @@ -111,8 +118,8 @@ end """ zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N} -Utility functionality for constructing `rrule!!`s for functions which produce adjoints which -always return zero. +Utility functionality for constructing `rrule!!`s for functions whose adjoints always return +zero. NOTE: you should only make use of this function if you cannot make use of the [`@zero_adjoint`](@ref) macro. @@ -143,39 +150,78 @@ may be required if it is not. end """ - @zero_adjoint ctx sig + zero_derivative(f::Dual, x::Vararg{Dual,N}) where {N} + +Utility functionality for constructing `frule!!`s for functions whose derivatives always +return zero. + +NOTE: you should only make use of this function if you cannot make use of the +[`@zero_derivative`](@ref) macro. + +You make use of this functionality by writing a method of `Mooncake.frule!!`, and +passing all of its arguments (including the function itself) to this function. For example: +```jldoctest +julia> import Mooncake: zero_derivative, DefaultCtx, zero_dual, frule!!, Dual + +julia> foo(x::Vararg{Int}) = 5 +foo (generic function with 1 method) + +julia> frule!!(f::Dual{typeof(foo)}, x::Vararg{Dual{Int}}) = zero_derivative(f, x...); + +julia> frule!!(zero_dual(foo), zero_dual(3), zero_dual(2)) +Dual{Int64, NoTangent}(5, NoTangent()) +``` +""" +@inline function zero_derivative(f::Dual, x::Vararg{Dual,N}) where {N} + return zero_dual(primal(f)(map(primal, x)...)) +end + +""" + zero_derivative(ctx, sig, [mode=Mode]) + +Declares that the derivative of the mode for `sig` is always zero, for all arguments. This +also implies that the adjoint of the derivative is always zero for all arguments. + +Accordingly, if `mode===Mode` (the default) this macro creates a method of +[`is_primitive`](@ref) which returns `true` for `ctx`, `sig`, and both [`ForwardMode`](@ref) +and [`ReverseMode`](@ref). It additionally creates methods of [`frule!!`](@ref) and +[`rrule!!`](@ref) which always return zero / do not increment tangents and fdata. -Defines `is_primitive(context_type, sig) = true`, and defines a method of -`Mooncake.rrule!!` which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality -- it is morally the same as `ChainRulesCore.@non_differentiable`. For example: ```jldoctest -julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive +julia> using Mooncake: @zero_derivative, DefaultCtx, zero_dual, zero_fcodual, frule!!, rrule!!, is_primitive, ForwardMode, ReverseMode julia> foo(x) = 5 foo (generic function with 1 method) -julia> @zero_adjoint DefaultCtx Tuple{typeof(foo), Any} +julia> @zero_derivative DefaultCtx Tuple{typeof(foo), Any} + +julia> is_primitive(DefaultCtx, ForwardMode, Tuple{typeof(foo), Any}) +true -julia> is_primitive(DefaultCtx, Tuple{typeof(foo), Any}) +julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo), Any}) true +julia> frule!!(zero_dual(foo), zero_dual(3.0)) +Mooncake.Dual{Int64, NoTangent}(5, NoTangent()) + julia> rrule!!(zero_fcodual(foo), zero_fcodual(3.0))[2](NoRData()) (NoRData(), 0.0) ``` Limited support for `Vararg`s is also available. For example ```jldoctest -julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive +julia> using Mooncake: @zero_derivative, DefaultCtx, zero_fcodual, rrule!!, is_primitive, ReverseMode julia> foo_varargs(x...) = 5 foo_varargs (generic function with 1 method) -julia> @zero_adjoint DefaultCtx Tuple{typeof(foo_varargs), Vararg} +julia> @zero_derivative DefaultCtx Tuple{typeof(foo_varargs), Vararg} -julia> is_primitive(DefaultCtx, Tuple{typeof(foo_varargs), Any, Float64, Int}) +julia> is_primitive(DefaultCtx, ReverseMode, Tuple{typeof(foo_varargs), Any, Float64, Int}) true julia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData()) @@ -194,11 +240,18 @@ made a mistake. # Signatures Unsupported By This Macro -If the signature you wish to apply `@zero_adjoint` to is not supported, for example because +If the signature you wish to apply `@zero_derivative` to is not supported, for example because it uses a `Vararg` with a type parameter, you can still make use of -[`zero_adjoint`](@ref). +[`zero_derivative`](@ref). + """ -macro zero_adjoint(ctx, sig) +macro zero_derivative(ctx, sig, mode=Mode) + mode = mode == :ForwardMode ? ForwardMode : mode + mode = mode == :ReverseMode ? ReverseMode : mode + return _zero_derivative_impl(ctx, sig, mode) +end + +function _zero_derivative_impl(ctx, sig, mode) # Parse the signature, and construct the rule definition. If it is a vararg definition, # then the last argument requires special treatment. @@ -206,23 +259,57 @@ macro zero_adjoint(ctx, sig) arg_names = map(n -> Symbol("x_$n"), eachindex(arg_type_symbols)) is_vararg = arg_type_symbols[end] == Expr(:escape, :Vararg) if is_vararg - arg_types = vcat( + arg_types_deriv = vcat( + map(t -> :(Mooncake.Dual{<:$t}), arg_type_symbols[1:(end - 1)]), + :(Vararg{Mooncake.Dual}), + ) + arg_types_adjoint = vcat( map(t -> :(Mooncake.CoDual{<:$t}), arg_type_symbols[1:(end - 1)]), :(Vararg{Mooncake.CoDual}), ) splat_symbol = Expr(Symbol("..."), arg_names[end]) - body = Expr(:call, Mooncake.zero_adjoint, arg_names[1:(end - 1)]..., splat_symbol) + tmp = arg_names[1:(end - 1)] + body_deriv = Expr(:call, Mooncake.zero_derivative, tmp..., splat_symbol) + body_adjoint = Expr(:call, Mooncake.zero_adjoint, tmp..., splat_symbol) else - arg_types = map(t -> :(Mooncake.CoDual{<:$t}), arg_type_symbols) - body = Expr(:call, Mooncake.zero_adjoint, arg_names...) + arg_types_deriv = map(t -> :(Mooncake.Dual{<:$t}), arg_type_symbols) + arg_types_adjoint = map(t -> :(Mooncake.CoDual{<:$t}), arg_type_symbols) + body_deriv = Expr(:call, Mooncake.zero_derivative, arg_names...) + body_adjoint = Expr(:call, Mooncake.zero_adjoint, arg_names...) end - # Return code to create a method of is_primitive and a rule. - ex = quote - Mooncake.is_primitive(::Type{$(esc(ctx))}, ::Type{<:$(esc(sig))}) = true - $(construct_def(arg_names, arg_types, where_params, body)) + # Construct is_primitive statement. If no mode is provided, then construct a statement + # which does not escape the mode argument. This will work even if the names `Mooncake` + # or `Mooncake.Mode` are not available in the scope which calls this macro. + is_primitive_ex = quote + const M = $mode + function Mooncake.is_primitive( + ::Type{$(esc(ctx))}, ::Type{<:M}, ::Type{<:$(esc(sig))} + ) + return true + end end - return ex + + # Figuring out which mode argument was actually provided is going to be very hard in + # general, and rather error prone, because the mode might appear as a `Type`, one of + # several `Symbol`s, or possibly something else not considered. As a result, we always + # define both the frule and rrule, and rely on the method of `is_primitive` defined + # above to determine whether or not they do anything. This might inflate the method + # table a bit for `frule!!` and `rrule!!` unnecessarily, but it will be robust. + frule_ex = construct_frule_def(arg_names, arg_types_deriv, where_params, body_deriv) + rrule_ex = construct_rrule_def(arg_names, arg_types_adjoint, where_params, body_adjoint) + + return Expr(:block, is_primitive_ex, frule_ex, rrule_ex) +end + +""" + @zero_adjoint ctx sig + +Equivalent to `@zero_derivative ctx sig ReverseMode`. Consult the docstring for +[`@zero_derivative`](@ref) for more information. +""" +macro zero_adjoint(ctx, sig) + return _zero_derivative_impl(ctx, sig, ReverseMode) end # @@ -242,6 +329,25 @@ to_cr_tangent(t::Tangent) = CRC.Tangent{Any}(; map(to_cr_tangent, t.fields)...) to_cr_tangent(t::MutableTangent) = CRC.Tangent{Any}(; map(to_cr_tangent, t.fields)...) to_cr_tangent(t::Tuple) = CRC.Tangent{Any}(map(to_cr_tangent, t)...) +""" + mooncake_tangent(p, cr_tangent) + +For primal `p` and a tangent used by ChainRules `cr_tangent`, returns the tangent of type +`tangent_type(typeof(p))`. Useful for converting the result of a `ChainRules.frule` into +something that Mooncake can use. +""" +mooncake_tangent(p, ::CRC.NoTangent) = NoTangent() +mooncake_tangent(p, t::IEEEFloat) = t +mooncake_tangent(p::Array, t::Array{<:IEEEFloat}) = t +mooncake_tangent(p::Array, t::Array) = map(mooncake_tangent, p, t) +mooncake_tangent(p, t::CRC.ZeroTangent) = zero_tangent(p) +function mooncake_tangent(p::P, t::T) where {P,T<:Tuple} + return tangent_type(P) == NoTangent ? NoTangent() : map(mooncake_tangent, p, t) +end +function mooncake_tangent(p::P, t::T) where {P<:Tuple,T<:CRC.Tangent} + return tangent_type(P) == NoTangent ? NoTangent() : map(mooncake_tangent, p, t.backing) +end + """ increment_and_get_rdata!(fdata, zero_rdata, cr_tangent) @@ -258,6 +364,29 @@ function increment_and_get_rdata!(f, r, t::CRC.Thunk) return increment_and_get_rdata!(f, r, CRC.unthunk(t)) end +""" + frule_wrapper(f::Dual, args::Dual...) + +Implements an `frule!!` for `f` applied to `args` by calling `ChainRulesCore.frule`. +""" +function frule_wrapper(fargs::Vararg{Dual,N}) where {N} + tangents = tuple_map(to_cr_tangent ∘ tangent, fargs) + Ω, dΩ = CRC.frule(tangents, tuple_map(primal, fargs)...) + return Dual(Ω, mooncake_tangent(Ω, dΩ)) +end + +function frule_wrapper(::Dual{typeof(Core.kwcall)}, fargs::Vararg{Dual,N}) where {N} + primals = map(primal, fargs) + tangents = map(to_cr_tangent ∘ tangent, fargs[2:end]) + Ω, dΩ = Core.kwcall(primals[1], CRC.frule, tangents, primals[2:end]...) + return Dual(Ω, mooncake_tangent(Ω, dΩ)) +end + +function construct_frule_wrapper_def(arg_names, arg_types, where_params) + body = Expr(:call, frule_wrapper, arg_names...) + return construct_frule_def(arg_names, arg_types, where_params, body) +end + """ rrule_wrapper(f::CoDual, args::CoDual...) @@ -334,53 +463,61 @@ end function construct_rrule_wrapper_def(arg_names, arg_types, where_params) body = Expr(:call, rrule_wrapper, arg_names...) - return construct_def(arg_names, arg_types, where_params, body) + return construct_rrule_def(arg_names, arg_types, where_params, body) end """ - @from_rrule ctx sig [has_kwargs=false] + @from_chainrules ctx sig [has_kwargs=false mode=nothing] -Convenience functionality to assist in using `ChainRulesCore.rrule`s to write `rrule!!`s. +Convenience functionality to assist in using `ChainRuleCore.frule`s and +`ChainRulesCore.rrule`s to write `frule!!`s and `rrule!!`s. # Arguments - `ctx`: A Mooncake context type - `sig`: the signature which you wish to assert should be a primitive in `Mooncake.jl`, and - use an existing `ChainRulesCore.rrule` to implement this functionality. -- `has_kwargs`: a `Bool` state whether or not the function has keyword arguments. This - feature has the same limitations as `ChainRulesCore.rrule` -- the derivative w.r.t. all - kwargs must be zero. + use an existing `ChainRulesCore.rrule` or `ChainRulesCore.frule` to implement this functionality. +- `has_kwargs=true`: a `Bool` stating whether or not the function has keyword arguments. + This feature has the same limitations as `ChainRulesCore.frule` and + `ChainRulesCore.rrule` and -- the derivative w.r.t. all kwargs must be zero. +- `mode=nothing`: the mode to produce rules for. By default, produces rules for both forward + and reverse mode. If `mode=ForwardMode` only rules for forward mode are produced. If + `mode=ReverseMode` only rules for reverse mode are produced. # Example Usage ## A Basic Example ```jldoctest -julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils +julia> using Mooncake: @from_chainrules, DefaultCtx, frule!!, rrule!!, Dual, zero_dual, zero_fcodual, TestUtils julia> import ChainRulesCore julia> foo(x::Real) = 5x; +julia> ChainRulesCore.frule((df, dx), ::typeof(foo), x::Real) = 5x, 5dx; + julia> function ChainRulesCore.rrule(::typeof(foo), x::Real) foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), 5Ω return foo(x), foo_pb end; -julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} +julia> @from_chainrules DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} + +julia> frule!!(zero_dual(foo), Dual(5.0, 2.0)) +Dual{Float64, Float64}(25.0, 10.0) julia> rrule!!(zero_fcodual(foo), zero_fcodual(5.0))[2](1.0) (NoRData(), 5.0) -julia> # Check that the rule works as intended. - TestUtils.test_rule(Xoshiro(123), foo, 5.0; is_primitive=true) -Test Passed +julia> # Check that the rule works as intended. Put this in your test suite. + TestUtils.test_rule(Xoshiro(123), foo, 5.0; is_primitive=true, print_results=false); ``` -## An Example with Keyword Arguments +## An Example with Keyword Arguments and ReverseMode ```jldoctest -julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils +julia> using Mooncake: @from_chainrules, DefaultCtx, rrule!!, zero_fcodual, TestUtils, ReverseMode julia> import ChainRulesCore @@ -391,7 +528,7 @@ julia> function ChainRulesCore.rrule(::typeof(foo), x::Real; cond::Bool) return foo(x; cond), foo_pb end; -julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} true +julia> @from_chainrules DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} true ReverseMode julia> _, pb = rrule!!( zero_fcodual(Core.kwcall), @@ -403,11 +540,11 @@ julia> _, pb = rrule!!( julia> pb(3.0) (NoRData(), NoRData(), NoRData(), 12.0) -julia> # Check that the rule works as intended. +julia> # Check that the rule works as intended. Put this in your test suite. TestUtils.test_rule( - Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true - ) -Test Passed + Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; + is_primitive=true, print_results=false, mode=Mooncake.ReverseMode, + ); ``` Notice that, in order to access the kwarg method we must call the method of `Core.kwcall`, as Mooncake's `rrule!!` does not itself permit the use of kwargs. @@ -431,7 +568,7 @@ Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}} There are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length. -Suffice it to say, you should not write rules for _this_ package which are so generically +Suffice it to say, you should not write rules for Mooncake which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the `ChainRulesCore.rrule` will work correctly, and leave this package to derive rules for the @@ -443,40 +580,80 @@ someone might define. # Conversions Between Different Tangent Type Systems -Under the hood, this functionality relies on two functions: `Mooncake.to_cr_tangent`, and -`Mooncake.increment_and_get_rdata!`. These two functions handle conversion to / from -`Mooncake` tangent types and `ChainRulesCore` tangent types. This functionality is known to -work well for simple types, but has not been tested to a great extent on complicated -composite types. If `@from_rrule` does not work in your case because the required method of -either of these functions does not exist, please open an issue. +Under the hood, this functionality relies on three functions: `Mooncake.mooncake_tangent`, +`Mooncake.to_cr_tangent`, and `Mooncake.increment_and_get_rdata!`. These two functions +handle conversion to / from `Mooncake` tangent types and `ChainRulesCore` tangent types. +This functionality is known to work well for simple types, but has not been tested to a +great extent on complicated composite types. If `@from_chainrules` does not work in your +case because the required method of either of these functions does not exist, please open an +issue. """ -macro from_rrule(ctx, sig::Expr, has_kwargs::Bool=false) +macro from_chainrules(ctx, sig::Expr, has_kwargs::Bool=false, mode=Mode) + return _from_chainrules_impl(ctx, sig, has_kwargs, mode) +end + +function _from_chainrules_impl(ctx, sig::Expr, has_kwargs::Bool, mode) arg_type_syms, where_params = parse_signature_expr(sig) arg_names = map(n -> Symbol("x_$n"), eachindex(arg_type_syms)) - arg_types = map(t -> :(Mooncake.CoDual{<:$t}), arg_type_syms) - rule_expr = construct_rrule_wrapper_def(arg_names, arg_types, where_params) + dual_arg_types = map(t -> :(Mooncake.Dual{<:$t}), arg_type_syms) + codual_arg_types = map(t -> :(Mooncake.CoDual{<:$t}), arg_type_syms) + frule_expr = construct_frule_wrapper_def(arg_names, dual_arg_types, where_params) + rrule_expr = construct_rrule_wrapper_def(arg_names, codual_arg_types, where_params) if has_kwargs kw_sig = Expr(:curly, :Tuple, :(typeof(Core.kwcall)), :NamedTuple, arg_type_syms...) kw_sig = where_params === nothing ? kw_sig : Expr(:where, kw_sig, where_params...) - kw_is_primitive = :(Mooncake.is_primitive(::Type{$ctx}, ::Type{<:$kw_sig}) = true) - kwcall_type = :(Mooncake.CoDual{typeof(Core.kwcall)}) - nt_type = :(Mooncake.CoDual{<:NamedTuple}) - kwargs_rule_expr = construct_rrule_wrapper_def( + # Type M will be available later on, and will be the mode type. + kw_is_primitive = quote + function Mooncake.is_primitive(::Type{$(esc(ctx))}, ::Type{<:M}, ::Type{<:$kw_sig}) + return true + end + end + kwargs_frule_expr = construct_frule_wrapper_def( vcat(:_kwcall, :kwargs, arg_names), - vcat(kwcall_type, nt_type, arg_types), + vcat( + :(Mooncake.Dual{typeof(Core.kwcall)}), + :(Mooncake.Dual{<:NamedTuple}), + dual_arg_types, + ), + where_params, + ) + kwargs_rrule_expr = construct_rrule_wrapper_def( + vcat(:_kwcall, :kwargs, arg_names), + vcat( + :(Mooncake.CoDual{typeof(Core.kwcall)}), + :(Mooncake.CoDual{<:NamedTuple}), + codual_arg_types, + ), where_params, ) else kw_is_primitive = nothing - kwargs_rule_expr = nothing + kwargs_frule_expr = nothing + kwargs_rrule_expr = nothing end - ex = quote - Mooncake.is_primitive(::Type{$(esc(ctx))}, ::Type{<:($(esc(sig)))}) = true - $rule_expr + return quote + const M = $mode + function Mooncake.is_primitive( + ::Type{$(esc(ctx))}, ::Type{<:M}, ::Type{<:($(esc(sig)))} + ) + return true + end + $frule_expr + $rrule_expr $kw_is_primitive - $kwargs_rule_expr + $kwargs_frule_expr + $kwargs_rrule_expr end - return ex +end + +""" + @from_rrule ctx sig [has_kwargs=false] + +Equivalent to `@from_chainrules ctx sig has_kwargs ReverseMode`. See +[`@from_chainrules`](@ref) for more information. +""" +macro from_rrule(ctx, sig::Expr, has_kwargs::Bool=false) + return _from_chainrules_impl(ctx, sig, has_kwargs, ReverseMode) end diff --git a/src/utils.jl b/src/utils.jl index 682c65d4ff..c2218efea7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -159,7 +159,7 @@ end Returns a 2-tuple. The first element is true if `m` is a vararg method, and false if not. The second element contains the names of the static parameters associated to `m`. """ -is_vararg_and_sparam_names(m::Method) = m.isva, sparam_names(m) +is_vararg_and_sparam_names(m::Method)::Tuple{Bool,Vector{Symbol}} = m.isva, sparam_names(m) """ is_vararg_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}} @@ -185,6 +185,14 @@ function is_vararg_and_sparam_names(mi::Core.MethodInstance)::Tuple{Bool,Vector{ return is_vararg_and_sparam_names(mi.def) end +""" + is_vararg_and_sparam_names(mc::MistyClosure)::Tuple{Bool,Vector{Symbol}} + +Basic implementation for MistyClosure. Assumes is not a varargs function, and has no static +parameter names appearing in the source. +""" +is_vararg_and_sparam_names(::MistyClosure)::Tuple{Bool,Vector{Symbol}} = false, Symbol[] + """ sparam_names(m::Core.Method)::Vector{Symbol} @@ -357,3 +365,56 @@ function misty_closure( ) return MistyClosure(opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir)) end + +""" + _copytrito!(B::AbstractMatrix, A::AbstractMatrix, uplo::AbstractChar) + +Literally just copied over from Julia's LinearAlgebra std lib, in order to produce a method +which works on 1.10. +""" +function _copytrito!(B::AbstractMatrix, A::AbstractMatrix, uplo::AbstractChar) + @static if VERSION >= v"1.11" + return copytrito!(B, A, uplo) + end + Base.require_one_based_indexing(A, B) + BLAS.chkuplo(uplo) + m, n = size(A) + m1, n1 = size(B) + A = Base.unalias(B, A) + if uplo == 'U' + if n < m + (m1 < n || n1 < n) && throw( + DimensionMismatch( + lazy"B of size ($m1,$n1) should have at least size ($n,$n)" + ), + ) + else + (m1 < m || n1 < n) && throw( + DimensionMismatch( + lazy"B of size ($m1,$n1) should have at least size ($m,$n)" + ), + ) + end + for j in 1:n, i in 1:min(j, m) + @inbounds B[i, j] = A[i, j] + end + else # uplo == 'L' + if m < n + (m1 < m || n1 < m) && throw( + DimensionMismatch( + lazy"B of size ($m1,$n1) should have at least size ($m,$m)" + ), + ) + else + (m1 < m || n1 < n) && throw( + DimensionMismatch( + lazy"B of size ($m1,$n1) should have at least size ($m,$n)" + ), + ) + end + for j in 1:n, i in j:m + @inbounds B[i, j] = A[i, j] + end + end + return B +end diff --git a/test/developer_tools.jl b/test/developer_tools.jl index 7c65bfd67e..3d5b3f27ec 100644 --- a/test/developer_tools.jl +++ b/test/developer_tools.jl @@ -1,6 +1,8 @@ @testset "developer_tools" begin sig = Tuple{typeof(sin),Float64} - @test Mooncake.primal_ir(sig) isa CC.IRCode + @test Mooncake.primal_ir(Mooncake.MooncakeInterpreter(ForwardMode), sig) isa CC.IRCode + @test Mooncake.primal_ir(Mooncake.MooncakeInterpreter(ReverseMode), sig) isa CC.IRCode + @test Mooncake.dual_ir(sig) isa CC.IRCode @test Mooncake.fwd_ir(sig) isa CC.IRCode @test Mooncake.rvs_ir(sig) isa CC.IRCode end diff --git a/test/ext/cuda/cuda.jl b/test/ext/cuda/cuda.jl index d246a4e54d..357814773f 100644 --- a/test/ext/cuda/cuda.jl +++ b/test/ext/cuda/cuda.jl @@ -21,6 +21,7 @@ using Mooncake.TestUtils: test_tangent_interface, test_tangent_splitting, test_r interface_only=true, is_primitive=true, debug_mode=true, + mode=Mooncake.ReverseMode, ) else println("Tests are skipped since no CUDA device was found. ") diff --git a/test/ext/dynamic_expressions/dynamic_expressions.jl b/test/ext/dynamic_expressions/dynamic_expressions.jl index c154bc56fb..7c761af9e1 100644 --- a/test/ext/dynamic_expressions/dynamic_expressions.jl +++ b/test/ext/dynamic_expressions/dynamic_expressions.jl @@ -187,6 +187,7 @@ end perf_flag=:none, is_primitive=false, unsafe_perturb=true, + mode=Mooncake.ReverseMode, ) end @@ -201,6 +202,7 @@ end perf_flag=:none, is_primitive=false, unsafe_perturb=true, + mode=Mooncake.ReverseMode, ) end diff --git a/test/ext/flux/flux.jl b/test/ext/flux/flux.jl index ea12979fac..609e2fa1f9 100644 --- a/test/ext/flux/flux.jl +++ b/test/ext/flux/flux.jl @@ -20,6 +20,8 @@ using Mooncake.TestUtils: test_rule ) end, ) - test_rule(StableRNG(123), f, fargs...; interface_only, perf_flag, is_primitive) + rng = StableRNG(123) + mode = Mooncake.ReverseMode + test_rule(rng, f, fargs...; interface_only, perf_flag, is_primitive, mode) end end diff --git a/test/ext/luxlib/luxlib.jl b/test/ext/luxlib/luxlib.jl index 56a5a8c586..58e6c35145 100644 --- a/test/ext/luxlib/luxlib.jl +++ b/test/ext/luxlib/luxlib.jl @@ -85,6 +85,7 @@ using Mooncake.TestUtils: test_rule end, ), ) - test_rule(StableRNG(123), fargs...; perf_flag, is_primitive, interface_only) + mode = Mooncake.ReverseMode + test_rule(StableRNG(123), fargs...; perf_flag, is_primitive, interface_only, mode) end end diff --git a/test/ext/nnlib/nnlib.jl b/test/ext/nnlib/nnlib.jl index f8e74b5044..c65056b4e1 100644 --- a/test/ext/nnlib/nnlib.jl +++ b/test/ext/nnlib/nnlib.jl @@ -138,6 +138,7 @@ dropout_tester_3(Trng, x, p) = dropout(Trng(1), x, p; dims=(1, 2)) @info "$(typeof(fargs))" perf_flag = cuda ? :none : perf_flag - test_rule(StableRNG(123), fargs...; perf_flag, is_primitive, interface_only) + mode = Mooncake.ReverseMode + test_rule(StableRNG(123), fargs...; perf_flag, is_primitive, interface_only, mode) end end diff --git a/test/ext/special_functions/special_functions.jl b/test/ext/special_functions/special_functions.jl index 613ceba10d..42eb71bd7f 100644 --- a/test/ext/special_functions/special_functions.jl +++ b/test/ext/special_functions/special_functions.jl @@ -3,6 +3,7 @@ Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using AllocCheck, JET, Mooncake, SpecialFunctions, StableRNGs, Test +using Mooncake: ForwardMode, ReverseMode using Mooncake.TestUtils: test_rule # Rules in this file are only lightly tester, because they are all just @from_rrule rules. diff --git a/test/front_matter.jl b/test/front_matter.jl index a42b810f38..d1071a611f 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -38,6 +38,7 @@ using Mooncake: _add_to_primal, _diff, _dot, + zero_dual, zero_codual, codual_type, rrule!!, @@ -50,7 +51,10 @@ using Mooncake: rdata_type, fdata, rdata, - get_interpreter + get_interpreter, + Mode, + ForwardMode, + ReverseMode using Mooncake: CC, diff --git a/test/integration_testing/array/array.jl b/test/integration_testing/array/array.jl index 4fd4ab190d..0c0a1f7fcc 100644 --- a/test/integration_testing/array/array.jl +++ b/test/integration_testing/array/array.jl @@ -662,14 +662,6 @@ _getter() = 5.0 ) @testset for (interface_only, perf_flag, f, x...) in test_cases @info Mooncake._typeof((f, x...)) - test_rule( - sr(123456), - f, - x...; - interface_only, - is_primitive=false, - debug_mode=false, - perf_flag, - ) + test_rule(sr(123456), f, x...; interface_only, is_primitive=false, perf_flag) end end diff --git a/test/integration_testing/bijectors/bijectors.jl b/test/integration_testing/bijectors/bijectors.jl index 924fda5fb2..998c1f34d0 100644 --- a/test/integration_testing/bijectors/bijectors.jl +++ b/test/integration_testing/bijectors/bijectors.jl @@ -24,7 +24,7 @@ function b_binv_test_case(bijector, dim; name=nothing, rng=StableRNG(23)) if name === nothing name = string(bijector) end - return TestCase(x -> bijector(inverse(bijector)(x)), randn(rng, dim); name=name) + return TestCase(x -> bijector(inverse(bijector)(x)), randn(rng, dim); name) end @testset "Bijectors integration tests" begin @@ -50,7 +50,7 @@ end 1 0 0 0 0 1 ]), (3, 3)), - b_binv_test_case(Bijectors.PlanarLayer(3), (3, 3)), + # b_binv_test_case(Bijectors.PlanarLayer(3), (3, 3)), b_binv_test_case(Bijectors.RadialLayer(3), 3), b_binv_test_case(Bijectors.Reshape((2, 3), (3, 2)), (2, 3)), b_binv_test_case(Bijectors.Scale(0.2), 3), @@ -111,15 +111,15 @@ end ), ] - @testset "$(case.name)" for case in test_cases - if case.broken + @testset "$(c.name)" for c in test_cases + if c.broken @test_broken begin - test_rule(StableRNG(123456), case.func, case.arg; is_primitive=false) + test_rule(StableRNG(123456), c.func, c.arg; is_primitive=false) true end else rng = StableRNG(123456) - test_rule(rng, case.func, case.arg; is_primitive=false, unsafe_perturb=true) + test_rule(rng, c.func, c.arg; is_primitive=false, unsafe_perturb=true) end end end diff --git a/test/integration_testing/diff_tests/diff_tests.jl b/test/integration_testing/diff_tests/diff_tests.jl index 16f884bb68..ac66aa1d7e 100644 --- a/test/integration_testing/diff_tests/diff_tests.jl +++ b/test/integration_testing/diff_tests/diff_tests.jl @@ -3,71 +3,69 @@ Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using DiffTests, LinearAlgebra, Mooncake, Random, StableRNGs, Test -using Mooncake.TestUtils: test_rule -# Tests brought in from DiffTests.jl -const _rng = Xoshiro(123456) +sr(n::Int) = StableRNG(n) const TEST_CASES = vcat( tuple.( fill(false, length(DiffTests.NUMBER_TO_NUMBER_FUNCS)), DiffTests.NUMBER_TO_NUMBER_FUNCS, - rand(_rng, length(DiffTests.NUMBER_TO_NUMBER_FUNCS)) .+ 1e-1, + rand(sr(1), length(DiffTests.NUMBER_TO_NUMBER_FUNCS)) .+ 1e-1, ), tuple.( fill(false, length(DiffTests.NUMBER_TO_ARRAY_FUNCS)), DiffTests.NUMBER_TO_ARRAY_FUNCS, - [rand(_rng) + 1e-1 for _ in DiffTests.NUMBER_TO_ARRAY_FUNCS], + [rand(sr(2)) + 1e-1 for _ in DiffTests.NUMBER_TO_ARRAY_FUNCS], ), tuple.( fill(false, length(DiffTests.INPLACE_NUMBER_TO_ARRAY_FUNCS)), DiffTests.INPLACE_NUMBER_TO_ARRAY_FUNCS, - [rand(_rng, 5) .+ 1e-1 for _ in DiffTests.INPLACE_ARRAY_TO_ARRAY_FUNCS], - [rand(_rng) + 1e-1 for _ in DiffTests.INPLACE_ARRAY_TO_ARRAY_FUNCS], + [rand(sr(3), 5) .+ 1e-1 for _ in DiffTests.INPLACE_ARRAY_TO_ARRAY_FUNCS], + [rand(sr(4)) + 1e-1 for _ in DiffTests.INPLACE_ARRAY_TO_ARRAY_FUNCS], ), tuple.( fill(false, length(DiffTests.VECTOR_TO_NUMBER_FUNCS)), DiffTests.VECTOR_TO_NUMBER_FUNCS, - [rand(_rng, 5) .+ 1e-1 for _ in DiffTests.VECTOR_TO_NUMBER_FUNCS], + [rand(sr(5), 5) .+ 1e-1 for _ in DiffTests.VECTOR_TO_NUMBER_FUNCS], ), tuple.( fill(false, length(DiffTests.MATRIX_TO_NUMBER_FUNCS)), DiffTests.MATRIX_TO_NUMBER_FUNCS, - [rand(_rng, 5, 5) .+ 1e-1 for _ in DiffTests.MATRIX_TO_NUMBER_FUNCS], + [rand(sr(6), 5, 5) .+ 1e-1 for _ in DiffTests.MATRIX_TO_NUMBER_FUNCS], ), tuple.( fill(false, length(DiffTests.BINARY_MATRIX_TO_MATRIX_FUNCS)), DiffTests.BINARY_MATRIX_TO_MATRIX_FUNCS, - [rand(_rng, 5, 5) .+ 1e-1 + I for _ in DiffTests.BINARY_MATRIX_TO_MATRIX_FUNCS], - [rand(_rng, 5, 5) .+ 1e-1 + I for _ in DiffTests.BINARY_MATRIX_TO_MATRIX_FUNCS], + [rand(sr(7), 5, 5) .+ 1e-1 + I for _ in DiffTests.BINARY_MATRIX_TO_MATRIX_FUNCS], + [rand(sr(8), 5, 5) .+ 1e-1 + I for _ in DiffTests.BINARY_MATRIX_TO_MATRIX_FUNCS], ), tuple.( fill(false, length(DiffTests.TERNARY_MATRIX_TO_NUMBER_FUNCS)), DiffTests.TERNARY_MATRIX_TO_NUMBER_FUNCS, - [rand(_rng, 5, 5) .+ 1e-1 for _ in DiffTests.TERNARY_MATRIX_TO_NUMBER_FUNCS], - [rand(_rng, 5, 5) .+ 1e-1 for _ in DiffTests.TERNARY_MATRIX_TO_NUMBER_FUNCS], - [rand(_rng, 5, 5) .+ 1e-1 for _ in DiffTests.TERNARY_MATRIX_TO_NUMBER_FUNCS], + [rand(sr(9), 5, 5) .+ 1e-1 for _ in DiffTests.TERNARY_MATRIX_TO_NUMBER_FUNCS], + [rand(sr(10), 5, 5) .+ 1e-1 for _ in DiffTests.TERNARY_MATRIX_TO_NUMBER_FUNCS], + [rand(sr(11), 5, 5) .+ 1e-1 for _ in DiffTests.TERNARY_MATRIX_TO_NUMBER_FUNCS], ), tuple.( fill(false, length(DiffTests.INPLACE_ARRAY_TO_ARRAY_FUNCS)), DiffTests.INPLACE_ARRAY_TO_ARRAY_FUNCS, - [rand(_rng, 26) .+ 1e-1 for _ in DiffTests.INPLACE_ARRAY_TO_ARRAY_FUNCS], - [rand(_rng, 26) .+ 1e-1 for _ in DiffTests.INPLACE_ARRAY_TO_ARRAY_FUNCS], + [rand(sr(12), 26) .+ 1e-1 for _ in DiffTests.INPLACE_ARRAY_TO_ARRAY_FUNCS], + [rand(sr(13), 26) .+ 1e-1 for _ in DiffTests.INPLACE_ARRAY_TO_ARRAY_FUNCS], ), tuple.( fill(false, length(DiffTests.VECTOR_TO_VECTOR_FUNCS)), DiffTests.VECTOR_TO_VECTOR_FUNCS, - [rand(_rng, 26) .+ 1e-1 for _ in DiffTests.VECTOR_TO_VECTOR_FUNCS], + [rand(sr(14), 26) .+ 1e-1 for _ in DiffTests.VECTOR_TO_VECTOR_FUNCS], ), tuple.( fill(false, length(DiffTests.ARRAY_TO_ARRAY_FUNCS)), DiffTests.ARRAY_TO_ARRAY_FUNCS, - [rand(_rng, 26) .+ 1e-1 for _ in DiffTests.ARRAY_TO_ARRAY_FUNCS], + [rand(sr(15), 26) .+ 1e-1 for _ in DiffTests.ARRAY_TO_ARRAY_FUNCS], ), tuple.( fill(false, length(DiffTests.MATRIX_TO_MATRIX_FUNCS)), DiffTests.MATRIX_TO_MATRIX_FUNCS, - [rand(_rng, 5, 5) .+ 1e-1 for _ in DiffTests.MATRIX_TO_MATRIX_FUNCS], + [rand(sr(16), 5, 5) .+ 1e-1 for _ in DiffTests.MATRIX_TO_MATRIX_FUNCS], ), ) @@ -76,6 +74,6 @@ const TEST_CASES = vcat( vcat(TEST_CASES[1:66], TEST_CASES[68:89], TEST_CASES[91:end]) ) @info "$n: $(typeof((f, x...)))" - test_rule(StableRNG(123456), f, x...; is_primitive=false) + Mooncake.TestUtils.test_rule(StableRNG(123456), f, x...; is_primitive=false) end end diff --git a/test/integration_testing/distributions/distributions.jl b/test/integration_testing/distributions/distributions.jl index efbee197f9..06179fad19 100644 --- a/test/integration_testing/distributions/distributions.jl +++ b/test/integration_testing/distributions/distributions.jl @@ -284,11 +284,11 @@ sr(n::Int) = StableRNG(n) @testset "$(typeof(d))" for (perf_flag, d, x) in logpdf_test_cases @info "$(map(typeof, (d, x)))" - test_rule(StableRNG(123456), logpdf, d, x; perf_flag, is_primitive=false) + test_rule(StableRNG(123546), logpdf, d, x; perf_flag, is_primitive=false) end @testset "$name" for (perf_flag, name, f, x) in work_around_test_cases @info "$name" - test_rule(StableRNG(123456), f, x...; perf_flag=perf_flag, is_primitive=false) + test_rule(StableRNG(123456), f, x...; perf_flag, is_primitive=false) end end diff --git a/test/integration_testing/flux/flux.jl b/test/integration_testing/flux/flux.jl index b7d8176fa6..45f15e3daf 100644 --- a/test/integration_testing/flux/flux.jl +++ b/test/integration_testing/flux/flux.jl @@ -9,7 +9,7 @@ using Bijectors, Flux, Mooncake, StableRNGs # This example below tests a bug found at https://github.com/chalk-lab/Mooncake.jl/issues/661 # -# just define a MLP +# just define an MLP function mlp3( input_dim::Int, hidden_dims::Int, @@ -27,7 +27,7 @@ end inputdim = 4 mask_idx = 1:2:inputdim -# creat a masking layer +# create a masking layer mask = Bijectors.PartitionMask(inputdim, mask_idx) cdim = length(mask_idx) @@ -56,6 +56,7 @@ Mooncake.TestUtils.test_rule( is_primitive=false, interface_only=true, unsafe_perturb=true, + mode=Mooncake.ReverseMode, ) struct ACL @@ -87,4 +88,5 @@ Mooncake.TestUtils.test_rule( is_primitive=false, interface_only=true, unsafe_perturb=true, + mode=Mooncake.ReverseMode, ) diff --git a/test/integration_testing/gp/gp.jl b/test/integration_testing/gp/gp.jl index 01358badf9..70144bdb0d 100644 --- a/test/integration_testing/gp/gp.jl +++ b/test/integration_testing/gp/gp.jl @@ -41,7 +41,7 @@ using Mooncake.TestUtils: test_rule ], ) fx = GP(k)(x1, 1.1) - @testset "$(typeof(args))" for args in Any[ + @testset "$(typeof(x))" for x in Any[ (kernelmatrix, k, x1, x2), (kernelmatrix_diag, k, x1, x2), (kernelmatrix, k, x1), @@ -49,8 +49,8 @@ using Mooncake.TestUtils: test_rule (fx -> rand(StableRNG(123456), fx), fx), (logpdf, fx, rand(rng, fx)), ] - @info typeof(args) - test_rule(rng, args...; is_primitive=false, unsafe_perturb=true) + @info typeof(x) + test_rule(rng, x...; is_primitive=false, unsafe_perturb=true) end end end diff --git a/test/integration_testing/lux/lux.jl b/test/integration_testing/lux/lux.jl index 06e3543421..27a8b663ca 100644 --- a/test/integration_testing/lux/lux.jl +++ b/test/integration_testing/lux/lux.jl @@ -123,9 +123,10 @@ sr(x) = StableRNG(x) @info "$(typeof((f, x_f32...)))" rng = sr(123546) ps, st = f32(Lux.setup(rng, f)) - x = f32(x_f32) + mode = Mooncake.ReverseMode + fargs = (f, f32(x_f32), ps, st) test_rule( - rng, f, x, ps, st; is_primitive=false, interface_only, unsafe_perturb=true + rng, fargs...; is_primitive=false, interface_only, unsafe_perturb=true, mode ) end end diff --git a/test/integration_testing/temporalgps/temporalgps.jl b/test/integration_testing/temporalgps/temporalgps.jl index a9d7e91224..afc39b3b18 100644 --- a/test/integration_testing/temporalgps/temporalgps.jl +++ b/test/integration_testing/temporalgps/temporalgps.jl @@ -3,7 +3,6 @@ Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using AbstractGPs, KernelFunctions, Mooncake, StableRNGs, TemporalGPs, Test -using Mooncake.TestUtils: test_rule build_gp(k) = to_sde(GP(k), SArrayStorage(Float64)) @@ -23,6 +22,6 @@ temporalgps_logpdf_tester(k, x, y, s) = logpdf(build_gp(k)(x, s), y) f = temporalgps_logpdf_tester sig = typeof((temporalgps_logpdf_tester, k, x, y, s)) @info "$sig" - test_rule(StableRNG(123456), f, k, x, y, s; is_primitive=false) + Mooncake.TestUtils.test_rule(StableRNG(123456), f, k, x, y, s; is_primitive=false) end end diff --git a/test/interface.jl b/test/interface.jl index 0dba6466d8..276a589e4e 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -210,4 +210,17 @@ using Mooncake.TestUtils: count_allocs end end end + @testset "forwards mode ($kwargs)" for kwargs in [ + (;), + (; debug_mode=true), + (; debug_mode=false), + (; debug_mode=true, silence_debug_messages=true), + ] + f = (x, y) -> x * y + cos(x) + fx = (f, 5.0, 4.0) + rule = Mooncake.prepare_derivative_cache(fx...; kwargs...) + z = Mooncake.value_and_derivative!!(rule, map(zero_dual, fx)...) + @test z isa Mooncake.Dual + @test primal(z) == f(5.0, 4.0) + end end diff --git a/test/interpreter/abstract_interpretation.jl b/test/interpreter/abstract_interpretation.jl index a007054a2c..f992ec0d70 100644 --- a/test/interpreter/abstract_interpretation.jl +++ b/test/interpreter/abstract_interpretation.jl @@ -1,9 +1,13 @@ a_primitive(x) = sin(x) non_primitive(x) = sin(x) -Mooncake.is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(a_primitive),Any}}) = true function Mooncake.is_primitive( - ::Type{DefaultCtx}, ::Type{<:Tuple{typeof(non_primitive),Any}} + ::Type{DefaultCtx}, ::Type{ReverseMode}, ::Type{<:Tuple{typeof(a_primitive),Any}} +) + return true +end +function Mooncake.is_primitive( + ::Type{DefaultCtx}, ::Type{ReverseMode}, ::Type{<:Tuple{typeof(non_primitive),Any}} ) return false end @@ -28,7 +32,7 @@ contains_primitive_behind_call(x) = @inline contains_primitive(x) @assert stmt(usual_ir.stmts)[invoke_line].args[2] == GlobalRef(Main, :sin) # Should continue to inline away under AD compilation. - interp = Mooncake.MooncakeInterpreter(DefaultCtx) + interp = Mooncake.MooncakeInterpreter(DefaultCtx, ReverseMode) ad_ir = Base.code_ircode_by_type(sig; interp)[1][1] invoke_line = findfirst(x -> Meta.isexpr(x, :invoke), stmt(ad_ir.stmts)) @test stmt(ad_ir.stmts)[invoke_line].args[2] == GlobalRef(Main, :sin) @@ -45,7 +49,7 @@ contains_primitive_behind_call(x) = @inline contains_primitive(x) @assert stmt(usual_ir.stmts)[invoke_line].args[2] == GlobalRef(Main, :sin) # Should not inline away under AD compilation. - interp = Mooncake.MooncakeInterpreter(DefaultCtx) + interp = Mooncake.MooncakeInterpreter(DefaultCtx, ReverseMode) ad_ir = Base.code_ircode_by_type(sig; interp)[1][1] invoke_line = findfirst(x -> Meta.isexpr(x, :invoke), stmt(ad_ir.stmts)) @test stmt(ad_ir.stmts)[invoke_line].args[2] == GlobalRef(Main, :a_primitive) @@ -64,7 +68,7 @@ contains_primitive_behind_call(x) = @inline contains_primitive(x) @assert stmt(usual_ir.stmts)[invoke_line].args[2] == GlobalRef(Main, :sin) # Should not inline away under AD compilation. - interp = Mooncake.MooncakeInterpreter(DefaultCtx) + interp = Mooncake.MooncakeInterpreter(DefaultCtx, ReverseMode) ad_ir = Base.code_ircode_by_type(sig; interp)[1][1] invoke_line = findfirst(x -> Meta.isexpr(x, :invoke), stmt(ad_ir.stmts)) @test stmt(ad_ir.stmts)[invoke_line].args[2] == GlobalRef(Main, :a_primitive) diff --git a/test/interpreter/contexts.jl b/test/interpreter/contexts.jl index 01f6b69c46..1ba066d3de 100644 --- a/test/interpreter/contexts.jl +++ b/test/interpreter/contexts.jl @@ -9,6 +9,9 @@ foo(x) = x end @testset "contexts" begin - @test Mooncake.is_primitive(DefaultCtx, Tuple{typeof(ContextsTestModule.foo),Float64}) - @test !Mooncake.is_primitive(DefaultCtx, Tuple{typeof(ContextsTestModule.foo),Real}) + @testset "$mode" for mode in [Mooncake.ForwardMode, Mooncake.ReverseMode] + Tf = typeof(ContextsTestModule.foo) + @test Mooncake.is_primitive(DefaultCtx, mode, Tuple{Tf,Float64}) + @test !Mooncake.is_primitive(DefaultCtx, mode, Tuple{Tf,Real}) + end end diff --git a/test/interpreter/forward_mode.jl b/test/interpreter/forward_mode.jl new file mode 100644 index 0000000000..197d585f43 --- /dev/null +++ b/test/interpreter/forward_mode.jl @@ -0,0 +1,32 @@ +function foo(x) + y = 0.0 + try + if x > 0 + error("") + end + y = x + catch + y = 2x + end + return y +end + +@testset "s2s_forward_mode_ad" begin + test_cases = collect(enumerate(TestResources.generate_test_functions())) + @testset "$n - $(_typeof((fx)))" for (n, (int_only, pf, _, fx...)) in test_cases + @info "$n: $(_typeof(fx))" + rng = Xoshiro(123546) + mode = ForwardMode + TestUtils.test_rule( + rng, fx...; perf_flag=pf, interface_only=int_only, is_primitive=false, mode + ) + end + + # Try try-catch statements. + rng = StableRNG(123) + perf_flag = :none + interface_only = false + is_primitive = false + mode = ForwardMode + TestUtils.test_rule(rng, foo, 5.0; perf_flag, interface_only, is_primitive, mode) +end; diff --git a/test/interpreter/ir_utils.jl b/test/interpreter/ir_utils.jl index 5bbc382ebb..39a812c4fb 100644 --- a/test/interpreter/ir_utils.jl +++ b/test/interpreter/ir_utils.jl @@ -57,4 +57,13 @@ end Mooncake.replace_uses_with!(stmt, SSAValue(1), 5.0) @test stmt.args[end] == 5.0 end + @testset "characeterise_used_ssas" begin + stmts = Any[ + Expr(:call, sin, Argument(1)), + Expr(:call, sin, SSAValue(1)), + Expr(:call, sin, SSAValue(1)), + ReturnNode(SSAValue(3)), + ] + @test Mooncake.characterised_used_ssas(stmts) == [true, false, true, false] + end end diff --git a/test/interpreter/s2s_reverse_mode_ad.jl b/test/interpreter/reverse_mode.jl similarity index 98% rename from test/interpreter/s2s_reverse_mode_ad.jl rename to test/interpreter/reverse_mode.jl index 569152f73e..8b930f2a4b 100644 --- a/test/interpreter/s2s_reverse_mode_ad.jl +++ b/test/interpreter/reverse_mode.jl @@ -35,7 +35,7 @@ end is_used_dict = Dict{ID,Bool}(id_ssa_1 => true, id_ssa_2 => true) rdata_ref = Ref{Tuple{map(Mooncake.lazy_zero_rdata_type, (Float64, Int))...}}() info = ADInfo( - get_interpreter(), + get_interpreter(ReverseMode), arg_types, ssa_insts, is_used_dict, @@ -90,7 +90,7 @@ end id_line_1 = ID() id_line_2 = ID() info = ADInfo( - get_interpreter(), + get_interpreter(ReverseMode), Dict{Argument,Any}(Argument(1) => typeof(sin), Argument(2) => Float64), Dict{ID,CC.NewInstruction}( id_line_1 => new_inst(Expr(:invoke, nothing, cos, Argument(2)), Float64), @@ -277,7 +277,7 @@ end ], debug_mode in [true, false] - interp = get_interpreter() + interp = get_interpreter(ReverseMode) rule = Mooncake.build_rrule(interp, sig; debug_mode) @test rule isa Mooncake.rule_type(interp, sig; debug_mode) end @@ -290,8 +290,9 @@ end ) sig = _typeof((f, x...)) @info "$n: $sig" + mode = ReverseMode TestUtils.test_rule( - Xoshiro(123456), f, x...; perf_flag, interface_only, is_primitive=false + Xoshiro(123456), f, x...; perf_flag, interface_only, is_primitive=false, mode ) # TestUtils.test_rule( # Xoshiro(123456), @@ -303,7 +304,7 @@ end # debug_mode=true, # ) - # interp = Mooncake.get_interpreter() + # interp = Mooncake.get_interpreter(ReverseMode) # codual_args = map(zero_codual, (f, x...)) # fwds_args = map(Mooncake.to_fwds, codual_args) # rule = Mooncake.build_rrule(interp, sig) @@ -344,6 +345,7 @@ end ones(3); interface_only=false, is_primitive=false, + mode=Mooncake.ReverseMode, ) # BenchmarkTools not working due to world age problems. Provided that this code diff --git a/test/rrules/array_legacy.jl b/test/rrules/array_legacy.jl index b97c40598a..50209b25a6 100644 --- a/test/rrules/array_legacy.jl +++ b/test/rrules/array_legacy.jl @@ -1,3 +1,3 @@ @testset "array_legacy" begin - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:array_legacy)) + TestUtils.run_rule_test_cases(StableRNG, Val(:array_legacy)) end diff --git a/test/rrules/avoiding_non_differentiable_code.jl b/test/rrules/avoiding_non_differentiable_code.jl index bd73b570dd..d69e9372f6 100644 --- a/test/rrules/avoiding_non_differentiable_code.jl +++ b/test/rrules/avoiding_non_differentiable_code.jl @@ -1,3 +1,3 @@ @testset "avoiding_non_differentiable_code" begin - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:avoiding_non_differentiable_code)) + TestUtils.run_rule_test_cases(StableRNG, Val(:avoiding_non_differentiable_code)) end diff --git a/test/rrules/blas.jl b/test/rrules/blas.jl index 50cf1b15b6..09629a18e4 100644 --- a/test/rrules/blas.jl +++ b/test/rrules/blas.jl @@ -1,10 +1,4 @@ @testset "blas" begin - @test_throws ErrorException Mooncake.arrayify(5, 4) - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:blas)) - - @testset "mixed complex-real" begin - TestUtils.test_rule( - StableRNG(123456), x -> sum(complex(x) * x), rand(5, 5); is_primitive=false - ) - end + @test_throws "Encountered unexpected array type" Mooncake.arrayify(5, 4) + TestUtils.run_rule_test_cases(StableRNG, Val(:blas)) end diff --git a/test/rrules/builtins.jl b/test/rrules/builtins.jl index 60fa1e0ec3..216bb7e5a3 100644 --- a/test/rrules/builtins.jl +++ b/test/rrules/builtins.jl @@ -29,7 +29,7 @@ foo_throws(e) = throw(e) @test (@allocations Mooncake.is_homogeneous_and_immutable(x)) == 0 end - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:builtins)) + TestUtils.run_rule_test_cases(StableRNG, Val(:builtins)) # Unhandled built-in throws an intelligible error. @test_throws( diff --git a/test/rrules/fastmath.jl b/test/rrules/fastmath.jl index 2c8d8f82dc..9c2c3e77f5 100644 --- a/test/rrules/fastmath.jl +++ b/test/rrules/fastmath.jl @@ -1,3 +1,3 @@ @testset "fastmath" begin - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:fastmath)) + TestUtils.run_rule_test_cases(StableRNG, Val(:fastmath)) end diff --git a/test/rrules/foreigncall.jl b/test/rrules/foreigncall.jl index 73ad7e499a..a11776cd94 100644 --- a/test/rrules/foreigncall.jl +++ b/test/rrules/foreigncall.jl @@ -1,5 +1,5 @@ @testset "foreigncall" begin - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:foreigncall)) + TestUtils.run_rule_test_cases(StableRNG, Val(:foreigncall)) @testset "foreigncalls that should never be hit: $name" for name in [ :jl_alloc_array_1d, @@ -25,6 +25,10 @@ :memhash32_seed, :jl_get_field_offset, ] + @test_throws( + ErrorException, + Mooncake.frule!!(zero_dual(Mooncake._foreigncall_), zero_dual(Val(name))), + ) @test_throws( ErrorException, Mooncake.rrule!!(zero_codual(Mooncake._foreigncall_), zero_codual(Val(name))), diff --git a/test/rrules/function_wrappers.jl b/test/rrules/function_wrappers.jl index 27dfbcd9a0..2f621ad646 100644 --- a/test/rrules/function_wrappers.jl +++ b/test/rrules/function_wrappers.jl @@ -12,5 +12,5 @@ t = zero_tangent(p) @test Mooncake.to_cr_tangent(t) === t end - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:function_wrappers)) + TestUtils.run_rule_test_cases(StableRNG, Val(:function_wrappers), ReverseMode) end diff --git a/test/rrules/iddict.jl b/test/rrules/iddict.jl index d67de79a9e..1268a3e47f 100644 --- a/test/rrules/iddict.jl +++ b/test/rrules/iddict.jl @@ -5,5 +5,5 @@ TestUtils.test_tangent(sr(123456), p, T; interface_only=false, perf=false) TestUtils.test_tangent_splitting(sr(123456), p) end - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:iddict)) + TestUtils.run_rule_test_cases(StableRNG, Val(:iddict)) end diff --git a/test/rrules/lapack.jl b/test/rrules/lapack.jl index 6f45f41f71..dc8f8d1cef 100644 --- a/test/rrules/lapack.jl +++ b/test/rrules/lapack.jl @@ -1,3 +1,3 @@ @testset "lapack" begin - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:lapack)) + TestUtils.run_rule_test_cases(StableRNG, Val(:lapack)) end diff --git a/test/rrules/linear_algebra.jl b/test/rrules/linear_algebra.jl index f7e1ac0ed3..b479e989e4 100644 --- a/test/rrules/linear_algebra.jl +++ b/test/rrules/linear_algebra.jl @@ -1,3 +1,3 @@ @testset "linear_algebra" begin - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:linear_algebra)) + TestUtils.run_rule_test_cases(StableRNG, Val(:linear_algebra)) end diff --git a/test/rrules/low_level_maths.jl b/test/rrules/low_level_maths.jl index 1b5a13f95e..63b52e5dd8 100644 --- a/test/rrules/low_level_maths.jl +++ b/test/rrules/low_level_maths.jl @@ -1,20 +1,23 @@ @testset "low_level_maths" begin - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:low_level_maths)) + TestUtils.run_rule_test_cases(StableRNG, Val(:low_level_maths)) # These are all examples of signatures which we do _not_ want to make primitives, # because they are very shallow wrappers around lower-level primitives for which we # already have rules. - @testset "$T, $C" for T in [Float16, Float32, Float64], C in [DefaultCtx, MinimalCtx] - @test !is_primitive(C, Tuple{typeof(+),T}) - @test !is_primitive(C, Tuple{typeof(-),T}) - @test !is_primitive(C, Tuple{typeof(abs2),T}) - @test !is_primitive(C, Tuple{typeof(inv),T}) - @test !is_primitive(C, Tuple{typeof(abs),T}) + @testset "$T, $C, $M" for T in [Float16, Float32, Float64], + C in [DefaultCtx, MinimalCtx], + M in [ForwardMode, ReverseMode] - @test !is_primitive(C, Tuple{typeof(+),T,T}) - @test !is_primitive(C, Tuple{typeof(-),T,T}) - @test !is_primitive(C, Tuple{typeof(*),T,T}) - @test !is_primitive(C, Tuple{typeof(/),T,T}) - @test !is_primitive(C, Tuple{typeof(\),T,T}) + @test !is_primitive(C, M, Tuple{typeof(+),T}) + @test !is_primitive(C, M, Tuple{typeof(-),T}) + @test !is_primitive(C, M, Tuple{typeof(abs2),T}) + @test !is_primitive(C, M, Tuple{typeof(inv),T}) + @test !is_primitive(C, M, Tuple{typeof(abs),T}) + + @test !is_primitive(C, M, Tuple{typeof(+),T,T}) + @test !is_primitive(C, M, Tuple{typeof(-),T,T}) + @test !is_primitive(C, M, Tuple{typeof(*),T,T}) + @test !is_primitive(C, M, Tuple{typeof(/),T,T}) + @test !is_primitive(C, M, Tuple{typeof(\),T,T}) end end diff --git a/test/rrules/memory.jl b/test/rrules/memory.jl index 0fd9807bc1..946278d7ef 100644 --- a/test/rrules/memory.jl +++ b/test/rrules/memory.jl @@ -8,7 +8,7 @@ end @testset "$(typeof(p))" for p in generate_data_test_cases(StableRNG, Val(:memory)) TestUtils.test_data(sr(123), p) end - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:memory)) + TestUtils.run_rule_test_cases(StableRNG, Val(:memory)) # Check that the rule for `Memory{P}` only produces two allocations. generate_mem() diff --git a/test/rrules/misc.jl b/test/rrules/misc.jl index 8f2c4a88ae..ca7a57ce68 100644 --- a/test/rrules/misc.jl +++ b/test/rrules/misc.jl @@ -1,12 +1,4 @@ @testset "misc" begin - @testset "misc utility" begin - x = randn(4, 5) - p = Base.unsafe_convert(Ptr{Float64}, x) - @test Mooncake.wrap_ptr_as_view(p, 4, 4, 5) == x - @test Mooncake.wrap_ptr_as_view(p, 4, 2, 5) == x[1:2, :] - @test Mooncake.wrap_ptr_as_view(p, 4, 2, 3) == x[1:2, 1:3] - end - @testset "lgetfield" begin x = (5.0, 4) @test lgetfield(x, Val(1)) == getfield(x, 1) @@ -26,5 +18,5 @@ @test x.b === new_b end - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:misc)) + TestUtils.run_rule_test_cases(StableRNG, Val(:misc)) end diff --git a/test/rrules/misty_closures.jl b/test/rrules/misty_closures.jl new file mode 100644 index 0000000000..c715da5bb9 --- /dev/null +++ b/test/rrules/misty_closures.jl @@ -0,0 +1,104 @@ +mc_foo(x) = 2x + +run_misty_closure(mc::Mooncake.MistyClosure, x::Float64) = mc(x) + +struct Foo + y::Float64 +end + +(f::Foo)(x) = getfield(f, 1) + x + +# Test cases for second derivative computation. +quadratic(x) = x^3 +function low_level_gradient(rrule, f, x::Float64) + _, pb!! = rrule(zero_fcodual(f), zero_fcodual(x)) + return pb!!(1.0)[2] +end + +@testset "misty_closures" begin + + # Construct a sample MistyClosure. + ir = Base.code_ircode_by_type(Tuple{typeof(mc_foo),Float64})[1][1] + ir.argtypes[1] = Any + mc = Mooncake.MistyClosure(ir) + + @testset "tangent interface etc" begin + rng = StableRNG(123456) + TestUtils.test_tangent_interface(rng, mc) + # Do not run the `test_rule_and_type_interactions` test suite for + # `MistyClosure`s as we do not implement rules for `getfield` / `_new_`. + end + + TestUtils.test_rule( + StableRNG(123), + mc, + 5.0; + interface_only=false, + is_primitive=true, + perf_flag=:none, + unsafe_perturb=true, + mode=ForwardMode, + ) + TestUtils.test_rule( + StableRNG(123), + run_misty_closure, + mc, + 5.0; + interface_only=false, + is_primitive=false, + perf_flag=:none, + unsafe_perturb=true, + mode=ForwardMode, + ) + + # Construct a MistyClosure which accesses its captures. We achieve this by collecting + # the IR associated to a callable type, and manipulating the types of various fields to + # ensure that the MistyClosure produced using its `IRCode` is valid. + ir = Base.code_ircode_by_type(Tuple{Foo,Float64})[1][1] + ir.argtypes[1] = Tuple{Float64} + mc2 = Mooncake.MistyClosure(ir, 5.0) + @test mc2(4.0) == 9.0 + TestUtils.test_rule( + StableRNG(123), + mc2, + 4.0; + interface_only=false, + is_primitive=true, + perf_flag=:none, + unsafe_perturb=true, + mode=ForwardMode, + ) + + # Construct a callable which performs reverse-mode, and apply forwards-mode over it. + rule = Mooncake.build_rrule(Tuple{typeof(quadratic),Float64}) + TestUtils.test_rule( + StableRNG(123), + low_level_gradient, + rule, + quadratic, + 5.0; + interface_only=true, + is_primitive=false, + perf_flag=:none, + unsafe_perturb=true, + mode=ForwardMode, + ) + + # Check that the rrule for evaluating `MistyClosure`s errors appropriately. + args = (low_level_gradient, rule, quadratic, 5.0) + higher_rule = Mooncake.build_rrule(args...) + @test_throws(ArgumentError, higher_rule(map(zero_fcodual, args)...)) + + # Manually test that this correctly computes the second derivative. + frule = Mooncake.build_frule( + Mooncake.get_interpreter(Mooncake.ForwardMode), + Tuple{typeof(low_level_gradient),typeof(rule),typeof(quadratic),Float64}, + ) + result = frule( + zero_dual(low_level_gradient), + zero_dual(rule), + zero_dual(quadratic), + Mooncake.Dual(5.0, 1.0), + ) + @test tangent(result) == 6 * 5.0 +end diff --git a/test/rrules/new.jl b/test/rrules/new.jl index 8de3ef5882..79e12f7174 100644 --- a/test/rrules/new.jl +++ b/test/rrules/new.jl @@ -1,4 +1,4 @@ @testset "new" begin - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:new)) + TestUtils.run_rule_test_cases(StableRNG, Val(:new)) include("build_fdata_world_age_regression.jl") end diff --git a/test/rrules/performance_patches.jl b/test/rrules/performance_patches.jl index 5d6f61ae88..38758f5e5a 100644 --- a/test/rrules/performance_patches.jl +++ b/test/rrules/performance_patches.jl @@ -1,3 +1,3 @@ @testset "performance_patches" begin - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:performance_patches)) + TestUtils.run_rule_test_cases(StableRNG, Val(:performance_patches)) end diff --git a/test/rrules/random.jl b/test/rrules/random.jl index 7b1783328c..af8f023202 100644 --- a/test/rrules/random.jl +++ b/test/rrules/random.jl @@ -1,3 +1,3 @@ @testset "randn" begin - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:random)) + TestUtils.run_rule_test_cases(StableRNG, Val(:random)) end diff --git a/test/rrules/tasks.jl b/test/rrules/tasks.jl index 38b6f38383..5d082cbd28 100644 --- a/test/rrules/tasks.jl +++ b/test/rrules/tasks.jl @@ -4,5 +4,5 @@ T = Mooncake.TaskTangent TestUtils.test_tangent(sr(123456), p, T; interface_only=false, perf=false) end - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:tasks)) + TestUtils.run_rule_test_cases(StableRNG, Val(:tasks)) end diff --git a/test/rrules/twice_precision.jl b/test/rrules/twice_precision.jl index 824c1a5733..671eb5a897 100644 --- a/test/rrules/twice_precision.jl +++ b/test/rrules/twice_precision.jl @@ -3,5 +3,5 @@ p = Base.TwicePrecision{Float64}(5.0, 4.0) TestUtils.test_tangent_interface(rng, p) TestUtils.test_tangent_splitting(rng, p) - TestUtils.run_rrule!!_test_cases(StableRNG, Val(:twice_precision)) + TestUtils.run_rule_test_cases(StableRNG, Val(:twice_precision)) end diff --git a/test/runtests.jl b/test/runtests.jl index 48431136d7..cf544605a8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,7 +18,8 @@ include("front_matter.jl") include(joinpath("interpreter", "bbcode.jl")) include(joinpath("interpreter", "ir_normalisation.jl")) include(joinpath("interpreter", "zero_like_rdata.jl")) - include(joinpath("interpreter", "s2s_reverse_mode_ad.jl")) + include(joinpath("interpreter", "forward_mode.jl")) + include(joinpath("interpreter", "reverse_mode.jl")) end include("tools_for_rules.jl") include("interface.jl") @@ -49,6 +50,8 @@ include("front_matter.jl") include(joinpath("rrules", "low_level_maths.jl")) elseif test_group == "rrules/misc" include(joinpath("rrules", "misc.jl")) + elseif test_group == "rrules/misty_closures" + include(joinpath("rrules", "misty_closures.jl")) elseif test_group == "rrules/new" include(joinpath("rrules", "new.jl")) elseif test_group == "rrules/random" diff --git a/test/tools_for_rules.jl b/test/tools_for_rules.jl index 9cd066efe5..1229c07735 100644 --- a/test/tools_for_rules.jl +++ b/test/tools_for_rules.jl @@ -4,51 +4,73 @@ module ToolsForRulesResources # correctly if `Mooncake` is not in scope. using ChainRulesCore, LinearAlgebra using Base: IEEEFloat -using Mooncake: @mooncake_overlay, @zero_adjoint, @from_rrule, MinimalCtx, DefaultCtx +using Mooncake: + @mooncake_overlay, + @zero_derivative, + @from_rrule, + MinimalCtx, + DefaultCtx, + @from_chainrules, + ForwardMode, + ReverseMode + +const CRC = ChainRulesCore local_function(x) = 3x overlay_tester(x) = 2x @mooncake_overlay overlay_tester(x) = local_function(x) zero_tester(x) = 0 -@zero_adjoint MinimalCtx Tuple{typeof(zero_tester),Float64} +@zero_derivative MinimalCtx Tuple{typeof(zero_tester),Float64} vararg_zero_tester(x...) = 0 -@zero_adjoint MinimalCtx Tuple{typeof(vararg_zero_tester),Vararg} +@zero_derivative MinimalCtx Tuple{typeof(vararg_zero_tester),Vararg} + +zero_tester_forward_only(x) = 0 +@zero_derivative MinimalCtx Tuple{typeof(zero_tester_forward_only),Float64} ForwardMode + +zero_tester_reverse_only(x) = 0 +@zero_derivative MinimalCtx Tuple{typeof(zero_tester_reverse_only),Float64} ReverseMode # Test case with isbits data. bleh(x::Float64, y::Int) = x * y -function ChainRulesCore.rrule(::typeof(bleh), x::Float64, y::Int) - return x * y, dz -> (ChainRulesCore.NoTangent(), dz * y, ChainRulesCore.NoTangent()) +CRC.frule((_, dx, _), ::typeof(bleh), x::Float64, y::Int) = x * y, dx * y + +function CRC.rrule(::typeof(bleh), x::Float64, y::Int) + return x * y, dz -> (CRC.NoTangent(), dz * y, CRC.NoTangent()) end -@from_rrule DefaultCtx Tuple{typeof(bleh),Float64,Int} false +@from_chainrules DefaultCtx Tuple{typeof(bleh),Float64,Int} false # Test case with heap-allocated input. test_sum(x) = sum(x) -function ChainRulesCore.rrule(::typeof(test_sum), x::AbstractArray{<:Real}) - test_sum_pb(dy::Real) = ChainRulesCore.NoTangent(), fill(dy, size(x)) +CRC.frule((_, dx), ::typeof(test_sum), x::AbstractArray{<:Real}) = sum(x), sum(dx) + +function CRC.rrule(::typeof(test_sum), x::AbstractArray{<:Real}) + test_sum_pb(dy::Real) = CRC.NoTangent(), fill(dy, size(x)) return test_sum(x), test_sum_pb end -@from_rrule DefaultCtx Tuple{typeof(test_sum),Array{<:Base.IEEEFloat}} false +@from_chainrules DefaultCtx Tuple{typeof(test_sum),Array{<:Base.IEEEFloat}} false # Test case with heap-allocated output. test_scale(x::Real, y::AbstractVector{<:Real}) = x * y -function ChainRulesCore.rrule(::typeof(test_scale), x::Real, y::AbstractVector{<:Real}) - function test_scale_pb(dout::AbstractVector{<:Real}) - return ChainRulesCore.NoTangent(), dot(dout, y), dout * x - end +function CRC.frule((_, dx, dy), ::typeof(test_scale), x::Real, y::AbstractVector{<:Real}) + return x * y, dx * y + x * dy +end + +function CRC.rrule(::typeof(test_scale), x::Real, y::AbstractVector{<:Real}) + test_scale_pb(dout::AbstractVector{<:Real}) = CRC.NoTangent(), dot(dout, y), dout * x return x * y, test_scale_pb end -@from_rrule( +@from_chainrules( DefaultCtx, Tuple{typeof(test_scale),Base.IEEEFloat,Vector{<:Base.IEEEFloat}}, false ) @@ -56,12 +78,14 @@ end test_nothing() = nothing -function ChainRulesCore.rrule(::typeof(test_nothing)) - test_nothing_pb(::ChainRulesCore.NoTangent) = (ChainRulesCore.NoTangent(),) +CRC.frule(_, ::typeof(test_nothing)) = (nothing, CRC.NoTangent()) + +function CRC.rrule(::typeof(test_nothing)) + test_nothing_pb(::CRC.NoTangent) = (CRC.NoTangent(),) return nothing, test_nothing_pb end -@from_rrule DefaultCtx Tuple{typeof(test_nothing)} false +@from_chainrules DefaultCtx Tuple{typeof(test_nothing)} false # Test case in which ChainRulesCore returns a tangent which is of the "wrong" type from the # perspective of Mooncake.jl. In this instance, some kind of error should be thrown, rather @@ -69,8 +93,8 @@ end test_bad_rdata(x::Real) = 5x -function ChainRulesCore.rrule(::typeof(test_bad_rdata), x::Float64) - test_bad_rdata_pb(dy::Float64) = ChainRulesCore.NoTangent(), Float32(dy * 5) +function CRC.rrule(::typeof(test_bad_rdata), x::Float64) + test_bad_rdata_pb(dy::Float64) = CRC.NoTangent(), Float32(dy * 5) return 5x, test_bad_rdata_pb end @@ -78,8 +102,8 @@ end # Test case for rule with diagonal dispatch. test_add(x, y) = x + y -function ChainRulesCore.rrule(::typeof(test_add), x, y) - test_add_pb(dout) = ChainRulesCore.NoTangent(), dout, dout +function CRC.rrule(::typeof(test_add), x, y) + test_add_pb(dout) = CRC.NoTangent(), dout, dout return x + y, test_add_pb end @from_rrule DefaultCtx Tuple{typeof(test_add),T,T} where {T<:IEEEFloat} false @@ -87,22 +111,30 @@ end # Test case for rule with non-differentiable kwargs. test_kwargs(x; y::Bool=false) = y ? x : 2x -function ChainRulesCore.rrule(::typeof(test_kwargs), x::Float64; y::Bool=false) - test_kwargs_pb(dz::Float64) = ChainRulesCore.NoTangent(), y ? dz : 2dz +function CRC.frule((_, dx), ::typeof(test_kwargs), x::Float64; y::Bool=false) + return test_kwargs(x; y), y ? dx : 2dx +end + +function CRC.rrule(::typeof(test_kwargs), x::Float64; y::Bool=false) + test_kwargs_pb(dz::Float64) = CRC.NoTangent(), y ? dz : 2dz return y ? x : 2x, test_kwargs_pb end -@from_rrule(DefaultCtx, Tuple{typeof(test_kwargs),Float64}, true) +@from_chainrules(DefaultCtx, Tuple{typeof(test_kwargs),Float64}, true) # Test case for rule with differentiable types used in a non-differentiable way. test_kwargs_conditional(x; y::Float64=1.0) = y > 0 ? x : 2x -function ChainRulesCore.rrule(::typeof(test_kwargs_conditional), x::Float64; y::Float64=1.0) - test_kwargs_cond_pb(dz::Float64) = ChainRulesCore.NoTangent(), y > 0 ? dz : 2dz +function CRC.frule((_, dx), ::typeof(test_kwargs_conditional), x::Float64; y::Float64=1.0) + return test_kwargs_conditional(x; y), y > 0 ? dx : 2dx +end + +function CRC.rrule(::typeof(test_kwargs_conditional), x::Float64; y::Float64=1.0) + test_kwargs_cond_pb(dz::Float64) = CRC.NoTangent(), y > 0 ? dz : 2dz return y > 0 ? x : 2x, test_kwargs_cond_pb end -@from_rrule(DefaultCtx, Tuple{typeof(test_kwargs_conditional),Float64}, true) +@from_chainrules(DefaultCtx, Tuple{typeof(test_kwargs_conditional),Float64}, true) end @@ -112,8 +144,7 @@ end rule = Mooncake.build_rrule(Tuple{typeof(f),Float64}) @test value_and_gradient!!(rule, f, 5.0) == (15.0, (NoTangent(), 3.0)) end - @testset "zero_adjoint" begin - f_zero = ToolsForRulesResources + @testset "zero_derivative" begin test_rule( sr(123), ToolsForRulesResources.zero_tester, @@ -129,6 +160,22 @@ end is_primitive=true, perf_flag=:stability_and_allocs, ) + + perf_flag = :stability_and_allocs + @testset "forward mode only" begin + sig = Tuple{typeof(ToolsForRulesResources.zero_tester_forward_only),Float64} + @test is_primitive(MinimalCtx, ForwardMode, sig) + @test !is_primitive(MinimalCtx, ReverseMode, sig) + args = (ToolsForRulesResources.zero_tester_forward_only, 5.0) + test_rule(sr(123), args...; is_primitive=true, perf_flag, mode=ForwardMode) + end + @testset "reverse mode only" begin + sig = Tuple{typeof(ToolsForRulesResources.zero_tester_reverse_only),Float64} + @test !is_primitive(MinimalCtx, ForwardMode, sig) + @test is_primitive(MinimalCtx, ReverseMode, sig) + args = (ToolsForRulesResources.zero_tester_reverse_only, 5.0) + test_rule(sr(123), args...; is_primitive=true, perf_flag, mode=ReverseMode) + end end @testset "chain_rules_macro" begin @testset "to_cr_tangent" for (t, t_cr) in Any[ @@ -148,6 +195,15 @@ end ] @test Mooncake.to_cr_tangent(t) == t_cr end + @testset "mooncake_tangent($(typeof(p)), $(typeof(t)))" for (p, t) in Any[ + (5, ChainRulesCore.NoTangent()), + (5.0, 4.0), + (randn(5), randn(5)), + ([randn(5)], [randn(5)]), + ((5.0, 4), (4.0, ChainRulesCore.NoTangent())), + ] + @test Mooncake.mooncake_tangent(p, t) isa tangent_type(typeof(p)) + end @testset "rules: $(typeof(fargs))" for fargs in Any[ (ToolsForRulesResources.bleh, 5.0, 4), (ToolsForRulesResources.test_sum, ones(5)), @@ -160,7 +216,10 @@ end (Core.kwcall, (y=1.0,), ToolsForRulesResources.test_kwargs_conditional, 5.0), (ToolsForRulesResources.test_kwargs_conditional, 5.0), ] - test_rule(sr(1), fargs...; perf_flag=:stability, is_primitive=true) + test_rule(sr(1), fargs...; perf_flag=:none, is_primitive=true, mode=ForwardMode) + test_rule( + sr(1), fargs...; perf_flag=:stability, is_primitive=true, mode=ReverseMode + ) end @testset "bad rdata" begin f = ToolsForRulesResources.test_bad_rdata