diff --git a/Project.toml b/Project.toml index cdb08f1db8..e28c7e8c70 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,6 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -71,7 +70,6 @@ ExprTools = "0.1" Flux = "0.16.3" FunctionWrappers = "1.1.3" GPUArraysCore = "0.1, 0.2" -Graphs = "1" JET = "0.9, 0.10, 0.11" LinearAlgebra = "1" LogExpFunctions = "0.3" diff --git a/docs/make.jl b/docs/make.jl index 0700e66885..fea767b3fa 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -40,6 +40,7 @@ makedocs(; joinpath("understanding_mooncake", "introduction.md"), joinpath("understanding_mooncake", "algorithmic_differentiation.md"), joinpath("understanding_mooncake", "rule_system.md"), + joinpath("understanding_mooncake", "what_programme_are_you_differentiating.md"), ], "Utilities" => [ joinpath("utilities", "defining_rules.md"), diff --git a/docs/src/assets/computation_graph.png b/docs/src/assets/computation_graph.png new file mode 100644 index 0000000000..ceb449e8c1 Binary files /dev/null and b/docs/src/assets/computation_graph.png differ diff --git a/docs/src/developer_documentation/advanced_debugging.md b/docs/src/developer_documentation/advanced_debugging.md index 60805736da..3cf52bc940 100644 --- a/docs/src/developer_documentation/advanced_debugging.md +++ b/docs/src/developer_documentation/advanced_debugging.md @@ -44,15 +44,15 @@ show_world_info(ins) ### Reverse mode stages -`:raw` → `:normalized` → `:bbcode` → `:fwd_ir` / `:rvs_ir` → `:optimized_fwd` / `:optimized_rvs` +`:raw` → `:normalized` → `:cfg_blocks` → `:fwd_ir` / `:rvs_ir` → `:optimized_fwd` / `:optimized_rvs` ### Forward mode stages -`:raw` → `:normalized` → `:dual_ir` → `:optimized` +`:raw` → `:normalized` → `:cfg_blocks` → `:dual_ir` → `:optimized` !!! note - The inspection tool also shows a `:bbcode` stage for cross-mode comparison, - but forward mode does not use BBCode internally. + The inspection tool also shows a `:cfg_blocks` stage for cross-mode comparison, + but forward mode does not use `CFGBlock`s internally. !!! note Primitive signatures such as `sin` do not generate AD IR stages here. Mooncake diff --git a/docs/src/developer_documentation/forwards_mode_design.md b/docs/src/developer_documentation/forwards_mode_design.md index 75aad66f9a..9110594bad 100644 --- a/docs/src/developer_documentation/forwards_mode_design.md +++ b/docs/src/developer_documentation/forwards_mode_design.md @@ -334,7 +334,7 @@ Then, instead of propagating a "primal-tangent" pairs via `Dual`s, we propagate The implementation of forwards-mode AD is quite dramatically simpler than that of reverse-mode AD. Some notable technical differences include: 1. forwards-mode AD only makes use of the tangent system, whereas reverse-mode also makes use of the fdata / rdata system. -1. forwards-mode AD comprises only line-by-line transformations of the `IRCode`. In particular, it does not require the insertion of additional basic blocks, nor the modification of the successors / predecessors of any given basic block. Consequently, there is no need to make use of the `BBCode` infrastructure built up for reverse-mode AD -- everything can be straightforwardly done at the `Compiler.IRCode` level. +1. forwards-mode AD comprises only line-by-line transformations of the `IRCode`. In particular, it does not require the insertion of additional basic blocks, nor the modification of the successors / predecessors of any given basic block. Consequently, there is no need to make use of the builder-local CFG machinery used in reverse mode -- everything can be straightforwardly done at the `Compiler.IRCode` level. ## Comparison with ForwardDiff.jl diff --git a/docs/src/developer_documentation/ir_representation.md b/docs/src/developer_documentation/ir_representation.md index f69d6326ca..42d1516045 100644 --- a/docs/src/developer_documentation/ir_representation.md +++ b/docs/src/developer_documentation/ir_representation.md @@ -1,17 +1,25 @@ # IR Representations and Code Transformations -Mooncake.jl works by transforming Julia's SSA-form (static single assignment) Intermediate Representation (IR), so a good understanding of Julia's IR is needed to understand Mooncake. -Furthermore, Mooncake holds Julia's IR in a different data structure than the one usually used when producing code for reverse-mode AD. -We discuss both data structures below, and provide examples of the kinds of transformations which must be applied to Julia's IR in order to implement AD, contrasting the two different data structures. +Mooncake works by transforming Julia's SSA-form Intermediate Representation (IR), so a good +working model of that IR is useful when touching the interpreter. -Please note that Julia's SSA-form IR typically changes representation slightly between minor versions of Julia, as it's not part of the public interface of the language. -The information below is accurate on version 1.11.4, but you might well find that things are slightly different on different versions. +Please note that Julia's SSA-form IR changes slightly across minor versions, because it is +not a public language interface. The examples below are representative rather than +version-stable. -## Julia's SSA-form IR +Before looking at the printed IR, keep three ideas in mind: + +1. Each SSA statement produces one named value such as `%1` or `%2`. +2. Control flow is organized into basic blocks with branch or return terminators. +3. The compiler stores the statements and the control-flow graph separately, and Mooncake has + to keep those two views coherent when transforming code. + +## Julia's SSA-Form IR ### Straight-Line Code You can find the IR associated to a given signature using `Base.code_ircode_by_type`: + ```jldoctest julia> function foo(x) y = sin(x) @@ -29,25 +37,18 @@ julia> Base.code_ircode_by_type(signature)[1][1] 4 └── return %2 ``` -What you can see here is that the calls to `sin` and `cos` in the original function are associated to a number, denoted `%1` and `%2`. We refers to these as the "ssa"s associated to each statement. -Each statement is associated to a single ssa, and this association is determined by where it appears in the list of statements -- the first statement is associated to `%1`, the second to `%2`, and so on. -You will also notice that the argument `x` has been replaced with a `_2` in the first statement -- in general, all uses of the `n`th argument are indicated by `_n` (the first argument is the function itself). -The final statement requires no explanation. - -Note that this IR is obtained after both type inference and various Julia-level optimisation passes. -This means that the type information is available for each statement. -For example, the `::Float64` at the end of the first and second statements indicates that the type of `%1` and `%2` is always `Float64`. -The types are also displayed at uses -- the call to `sin` involves `_2::Float64`, not just `_2`. +The statements are associated to SSA names such as `%1` and `%2`. Each statement is +associated to a single SSA value, and uses of arguments are written as `_n`, where `_1` is +the function itself. -Additionally notice that the statements are `invoke` statements, rather than just call statements. -In Julia's IR, an `invoke` statement represents static dispatch to a particular `MethodInstance` -- i.e. running type inference + optimisation passes has determined enough about the argument types to make it possible to know exactly which `MethodInstance` of `sin` and `cos` to call. -This is a very common occurrence in type-stable code. +This IR is obtained after type inference and some optimisation passes, so each statement +already carries type information. In the example above, `%1` and `%2` are both known to be +`Float64`. ### Control Flow -The above is straight-line code -- it does not involve any control flow. -Julia has several statements which are involved in handling control flow. -For example +Control flow is expressed via basic blocks and terminators: + ```jldoctest bar julia> function bar(x) if x > 0 @@ -58,7 +59,7 @@ julia> function bar(x) end bar (generic function with 1 method) -julia> Base.code_ircode_by_type(Tuple{typeof(bar),Float64})[1][1] +julia> Base.code_ircode_by_type(Tuple{typeof(bar), Float64})[1][1] 2 1 ─ %1 = intrinsic Base.lt_float(0.0, _2)::Bool │ %2 = intrinsic Base.or_int(%1, false)::Bool └── goto #3 if not %2 @@ -66,46 +67,24 @@ julia> Base.code_ircode_by_type(Tuple{typeof(bar),Float64})[1][1] 5 3 ─ %5 = intrinsic Base.mul_float(5.0, _2)::Float64 └── return %5 ``` -In this example we see the statement `goto #3 if not %2`. -This should be read as "jump to basic block 3 if %2 is `false`". -The second half of that statement should be clear, but to understand the first half requires knowing what a basic block is: -```julia - 1 ─ - │ - └── - 2 ─ - 3 ─ - └── -``` -Here, everything is removed from the above example except for information about the basic block structure. -To first approximation, each basic block is a sequence of statements which _must_ always execute one after the other. -Once all statements in a basic block have run, we typically either jump to another basic block, or hit a `return` statement. -In this example, we have three basic blocks -- you can see this from the numbers `1`, `2`, and `3`. -The first basic block comprises three statements, the second only one statement, and the third two statements. -Another way to investigate this structure is to look at the control-flow graph associated to the IR: + +The corresponding control-flow graph is stored separately in the `cfg` field: + ```jldoctest bar -julia> Base.code_ircode_by_type(Tuple{typeof(bar),Float64})[1][1].cfg +julia> Base.code_ircode_by_type(Tuple{typeof(bar), Float64})[1][1].cfg CFG with 3 blocks: bb 1 (stmts 1:3) → bb 3, 2 bb 2 (stmt 4) bb 3 (stmts 5:6) ``` -For example, the above states that "bb" (basic block) 1 comprises statements 1 to 3, and has successor blocks 2 and 3 (ie. once the statements in basic block 1 have executed, we know for certain that either those in block 2 or block 3 will run next). -Blocks 2 and 3 have no successors, because they both end in a `return` statement. -The predecessors of each basic block (the blocks which could possibly have run immediately prior to a given block) are also stored in the blocks of the `CFG`, even though this is not printed -- you should have a play around with this data structure to see what is in there. -Additionally, note that `Base.lt_float` (used to check if one floating point number is less than another) and `Base.or_int` do not appear as `invoke` statements -- this is because they are not generic Julia functions. -Rather, they are Julia intrinsics: -```jldoctest -julia> Base.lt_float -lt_float (intrinsic function #33) -``` -These intrinsics have special handling in the compiler. -Either way, the overall point is to be aware that these kinds of low-level intrinsics exist, and appear regularly in Julia IR. +Each basic block is a straight-line region that ends either by falling through, branching, +or returning. + +### Simple Loops and Phi Nodes -### Simple Loops and Phi-Nodes +Loops introduce phi nodes: -Finally, we shall consider a simple loop: ```jldoctest my_factorial julia> function my_factorial(x::Int) n = 0 @@ -129,35 +108,25 @@ julia> ir = Base.code_ircode_by_type(Tuple{typeof(my_factorial), Int})[1][1] 7 └── goto #2 8 4 ─ return %2 ``` -There are a few new intrinsics that we have not seen previously (`Base.slt_int` (used to check whether one int is strictly less than another), `Base.add_int`, and `Base.mul_int`). -Additionally, there is the node `goto #2`, which simply states that control flow should jump to basic block 2 whenever it is hit. -The most interesting additional nodes, however, are the two `φ` (phi) nodes. -These are a defining feature of SSA-form IR. Consider the first `φ` node: +For example, + ```julia %2 = φ (#1 => 1, #3 => %7) ``` -means ssa `%2` takes value `1` if the previous basic block was `#1`, and whatever value is currently associated to ssa `%7` if the previous basic block was `#3`. -It is helpful to step through this code in your head: upon calling `my_factorial` we enter basic block `#1`, and proceed directly to basic block `#2`. -Therefore, on the first iteration, `%2` takes value `1`. We never return to basic block `#1`, so all subsequent visits to this `φ` node will result in `%2` taking the value associated to `%7`. -You should convince yourself that `%2` corresponds to the value of `s` at each iteration, and `%3` corresponds to the value of `n` at each iteration. +means `%2` takes value `1` when control arrives from block `#1`, and the value of `%7` +when control arrives from block `#3`. -### Summary +## Julia Compiler's IR Datastructure -Julia's SSA-form IR comprises a sequence of statements, which can be broken down into a collection of basic blocks. -Each basic block begins with a (potentially empty) collection of phi nodes, followed by a sequence of statements, and potentially finished by a _terminator_ (goto, goto-if-not, return). -Control flow is dictated by the terminators at the end of basic blocks -- if there is no terminator then we "fall through" to the next basic block. +The compiler represents inferred IR as `Core.Compiler.IRCode`. The statements live in the +`stmts` field, which is a `Core.Compiler.InstructionStream`. An `InstructionStream` is a +bundle of parallel vectors: the statement itself, its inferred type, call info, line data, +and flags. -## Julia Compiler's IR Datastructure +For example: -The Julia compiler represents the IR associated to a signature via a `struct` called `Core.Compiler.IRCode`. -The statements are given by the `stmts` field, which is a `Core.Compiler.InstructionStream`. -An `InstructionStream` is a collection of 5 `Vector`s, each of which have the same length. -The properties of the `n`th statement in the `IR` are given by the `n`th element of each of these vectors. -For example, the `stmt` field contains the statement itself, the `type` field contains the inferred type associated to the statement. -We'll skip the rest for now. -For example, the statements associated to the `my_factorial` function above can be retrieved as follows: ```jldoctest my_factorial julia> ir.stmts.stmt 9-element Vector{Any}: @@ -170,9 +139,7 @@ julia> ir.stmts.stmt :(Base.mul_int(%2, %6)) :(goto %2) :(return %2) -``` -The types can be accessed in a similar way: -```jldoctest my_factorial + julia> ir.stmts.type 9-element Vector{Any}: Nothing @@ -186,103 +153,21 @@ julia> ir.stmts.type Any ``` -As seen in [Control Flow](@ref), the control flow graph (CFG) is represented as a separate data structure, stored in the `cfg` field of the `IRCode`. -The argument types associated to the signature are stored in the `argtypes` field of the `IRCode`. - -## An Alternative IR Datastructure - -`IRCode` is a perfectly good way to represent Julia's IR the vast majority of the time. -For example, it suffices for the code transformations required for forwards-mode AD. -However, IR transformations involving multiple changes to the control flow structure of a programme are needed in reverse-mode, and are prohibitively awkward to undertake using `IRCode`. -Mooncake's implementation of reverse-mode AD instead makes use of a custom representation of Julia's IR, called `BBCode`. -We emphasise that `BBCode` represents the _same_ thing under the hood, it is just represented in memory in a slightly different way, such that certain kinds of transformations are straightforward to implement. - -You can construct a `BBCode` from an `IRCode`, and vice versa: -```jldoctest my_factorial -julia> using Mooncake: BBCode - -julia> bb_ir = BBCode(ir); - -julia> bb_ir isa BBCode -true - -julia> Core.Compiler.IRCode(bb_ir) - 1 ─ nothing::Nothing -4 2 ┄ %2 = φ (#1 => 1, #3 => %7)::Int64 - │ %3 = φ (#1 => 0, #3 => %6)::Int64 - │ %4 = intrinsic Base.slt_int(%3, _2)::Bool - └── goto #4 if not %4 -5 3 ─ %6 = intrinsic Base.add_int(%3, 1)::Int64 -6 │ %7 = intrinsic Base.mul_int(%2, %6)::Int64 -7 └── goto #2 -8 4 ─ return %2 -``` -At present, `BBCode` does not display itself nicely, so to look at it we must either inspect its fields, or convert it back to an `IRCode` (which _does_ print nicely). - -Instead of storing all of the statements in a single vector (and the types in their own vector, etc), `BBCode` stores all statements associated to a particular basic block in a `Mooncake.BBlock`, and stores these in a `Vector{Mooncake.BBlock}`. -```jldoctest my_factorial -julia> typeof(bb_ir.blocks) -Vector{BBlock} (alias for Array{Mooncake.BasicBlockCode.BBlock, 1}) -``` -Each `BBlock` has a field `insts`, containing the statements associated to that basic block. -This is stored as a `Vector{Core.Compiler.NewInstruction}`, because `Core.Compiler.NewInstruction` contains the 5 fields that define an instruction in `IRCode` (you should compare the fields of a `Core.Compiler.NewInstruction` with those of `Core.Compiler.InstructionStream` to see the correspondence). -For example, consider -```jldoctest my_factorial -julia> using Mooncake.BasicBlockCode: ID # to improve printing - -julia> bb_ir.blocks[3].insts[1] -Compiler.NewInstruction(:(Base.add_int(ID(2), 1)), Int64, Compiler.NoCallInfo(), (0, 0, 0), 0x00002478) -``` -This is the first instruction of the third basic block. -The first field is a call to `Base.add_int`, the second field is `Int64` (we promise that the other fields are just copies of the corresponding data from the `Core.Compiler.InstructionStream` in the original `IRCode` representation of this IR). - -The other structural difference is that `BBCode` has no field containing the control-flow graph. -Instead, the control-flow graph is represented implicitly as part of the `blocks` field. -The upside of this is that any transformations of `blocks` which modify the CFG are automatically reflected in the `blocks` -- there is no need to perform any book-keeping to ensure that the CFG is kept in sync with the instructions. -This saves both time and memory when inserting new basic blocks -- when basic block structure changes, a scan of the entire `IRCode` is required to modify any statements which refer to a given block, and yields code simplifications. -The downside is that the CFG must be computed whenever we need to know about it. -As a resut, neither `IRCode` nor `BBCode`'s representation of the CFG is strictly better than the other. -To extract CFG-related information from a `BBCode`, see [`Mooncake.BasicBlockCode.compute_all_successors`](@ref), [`Mooncake.BasicBlockCode.compute_all_predecessors`](@ref), and [`Mooncake.BasicBlockCode.control_flow_graph`](@ref). +The control-flow graph is stored separately in `ir.cfg`, and the argument types are stored +in `ir.argtypes`. +## Code Transformations -The final major difference between `IRCode` and `BBCode` is that all ssa values in an `IRCode` (`%1`, `%2`, `%n`, etc) are replaced with unique `ID`s. The `ID` associated to a statement is stored separately from the statement in the `inst_ids` field of a `BBlock`: -```jldoctest my_factorial -julia> bb_ir.blocks[3].inst_ids -3-element Vector{ID}: - ID(5) - ID(6) - ID(7) -``` -There is exactly one `ID` per instruction, and it is an error to have the same `ID` associated to multiple instructions. -Similarly, while the number associated to a basic block in `IRCode` is a function of the number of basic blocks which precede it, the `ID` of a basic block in `BBCode` is stored in its `id` field: -```jldoctest my_factorial -julia> bb_ir.blocks[3].id -ID(11) -``` -As a result of this, all references to ssa values and basic block numbers in `IRCode` are replaced with `ID`s in `BBCode`. -The purpose of this is to guarantee that the "name" of a basic block and an instruction does not change when you insert new basic blocks and new instructions. -We shall see how this is useful in the examples below. +Mooncake uses two broad styles of transformation: -## Code Transformations +1. Straight-line edits on `IRCode`, especially in forward mode. +2. Reverse-mode assembly through a builder-local CFG in `reverse_mode.jl`, followed by a + final lowering step back to coherent `IRCode`. -In what follows, we look at a few transformations of Julia's IR, and see how these can be undertaken using both `IRCode` and `BBCode`. -The purpose is two-fold: -1. to enable readers to understand the code used to implement Mooncake, and -2. to highlight the relative merits of `IRCode` vs `BBCode`. +### Replacing Instructions in `IRCode` -### Replacing Instructions +Replacing one statement with another is straightforward: -This is a very simple code transformation. -It is used in both forwards-mode and reverse-mode in Mooncake to replace calls of the form -```julia -f(x, y, z) -``` -with calls of the form -```julia -frule!!(f, x, y, z) -``` -This kind of transformation is performed in basically the same way for both `IRCode` and `BBCode`. -For example, the `mul_int` statement associated to ssa `%7` can be replaced with an `add_int` statement as follows: ```jldoctest my_factorial julia> using Core: SSAValue @@ -296,8 +181,7 @@ julia> old_stmt = new_ir.stmts.stmt[7] julia> new_stmt = Expr(:call, Base.add_int, old_stmt.args[2:end]...) :((Core.Intrinsics.add_int)(%2, %6)) -julia> # new_ir[SSAValue(7)][:stmt] = new_stmt - CC.setindex!(CC.getindex(new_ir, SSAValue(7)), new_stmt, :stmt); +julia> CC.setindex!(CC.getindex(new_ir, SSAValue(7)), new_stmt, :stmt); julia> new_ir 1 ─ nothing::Nothing @@ -310,44 +194,14 @@ julia> new_ir 7 └── goto #2 8 4 ─ return %2 ``` -Observe that ssa `7` has been replaced with the new `:call` to `add_int`. -Unfortunately, in order to avoid committing type-piracy against `Core.Compiler`, we cannot currently write `new_ir[SSAValue(7)][:stmt]`. (`CC.getindex` is a different function from `Base.getindex` -- the same is true for `CC.setindex!` vs `Base.setindex!`). -In general, I would recommend defining helper functions to improve the DRYness of your code. - -The same transformation can be performed on `BBCode`: -```jldoctest my_factorial -julia> bb_ir_copy = copy(bb_ir); - -julia> old_inst = bb_ir_copy.blocks[3].insts[2] -Compiler.NewInstruction(:(Base.mul_int(ID(1), ID(5))), Int64, Compiler.NoCallInfo(), (3, 0, 0), 0x00002478) - -julia> new_stmt = Expr(:call, Base.add_int, old_inst.stmt.args[2:end]...) -:((Core.Intrinsics.add_int)(ID(1), ID(5))) -julia> bb_ir_copy.blocks[3].insts[2] = CC.NewInstruction(old_inst; stmt=new_stmt); - -julia> CC.IRCode(bb_ir_copy) - 1 ─ nothing::Nothing -4 2 ┄ %2 = φ (#1 => 1, #3 => %7)::Int64 - │ %3 = φ (#1 => 0, #3 => %6)::Int64 - │ %4 = intrinsic Base.slt_int(%3, _2)::Bool - └── goto #4 if not %4 -5 3 ─ %6 = intrinsic Base.add_int(%3, 1)::Int64 -6 │ %7 = intrinsic (Core.Intrinsics.add_int)(%2, %6)::Int64 -7 └── goto #2 -8 4 ─ return %2 -``` -As you can see, in both cases we wind up with the same `IRCode` at the end. +This is the kind of local transformation that forward mode relies on heavily. -### Inserting New Instructions +### Inserting New Instructions in `IRCode` -Inserting entirely new instructions into the IR requires a little more thought, but is ultimately very straightforward using either `IRCode` or `BBCode`. +Insertion requires a little more care because later SSA names may need to shift. `IRCode` +handles this through `insert_node!` plus a later `compact!`: -First, `IRCode`. -Suppose that we wish to insert another instruction immediately before the first `add_int` instruction which multiplies `%3` by 2 before adding `1` to it in `#3`. -In `IRCode`, this kind of modification requires some care, because naively inserting an instruction between the 5th and 6th line changes the name of all instructions from the 6th onwards. -Consequently, we need to replace all existing uses of e.g. `%6` with uses of `%7`, etc. -Happily, `IRCode` has a mechanism to achieve just this. ```jldoctest my_factorial julia> ni = CC.NewInstruction(Expr(:call, Base.mul_int, SSAValue(3), 2), Int) Compiler.NewInstruction(:((Core.Intrinsics.mul_int)(%3, 2)), Int64, Compiler.NoCallInfo(), nothing, nothing) @@ -355,46 +209,11 @@ Compiler.NewInstruction(:((Core.Intrinsics.mul_int)(%3, 2)), Int64, Compiler.NoC julia> new_ssa = CC.insert_node!(new_ir, SSAValue(6), ni) :(%10) -julia> new_ir - 1 ─ nothing::Nothing -4 2 ┄ %2 = φ (#1 => 1, #3 => %7)::Int64 - │ %3 = φ (#1 => 0, #3 => %6)::Int64 - │ %4 = intrinsic Base.slt_int(%3, _2)::Bool - └── goto #4 if not %4 -5 3 ─ intrinsic (Core.Intrinsics.mul_int)(%3, 2)::Int64 - │ %6 = intrinsic Base.add_int(%3, 1)::Int64 -6 │ %7 = intrinsic (Core.Intrinsics.add_int)(%2, %6)::Int64 -7 └── goto #2 -8 4 ─ return %2 -``` -`CC.insert_node!(ir, ssa, new_inst)` inserts `new_inst` into `ir` immediately before `ssa`, and attaches it to the same basic block as `ssa` resides. -It returns an `SSAValue`, which is the "name" associated to the inserted instruction in the IR. -Here, we see it has inserted the instruction to multiply `%3` by `2` immediately before `%6`. -However, observe that the `IRCode` has not changed the name associated to the subsequent `add_int` instruction -- it still assigns to `%6`, despite not being the 6th statement in the IR anymore. -This is achieved via `IRCode`'s `new_nodes` field -- upon calling `CC.insert_node!`, rather than inserting the instruction directly into the `InstructionStream`, this list is appended to. -We can do this as many times as we like, and then call `CC.compact!` at the end to handle all of the book-keeping involved in inserting all of the statements, updating all ssa uses where required, and updating the `cfg` field of the IR. - -Also observe that the inserted statement is printed without a `%10 =` at the start of it -- this is because there are not (yet) any uses of `%10`, so `IRCode` does not print it out (presumably in order to reduce visual noise). - -To conclude this transformation, we replace the first argument of the `add_int` instruction with the new ssa returned by `insert_node!`, and then call `CC.compact!` to process all of the nodes currently in the `new_nodes` list, and produce a valid `IRCode`: -```jldoctest my_factorial julia> stmt = CC.getindex(CC.getindex(new_ir, SSAValue(6)), :stmt) :(Base.add_int(%3, 1)) julia> stmt.args[2] = new_ssa; -julia> new_ir - 1 ─ nothing::Nothing -4 2 ┄ %2 = φ (#1 => 1, #3 => %7)::Int64 - │ %3 = φ (#1 => 0, #3 => %6)::Int64 - │ %4 = intrinsic Base.slt_int(%3, _2)::Bool - └── goto #4 if not %4 -5 3 ─ %10 = intrinsic (Core.Intrinsics.mul_int)(%3, 2)::Int64 - │ %6 = intrinsic Base.add_int(%10, 1)::Int64 -6 │ %7 = intrinsic (Core.Intrinsics.add_int)(%2, %6)::Int64 -7 └── goto #2 -8 4 ─ return %2 - julia> new_ir = CC.compact!(new_ir) 1 ─ nothing::Nothing 4 2 ┄ %2 = φ (#1 => 1, #3 => %8)::Int64 @@ -407,164 +226,43 @@ julia> new_ir = CC.compact!(new_ir) 7 └── goto #2 8 4 ─ return %2 ``` -Observe that, before `compact!`-ing, the first instruction in basic block `#3` is still labelled as being `%10`. -After `compact!`-ing, we have standard sequentially-labelled IR again. -Note that the above is exactly the kind of thing that we do in our implementation of forwards-mode AD -- all insertions of nodes are performed in a single pass over the `IRCode`, and `CC.compact!` is called once at the end. - -Performing this transformation using `BBCode` is similarly straightforward. -Since the name associated to instructions does not change when you insert another instruction, you really just need to insert an instruction + its `ID`, update the next instruction (as before), and you're done: -```jldoctest my_factorial -julia> using Mooncake.BasicBlockCode: ID, new_inst - -julia> new_id = ID(); # this produces a new unique `ID`. - -julia> target_id = bb_ir_copy.blocks[3].insts[1].stmt.args[2]; # find `ID` of argument to add_int. - -julia> ni = new_inst(Expr(:call, Base.mul_int, target_id, 2), Int); - -julia> insert!(bb_ir_copy.blocks[3], 1, new_id, ni) - -julia> bb_ir_copy.blocks[3].insts[2].stmt.args[2] = new_id; - -julia> CC.IRCode(bb_ir_copy) - 1 ─ nothing::Nothing -4 2 ┄ %2 = φ (#1 => 1, #3 => %8)::Int64 - │ %3 = φ (#1 => 0, #3 => %7)::Int64 - │ %4 = intrinsic Base.slt_int(%3, _2)::Bool - └── goto #4 if not %4 -5 3 ─ %6 = intrinsic (Core.Intrinsics.mul_int)(%3, 2)::Int64 -6 │ %7 = intrinsic Base.add_int(%6, 1)::Int64 -7 │ %8 = intrinsic (Core.Intrinsics.add_int)(%2, %7)::Int64 -8 └── goto #2 - 4 ─ return %2 -``` -We see here that `IRCode` and `BBCode` involve similar levels of complexity to insert an instruction. - -### Inserting New Basic Blocks - -This is the situation in which the design of `BBCode` shines vs `IRCode`. -`IRCode` does not, at present, really have much to say about transformations which change control flow. -It is, however, straightforward using `BBCode`. -Suppose that we wish to modify the above to display the value of `%2` if it is even on any given iteration. -Since this involves control flow, it necessarily requires at least one additional basic block. - -We do this in two steps. -We first insert an additional basic block between blocks `#3` and `#4` which always prints out the value of `%2`, and then goes to block `#2`: -```jldoctest my_factorial -julia> using Mooncake.BasicBlockCode: BBlock, new_inst, IDGotoNode, IDGotoIfNot - -julia> block_2_id = bb_ir_copy.blocks[2].id; - -julia> new_bb_id = ID(); - -julia> new_bb = BBlock( - new_bb_id, - ID[ID(), ID()], - CC.NewInstruction[ - new_inst(Expr(:call, println, CC.SSAValue(2))), - new_inst(IDGotoNode(block_2_id)), - ], - ); - -julia> insert!(bb_ir_copy.blocks, 4, new_bb); - -julia> CC.IRCode(bb_ir_copy) - 1 ─ nothing::Nothing -4 2 ┄ %2 = φ (#1 => 1, #3 => %8)::Int64 - │ %3 = φ (#1 => 0, #3 => %7)::Int64 - │ %4 = intrinsic Base.slt_int(%3, _2)::Bool - └── goto #5 if not %4 -5 3 ─ %6 = intrinsic (Core.Intrinsics.mul_int)(%3, 2)::Int64 -6 │ %7 = intrinsic Base.add_int(%6, 1)::Int64 -7 │ %8 = intrinsic (Core.Intrinsics.add_int)(%2, %7)::Int64 -8 └── goto #2 - 4 ─ dynamic (println)(%2)::Any - └── goto #2 - 5 ─ return %2 -``` -Observe that, in this case, rather than creating `new_bb` and then inserting instructions into it, we simply create the block _with_ the instructions. -This programming style is often more convenient. -Additionally note that we create an `ID` for each statement in the new basic block. -These `ID`s are never actually used anywhere, but `BBCode` requires that each instruction be associated to an `ID`, so we must create them. -Additionally, note the usage of an [`Mooncake.BasicBlockCode.IDGotoNode`](@ref). -This is exactly the same thing as a `Core.Compiler.GotoNode`, except it contains an `ID` stating which basic block to jump to, rather than an `Int`. -Similarly, the [`Mooncake.BasicBlockCode.IDGotoIfNot`](@ref) is a direct translation of `Core.Compiler.GotoIfNot`, with the `dest` field being an `ID` rather than an `Int`. +This is the right tool when the transformation stays local to existing basic blocks. -Furthermore, note that the `goto if not` instruction at the end of basic block `#2` now (correctly) jumps to basic block `#5`, whereas before it jumped to block `#4`. -That is, by virtue of the fact that the `ID` associated to each basic block remains unchanged in `BBCode`, all pre-existing control flow relationships have remained the same. -Moreover, we did not have to write any book-keeping code to ensure that this update happened correctly. +## Reverse-Mode CFG Assembly -Now that we've created the new basic block, we modify block `#3` to fall-through to the new block if `%2` is even, and to jump straight back to block `#2` if not: -```jldoctest my_factorial -julia> bb = bb_ir_copy.blocks[3]; - -julia> cond_id = ID(); +Reverse mode needs more than local SSA insertion. It frequently has to: -julia> target_id = bb_ir_copy.blocks[2].inst_ids[1]; +1. create fresh blocks, +2. thread predecessor-sensitive phi handling, +3. insert reverse-only control flow, and +4. preserve a coherent CFG while doing so. -julia> insert!(bb, 4, cond_id, new_inst(Expr(:call, iseven, target_id))); +Mooncake now handles that in [`src/interpreter/reverse_mode.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/main/src/interpreter/reverse_mode.jl) +using a builder-local `CFGBlock` representation. The reverse transform first translates the +normalized primal `IRCode` into CFG blocks with stable internal `ID`s, assembles the +forwards and pullback control flow in that representation, and finally lowers the result +back to `IRCode`. -julia> bb.insts[end] = new_inst(IDGotoIfNot(cond_id, block_2_id)); - -julia> new_ir = CC.IRCode(bb_ir_copy) - 1 ─ nothing::Nothing -4 2 ┄ %2 = φ (#1 => 1, #3 => %8)::Int64 - │ %3 = φ (#1 => 0, #3 => %7)::Int64 - │ %4 = intrinsic Base.slt_int(%3, _2)::Bool - └── goto #5 if not %4 -5 3 ─ %6 = intrinsic (Core.Intrinsics.mul_int)(%3, 2)::Int64 -6 │ %7 = intrinsic Base.add_int(%6, 1)::Int64 -7 │ %8 = intrinsic (Core.Intrinsics.add_int)(%2, %7)::Int64 -8 │ %9 = dynamic (iseven)(%2)::Any - └── goto #2 if not %9 - 4 ─ dynamic (println)(%2)::Any - └── goto #2 - 5 ─ return %2 -``` -Observe that in order to tie the conditional to the goto-if-not, we simply ensure that the `ID` associated to the instruction which computes the conditional appears in the `IDGotoIfNot` instruction. +That split is deliberate: -### Run the new code +1. `IRCode` remains the source of truth at the compiler boundary. +2. The builder provides a convenient place to manipulate reverse-mode control flow. +3. The final lowering step re-establishes standard compiler IR with a coherent CFG. -As ever, we can construct a `Core.OpaqueClosure` using `IRCode` in order to produce something runnable. -Since `new_ir` originated as method IR, its first argument type still corresponds to the -function object rather than the opaque-closure environment tuple. For a zero-capture -opaque closure, we therefore first rewrite `argtypes[1]` to `Tuple{}`: -```jldoctest my_factorial -julia> new_ir.argtypes[1] = Tuple{}; - -julia> oc = Core.OpaqueClosure(new_ir; do_compile=true) -(::Int64)->◌::Int64 - -julia> oc(1000) -2 -12 -58 -248 -1014 -2037 -``` -Exactly what `oc` is computing is neither here nor there. -The point is that we've successfully inserted a new basic block into Julia's IR, and produced a callable from it. +This page is mainly about the representations themselves. For the full reverse-mode pipeline, +including statement translation, control-flow replay, and forward-to-reverse communication, see +[`reverse_mode_design.md`](reverse_mode_design.md). ## Summary -We have reviewed the two representations of Julia IR used in Mooncake. -Where possible, we always use `IRCode` -- as discussed, forwards-mode AD exclusively uses `IRCode`. -`BBCode` is basically only needed when undertaking transformations which involve changes to basic block structure -- the insertion of new basic blocks, and the modification of terminators in a way which changes the predecessors / successors of a given block being the primary sources of these kinds of changes. -Reverse-mode AD makes extensive use of such transformations, so `BBCode` is currently important there. - -There are efforts such as [this PR](https://github.com/JuliaLang/julia/pull/45305) to augment `IRCode` with the capability to manipulate the CFG structure in a convenient manner. -Ideally these efforts will succeed, then we can do away with `BBCode`. +`IRCode` is the main compiler-facing representation throughout Mooncake. +Forward mode mostly performs local statement rewrites on that representation. +Reverse mode still starts from normalized `IRCode`, but assembles its extra control flow in +a builder-local CFG before lowering back to `IRCode`. -### Comparison with Alternative Approaches +If you are modifying interpreter internals, the most important invariants to preserve are: -It's worth noting that other automatic differentiation systems have taken different approaches to IR manipulation. For example, [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) uses an "Optics" approach for IR transformations. - -For readers interested in learning more about Julia's IR representation beyond what's covered here, the [Scientific Programming in Julia course materials](https://github.com/JuliaTeachingCTU/Scientific-Programming-in-Julia/blob/2023W/docs/src/lecture_09/ircode.md) provide additional valuable context. - -## Docstrings - -```@autodocs; canonical=true -Modules = [Mooncake.BasicBlockCode] -``` +1. SSA uses must stay coherent after insertion and compaction. +2. `ir.stmts` and `ir.cfg` must agree after lowering. +3. Phi-node edges and predecessor relationships must stay aligned when blocks are removed or reordered. diff --git a/docs/src/developer_documentation/reverse_mode_design.md b/docs/src/developer_documentation/reverse_mode_design.md index 04979822ff..f132c7eb46 100644 --- a/docs/src/developer_documentation/reverse_mode_design.md +++ b/docs/src/developer_documentation/reverse_mode_design.md @@ -1,90 +1,1003 @@ # Reverse-Mode Design -## Compilation process +Last checked: 04/04/2026, Julia v1.10 / v1.11 / v1.12. -Last checked: 09/02/2025, Julia v1.10.8 / v1.11.3, Mooncake 0.4.83. +This page gives a high-level map of how Mooncake's reverse-mode transform is structured. +It is aimed at readers who want to understand the main ideas before reading the implementation. -This brief informal note was largely written by Guillaume Dalle while learning how Mooncake's internals operate for reverse-mode, in order to be able to add forwards-mode AD. -It should help readers orient themselves when first trying to understand Mooncake's internals. +## High-Level Transform: `IRCode` to `IRCode` + +The reverse-mode pipeline is easiest to understand as a four-step transform: + +1. Start from normalized primal `IRCode`. +2. Convert that `IRCode` into Mooncake's builder-local CFG representation. +3. Assemble new forward and reverse CFGs. +4. Lower each assembled CFG back to compiler `IRCode`. + +So reverse mode is no longer "edit compiler IR in place until it works". The compiler `IRCode` +is the input and output format, but most of the reverse-mode assembly happens in the middle on +`CFGBlock`s with Mooncake `ID`s. + +### 1. Normalize and reinterpret the primal `IRCode` + +[`generate_ir`](@ref Mooncake.generate_ir) begins by looking up the inferred primal `IRCode`, +running [`normalise!`](@ref Mooncake.normalise!), and converting the result into CFG blocks. +At that point the source of truth is still the primal `IRCode`, but reverse mode stops working +directly with compiler block numbers and SSA names. + +The local CFG layer exists because reverse mode needs to introduce new control flow, not just +rewrite statements. In particular, the pullback has to reconstruct which predecessor edge was +taken through the primal CFG, including phi-sensitive cases. + +### 2. Translate each primal statement into AD fragments + +Each primal statement is translated by [`make_ad_stmts!`](@ref Mooncake.make_ad_stmts!) into an +`ADStmtInfo`. This is a small per-statement plan containing: + +- forward-pass instructions +- reverse-pass instructions +- an optional communication value that must survive from the forward run to the pullback + +`ADInfo` stores the global transform state used while assembling those +fragments: argument and SSA type tables, reverse-data references, shared-data bookkeeping, +block-stack information for control-flow reconstruction, and debug-mode configuration. + +The important design point is that statement translation does not directly build the final +pullback `IRCode`. It produces forward and reverse fragments first, and whole CFGs are assembled +afterward. + +### 3. Assemble forward and reverse CFGs + +`forwards_pass_ir` builds the forward closure CFG. This CFG: + +- loads shared captured state +- initializes lazy zero-rdata bookkeeping +- emits the translated forward statements for each primal block +- pushes communication values needed later by the pullback +- records block-stack information when the reverse pass must recover dynamic control flow + +`pullback_ir` then builds the pullback CFG separately. This CFG: + +- loads shared data and reverse-data references +- dispatches to the block that actually exited during the forward run +- walks primal statements in reverse order +- reconstructs predecessor-sensitive control flow +- handles phi nodes by routing through edge-specific reverse blocks when needed +- materializes the final cotangent tuple returned by the pullback closure + +This is why the pullback is more than "run the statements backwards". It also has to replay the +control-flow structure needed to send cotangents back along the correct incoming edges. + +### 4. Lower the builder CFG back to compiler `IRCode` + +Once the forward and reverse CFGs are assembled, `lower_cfg_blocks_to_ir` turns them back into +coherent compiler `IRCode`. This step +handles the mechanical reconstruction work in one place: + +- canonicalizing the local CFG +- pruning unreachable blocks +- lowering switch-style control flow into compiler-compatible terminators +- rebuilding SSA numbering and block numbering +- constructing a fresh `Core.Compiler.IRCode` + +After that, [`generate_ir`](@ref Mooncake.generate_ir) can run the usual optimization pass on +both results and wrap them into opaque closures for the final derived reverse rule. + +In short, the reverse-mode transform is: + +```text +primal IRCode + -> normalized primal IRCode + -> builder-local CFG + -> forward CFG + pullback CFG + -> forward IRCode + pullback IRCode +``` + +That middle CFG stage is what keeps the reverse-mode implementation manageable: it isolates the +hard control-flow surgery from the compiler's concrete `IRCode` datastructure, and only lowers +back to compiler IR once the new program structure is complete. + +## A Worked Mini-Example + +Here is the smallest useful mental model for the whole reverse-mode pipeline. Consider a primal +function with one active call followed by a return: + +```julia +function f(x) + y = sin(x) + return y +end +``` + +Very roughly, the normalized primal `IRCode` looks like: + +```text +bb1: + %1 = sin(_2) + return %1 +``` + +### Step 1: statement translation + +`make_ad_stmts!` translates the call and the return into forward and reverse fragments. + +For the call, the important effect is: + +```text +forward: + %rule_result = rule_for_sin(x_arg) + %pb = getfield(%rule_result, 2) + %1 = getfield(%rule_result, 1) + +reverse: + %d1 = rdata_ref_for_%1[] + rdata_ref_for_%1[] = zero(...) + %dx = %pb(%d1) + increment_ref!(rdata_ref_for_x, getfield(%dx, 1)) +``` + +For the return, the important effect is: + +```text +forward: + return %1 + +reverse: + increment_ref!(rdata_ref_for_%1, dy) +``` + +### Step 2: forward CFG assembly + +`forwards_pass_ir` wraps those fragments in a forward CFG with an extra entry block. + +Conceptually: + +```text +fwd_entry: + load shared captures + initialize lazy zero-rdata state + goto fwd_bb1 + +fwd_bb1: + %rule_result = ... + %pb = ... + %1 = ... + push!(comms_stack, tuple(%pb)) + return %1 +``` + +The forward closure therefore does two things at once: + +1. computes the primal/codual result +2. stores the information that the pullback will need later + +### Step 3: pullback CFG assembly + +`pullback_ir` builds a separate pullback CFG: + +```text +rvs_entry: + load shared captures + create reverse-data refs + goto rvs_bb1 + +rvs_bb1: + %pb = getfield(pop!(comms_stack), 1) + increment_ref!(rdata_ref_for_%1, dy) + %d1 = rdata_ref_for_%1[] + rdata_ref_for_%1[] = zero(...) + %dx = %pb(%d1) + increment_ref!(rdata_ref_for_x, getfield(%dx, 1)) + goto rvs_exit + +rvs_exit: + read argument rdata refs + instantiate any lazy zero-rdata placeholders + return argument cotangent tuple +``` + +This is the essential pattern repeated across larger examples: the pullback is a separate CFG +that consumes stored forward-pass data and sends cotangents backward through the primal +dependency structure. + +### Step 4: lower back to compiler IR + +Finally, `lower_cfg_blocks_to_ir` converts both CFGs back to ordinary compiler `IRCode`. + +So even in this tiny example the real flow is: + +```text +primal IRCode + -> translated statement fragments + -> forward CFG / pullback CFG + -> forward IRCode / pullback IRCode +``` + +These examples use schematic names such as `x_arg` rather than exact lowered argument slots. +In the real generated IR, argument indices can shift because the generated closures carry extra +state in addition to the primal arguments. + +## `CFGBlock`: The Reverse-Mode Working IR + +`CFGBlock` is Mooncake's reverse-mode-local basic-block representation. It is the format used +while assembling the forward closure CFG and the pullback CFG inside +[`reverse_mode.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/main/src/interpreter/reverse_mode.jl). + +At a high level, a `CFGBlock` is just: + +- a stable internal block `ID` +- a vector of `(ID, NewInstruction)` pairs for the statements in that block + +Those IDs are Mooncake-local IDs, not compiler SSA numbers or compiler block numbers. That is +deliberate: reverse mode needs a representation that can survive block insertion, block +splitting, and control-flow rewrites without constantly renumbering compiler SSA values. + +### What `CFGBlock` is used for + +The reverse transform uses `CFGBlock`s to do the hard structural work before lowering back to +compiler IR. In particular, the builder-local CFG is where Mooncake: + +- translates primal SSA statements into forward and reverse fragments +- inserts extra entry and exit blocks +- groups and restores communication values for the pullback +- rewrites control flow into predecessor-sensitive reverse dispatch +- creates edge-specific reverse blocks for phi handling +- prunes unreachable blocks and canonicalizes the resulting CFG + +So `CFGBlock` is not an alternative public IR for Mooncake as a whole. It is a local assembly +format used to construct reverse-mode programs safely before converting them back into +`Core.Compiler.IRCode`. + +### Why `IRCode` is not sufficient as the working format + +`IRCode` is sufficient as the source and target format. It is not a good format for the middle +of reverse-mode assembly. + +The core problem is that reverse mode does more than local statement replacement. It often has +to: + +- create fresh blocks +- insert control-flow that did not exist in the primal +- split one logical reverse step across several blocks +- route cotangents along predecessor-specific edges +- keep phi handling consistent with those edges +- rebuild SSA numbering and CFG numbering coherently at the end + +Compiler `IRCode` stores statements and CFG structure in tightly related forms. If you edit it +mid-assembly, you have to keep the instruction stream, block numbering, terminators, phi edges, +and CFG metadata coherent at every intermediate step. That is possible for local edits, but it +becomes brittle once reverse mode starts inserting whole blocks and rethreading control flow. + +`CFGBlock` avoids that problem by giving reverse mode a looser construction format: + +- block identity is stable while the transform is running +- instructions can be inserted without immediate SSA renumbering +- predecessor and successor rewrites can happen before final lowering +- phi-edge handling can be expressed directly in terms of predecessor block IDs + +Only once the new forward and reverse CFGs are structurally complete does Mooncake lower them +back to `IRCode`, rebuild block/SSA numbering, and hand the result back to the compiler. + +### The right way to think about it + +`IRCode` is the compiler-facing representation. +`CFGBlock` is the reverse-mode assembly representation. + +Mooncake still begins with normalized primal `IRCode` and ends with compiler `IRCode`, but it +does the difficult control-flow surgery in `CFGBlock` form because that is the point where the +transform needs flexibility more than compiler-format exactness. + +## Data Structures That Matter Most + +Two data structures carry most of the transform's state: + +- `ADInfo`: global state shared across the whole derivation +- `ADStmtInfo`: the per-statement translation result + +### `ADInfo` + +When reading the implementation, the most important `ADInfo` fields are: + +- `shared_data_pairs`: the values captured by both generated closures +- `block_stack_id` and `block_stack`: the control-flow replay channel +- `arg_rdata_ref_ids` and `ssa_rdata_ref_ids`: where reverse data is accumulated +- `ssa_insts` and `arg_types`: the primal type information used during translation +- `lazy_zero_rdata_ref_id`: the placeholder-zero mechanism used at pullback exit + +The remaining fields mostly support those jobs rather than introducing separate ideas. + +### `ADStmtInfo` + +`ADStmtInfo` is simpler. Its central fields are: + +- `fwds`: forward-pass instructions for the primal statement +- `rvs`: reverse-pass instructions for that statement +- `comms_id`: the optional value that must survive from forward execution to pullback execution + +If you understand those three fields, you understand the role of `ADStmtInfo`. + +## Statement-Level MWEs + +The reverse-mode transform is easiest to follow if you read `make_ad_stmts!` as a translator +from one primal SSA statement into: + +- the forward statements that execute the primal computation and save anything the pullback + will need later +- the reverse statements that consume cotangents and propagate them to arguments and earlier + SSA values + +These are only sketches, but they match the current implementation strategy closely. +They use pseudocode-style helper names such as `rule_for_sin`, `increment_ref!`, and +`switch_to_reverse_phi_edge(...)` to show the dataflow. They are not literal emitted APIs or +exact compiler IR. + +### MWE 1: Constant literal + +Primal statement: -Rule building is done statically, based on types. Some methods accept values, e.g. ```julia -build_rrule(args...; debug_mode=false) +%5 = 3.0 ``` -but these simply extract the types of all the arguments and call the main method (non Helper) for [`build_rrule`](@ref Mooncake.build_rrule). -The action happens in [`s2s_reverse_mode_ad.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/main/src/interpreter/s2s_reverse_mode_ad.jl), in particular the following method: +Forward fragment: + ```julia -build_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) +%5 = uninit_fcodual(3.0) +``` + +Reverse fragment: + +```julia +nothing +``` + +Constants produce a `CoDual` on the forward pass, but there is nothing to propagate on the +reverse pass because Mooncake does not track cotangents for literals. + +### MWE 2: Active branch condition + +Primal statement: + +```julia +goto #3 if not %2 +``` + +If `%2` is active, then `%2` is a `CoDual` in the forward IR, so the condition must branch on +its primal: + +Forward fragment: + +```julia +%cond = primal(%2) +goto #3 if not %cond +``` + +Reverse fragment: + +```julia +nothing +``` + +The branch itself carries no cotangent information. Its effect on the reverse pass is indirect: +the forward pass records enough control-flow information for the pullback to recover which edge +was taken. + +### MWE 3: Active call + +Primal statement: + +```julia +%7 = sin(%4) +``` + +Very roughly, the forward fragment becomes: + +```julia +%rule_result = rule_for_sin(%4) +%pb = getfield(%rule_result, 2) +%raw_y = getfield(%rule_result, 1) +%7 = typeassert(%raw_y, CoDual{Float64, ...}) +``` + +and the reverse fragment becomes: + +```julia +%dy = rdata_ref_for_%7[] +rdata_ref_for_%7[] = zero_like_rdata_from_type(typeof(%7)) +%dx_tuple = %pb(%dy) +increment_ref!(rdata_ref_for_%4, getfield(%dx_tuple, 1)) +``` + +The exact generated code is more explicit than this because Mooncake writes out the `getfield`, +`setfield!`, and `increment!!` operations directly. That keeps the reverse-data references +visible to the optimizer. + +The important point is that the forward fragment computes both the output codual and a pullback +object, and the reverse fragment later consumes the stored pullback plus the output cotangent. + +### MWE 4: Return of an active SSA value + +Primal statement: + +```julia +return %9 +``` + +Forward fragment: + +```julia +%checked = typeassert(%9, fwd_ret_type) +return %checked +``` + +Reverse fragment: + +```julia +increment_ref!(rdata_ref_for_%9, dy) +``` + +The forward closure returns the codual value. The pullback closure receives the cotangent of +that return value and accumulates it into the reverse-data reference associated with `%9`. + +### MWE 5: Phi nodes + +Consider a primal block that starts with a phi node: + +```julia +3 ┄ %6 = φ (#1 => arg_x, #2 => %5) +``` + +On the forward pass, this is mostly structural. The phi node is rebuilt so that its incoming +values are coduals rather than raw primal values: + +```julia +3 ┄ %6 = φ (#1 => codual_arg_x, #2 => %5) +``` + +Schematically, the incoming primal argument is replaced by the corresponding codual argument in +the generated forward closure. In the real lowered IR, the exact argument slot can shift because +the generated closure carries extra leading state. + +The important work happens on the reverse side, not at the phi statement itself. Suppose the +cotangent for `%6` is stored in `r%6`. When the pullback reaches the reverse counterpart of +this block, it first extracts and zeros the cotangent for `%6`, then dispatches to an +edge-specific reverse block: + +```julia +%d6 = r%6[] +r%6[] = zero_like_rdata_from_type(typeof(%6)) +switch_to_reverse_phi_edge(...) +``` + +If the primal arrived from block `#1`, the reverse phi-edge block behaves roughly like: + +```julia +increment_ref!(r_arg_x, %d6) +goto reverse_block_for_#1 ``` -`sig_or_mi` is either a signature, such as `Tuple{typeof(foo), Float64}`, or a `Core.MethodInstance`. -Signatures are extracted from `Core.MethodInstance`s as necessary. -If a signature has a custom rule ([`Mooncake.is_primitive`](@ref) returns `true`), we take it, otherwise we generate the IR and differentiate it. +If the primal arrived from block `#2`, it behaves roughly like: -The forward and reverse pass IRs are created by the [`generate_ir`](@ref Mooncake.generate_ir) method. -The `OpaqueClosure` allows going back from the IR to a callable object. More precisely we use `MistyClosure` to store the associated IR. +```julia +increment_ref!(r%5, %d6) +goto reverse_block_for_#2 +``` + +So the key point is that phi handling is split across two places: + +- the forward pass rebuilds the phi with codual inputs +- the reverse pass routes cotangents back through predecessor-specific phi-edge blocks + +This is why the pullback needs predecessor reconstruction, rather than just local statement +reversal. + +## How Control Flow Is Handled + +Reverse mode cannot treat control flow as a local statement rewrite problem. The forward pass +must preserve enough information for the pullback to know: + +- which block actually executed next +- which predecessor edge reached a join block +- which reverse block should run after the current one finishes + +That is why reverse mode assembles whole CFGs and uses the block stack when the predecessor is +not statically determined. + +### MWE 1: Straight-line fallthrough + +Suppose the primal CFG is just: + +```text +bb1 -> bb2 -> return +``` + +and `bb2` has only one predecessor, namely `bb1`. + +In this case there is no ambiguity. The forward pass does not need to log any predecessor ID +for `bb2`, and the reverse pass can jump directly from the reverse counterpart of `bb2` to the +reverse counterpart of `bb1`. + +Very roughly: + +```text +forward: fwd_bb1 ; fwd_bb2 ; return +reverse: rvs_bb2 ; goto rvs_bb1 +``` + +This is the cheap path: no block-stack push, no block-stack pop, no predecessor switch. + +### MWE 2: Simple conditional branch + +Suppose the primal CFG is: + +```text +bb1: + if cond goto bb2 else bb3 + +bb2: + ... + goto bb4 + +bb3: + ... + goto bb4 + +bb4: + return y +``` + +On the forward pass, execution may reach `bb4` from either `bb2` or `bb3`. That predecessor is +not statically known, so the forward pass records enough control-flow information to recover it +later. + +Conceptually the forward pass does something like: + +```text +fwd_bb1: + branch on primal(cond) + +fwd_bb2: + push!(block_stack, bb2) + ... + goto fwd_bb4 + +fwd_bb3: + push!(block_stack, bb3) + ... + goto fwd_bb4 +``` + +Then, once the pullback reaches the reverse counterpart of `bb4`, it uses `make_switch_stmts` +to recover which predecessor actually led there: + +```text +rvs_bb4: + prev = pop!(block_stack) + if prev == bb2 + goto rvs_phi_or_pred_for_bb2 + else + goto rvs_phi_or_pred_for_bb3 +``` + +So the pullback does not "re-run the branch condition". It replays the realized control-flow +path from the forward run. + +### MWE 3: Join block with phi node + +Now combine branching with a phi node: + +```julia +bb2: + %5 = ... + goto bb4 + +bb3: + goto bb4 + +bb4: + %6 = φ (#2 => %5, #3 => arg_x) + return %6 +``` + +The forward pass computes `%6` in the usual SSA sense. The reverse pass has to do two things +at the join: + +1. determine whether control came from `bb2` or `bb3` +2. send the cotangent of `%6` back to `%5` or `arg_x` accordingly + +So the reverse CFG for `bb4` behaves roughly like: + +```text +rvs_bb4: + d6 = r%6[] + r%6[] = zero(...) + prev = pop!(block_stack) + if prev == bb2 + goto rvs_phi_edge_bb2 + else + goto rvs_phi_edge_bb3 + +rvs_phi_edge_bb2: + increment_ref!(r%5, d6) + goto rvs_bb2 + +rvs_phi_edge_bb3: + increment_ref!(r_arg_x, d6) + goto rvs_bb3 +``` + +This is the precise place where control-flow replay and phi handling meet. + +### MWE 4: Unique-predecessor optimization + +Mooncake does not always push block IDs. If a block's predecessor is uniquely determined, the +reverse pass can hard-code that predecessor instead of reading the block stack. + +For example: + +```text +bb1 -> bb2 +bb2 -> bb3 +bb3 -> return +``` + +If `bb3` can only ever be reached from `bb2`, then the reverse pass for `bb3` can go straight +to `bb2`. No stack traffic is needed for that edge. + +This optimization matters because dynamic control-flow logging is only needed where the forward +execution path loses information that the pullback must recover later. + +## Entry and Exit Blocks + +Both generated closures contain extra structural blocks that do not correspond directly to a +single primal block. + +### Forward entry block + +The forward entry block is responsible for: + +- loading shared captured data +- initializing lazy zero-rdata placeholders for arguments +- optionally logging the synthetic entry block for later reverse dispatch + +### Pullback entry block + +The pullback entry block is responsible for: + +- loading the same shared captured data +- creating reverse-data references for arguments and SSA values +- dispatching to the reverse block associated with the primal block that actually returned + +### Pullback exit block + +The pullback exit block is responsible for: + +- reading argument reverse-data references +- materializing true zeros from lazy zero-rdata placeholders where needed +- packaging the final argument cotangent tuple +- returning that tuple with the expected type + +These blocks are worth calling out explicitly because they explain why the generated forward and +reverse CFGs do not look like simple block-for-block reversals of the primal CFG. + +## Forward-to-Reverse Communication + +Forward and reverse code do not communicate through a single channel. The current design uses +three distinct mechanisms, each for a different kind of information. + +### 1. Shared captured data + +Some values are known statically when the derived rule is built and are needed by both the +forward closure and the pullback closure. These go through `SharedDataPairs`. + +Examples include: + +- rule objects that are not singleton values +- constant coduals that are not safe to interpolate directly into IR +- the per-transform block stack object +- the lazy-zero-rdata reference used to materialize correctly typed zeros at pullback exit + +Both generated closures receive the same captures tuple. At the start of each closure, +`shared_data_stmts` extracts those captures back into local IDs. + +This is the "static shared state" channel: build-time data, not per-execution data. + +### 2. Per-block communication stacks + +Some values are only known during the forward execution and must be handed to the pullback for +the matching block. This is what the `comms_id` field of `ADStmtInfo` is for. + +For a statement like an active call, the typical communicated value is the pullback object: + +```julia +%rule_result = rule(...) +%pb = getfield(%rule_result, 2) +``` + +`%pb` is marked as the statement's `comms_id`. Later, `create_comms_insts!` groups the +`comms_id`s for a primal block, builds a tuple of them on the forward pass, and pushes that +tuple onto a block-local stack: -The `Pullback` and `DerivedRule` structs are convenience wrappers for `MistyClosure`s with some bookkeeping. +```julia +%tuple = tuple(%pb1, %pb2, ...) +push!(comms_stack, %tuple) +``` + +At the start of the corresponding reverse block, the pullback pops that tuple and restores the +saved values to the same IDs: + +```julia +%tuple = pop!(comms_stack) +%pb1 = getfield(%tuple, 1) +%pb2 = getfield(%tuple, 2) +``` + +This is the "dynamic value" channel: values produced during the forward execution and consumed +later by the reverse execution. + +### 3. The block stack for dynamic control flow + +The pullback has to know which predecessor edge the primal actually took. When that predecessor +is not statically determined, the forward pass records block IDs on a dedicated `BlockStack`. + +On the forward pass, selected blocks push their ID: + +```julia +push!(block_stack, current_block_id) +``` -Diving one level deeper, in the following method: +On the reverse pass, `make_switch_stmts` pops the predecessor ID and dispatches accordingly: ```julia +%prev = pop!(block_stack) +switch(%prev == pred1 ? ..., %prev == pred2 ? ..., ...) +``` + +This is what allows the pullback to send cotangents along the correct incoming edge and to +handle phi nodes via predecessor-specific reverse blocks. + +In short, the three communication mechanisms are: + +1. captures shared by both closures for static build-time data +2. per-block comms stacks for dynamic forward values such as pullback objects +3. the block stack for replaying dynamic control flow on the reverse pass + +## Concept-to-Helper Map + +If you want to jump from the conceptual description in this page to the implementation, these +are the main landmarks: + +- statement translation: `make_ad_stmts!` +- shared captured data: `SharedDataPairs`, `shared_data_tuple`, `shared_data_stmts` +- per-block forward-to-reverse values: `create_comms_insts!` +- forward CFG assembly: `forwards_pass_ir` +- pullback CFG assembly: `pullback_ir` +- control-flow replay: `__push_blk_stack!`, `__pop_blk_stack!`, `make_switch_stmts` +- phi-edge reverse routing: `conclude_rvs_block`, `rvs_phi_block` +- local CFG representation: `CFGBlock` +- lowering back to compiler IR: `lower_cfg_blocks_to_ir` + +## Where This Lives in Code + +If you want to connect the conceptual story above to the implementation, the main entry points +are: + +```julia +build_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) + generate_ir( interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true ) ``` -The function [`lookup_ir`](@ref Mooncake.lookup_ir) calls `Core.Compiler.typeinf_ircode` on a method instance, which is a lower-level version of `Base.code_ircode`. +Here `sig_or_mi` is either a signature such as `Tuple{typeof(foo), Float64}` or a +`Core.MethodInstance`. + +If the signature has a custom rule ([`Mooncake.is_primitive`](@ref) returns `true`), Mooncake +uses that rule. Otherwise it looks up the primal IR and differentiates it. + +[`lookup_ir`](@ref Mooncake.lookup_ir) calls `Core.Compiler.typeinf_ircode` on a method +instance, which is a lower-level version of `Base.code_ircode`. + +The transform works on `Core.Compiler.IRCode`, not the `CodeInfo` shown by `@code_typed`. +[`normalise!`](@ref Mooncake.normalise!) rewrites some `IRCode` expressions into forms that are +easier for the AD transform to handle, after which reverse mode assembles through the local CFG +builder in +[`reverse_mode.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/main/src/interpreter/reverse_mode.jl) +and lowers back to `IRCode`. + +## Captures and Closure Construction + +The generated forward and reverse `IRCode`s do not communicate by calling each other directly. +They communicate partly through ordinary runtime values and partly through a shared captures +tuple that is embedded into the generated closures. + +### What gets captured + +The shared captures tuple contains the values recorded in `SharedDataPairs`, such as: + +- static rule objects that are not safe or convenient to inline directly +- stacks used for per-block communication +- the block stack used for control-flow replay +- the lazy-zero-rdata reference used to finish the pullback result + +`generate_ir` builds this tuple with: + +```julia +shared_data = shared_data_tuple(info.shared_data_pairs) +``` + +and both generated closures are later constructed with that same tuple: + +```julia +fwd_oc = misty_closure(dri.fwd_ret_type, dri.fwd_ir, dri.shared_data...) +rvs_oc = misty_closure(dri.rvs_ret_type, dri.rvs_ir, dri.shared_data...) +``` + +Both generated closures therefore receive access to the same logical captured data. At the +start of each closure, `shared_data_stmts` lowers that tuple back into local IDs. + +So the resulting `IRCode` does not contain the captured values directly as ordinary SSA +definitions. Instead, it contains loads from the closure's captures field near the entry block. + +### MWE: one captured tuple shared by both closures + +Suppose `SharedDataPairs` conceptually contains: + +```text +[(id_rule, some_rule), (id_blk, block_stack), (id_zero, lazy_zero_ref)] +``` + +Then: + +```text +shared_data_tuple(...) == (some_rule, block_stack, lazy_zero_ref) +``` + +and the entry blocks in the lowered `IRCode`s begin by reconstructing those bindings: + +```text +fwd_entry: + id_rule = getfield(_1, 1) + id_blk = getfield(_1, 2) + id_zero = getfield(_1, 3) + ... + +rvs_entry: + id_rule = getfield(_1, 1) + id_blk = getfield(_1, 2) + id_zero = getfield(_1, 3) + ... +``` + +Here `_1` is the generated closure object. The `getfield` operations shown above are the loads +that recover values from that closure's captured state. The important point is that the two +generated closures do not each get separate logical bindings. They are built against the same +captured state assembled during derivation. + +### How forward-to-reverse sharing appears in the final IR + +After lowering, the final forward `IRCode` still contains the machinery that writes dynamic data +needed by the pullback: + +- pushes onto comms stacks for values such as pullback objects +- pushes onto the block stack when control-flow replay is needed +- initialization of lazy zero-rdata placeholders + +The final pullback `IRCode` contains the matching reads: + +- pops from comms stacks +- pops from the block stack +- loads from reverse-data references +- loads from the lazy-zero-rdata capture when finishing the returned cotangent tuple + +So although the builder-local CFG disappears after lowering, the final `IRCode` still encodes +the forward-to-reverse contract explicitly through stack operations, ref operations, and capture +loads. + +### MWE: dynamic value sharing in final `IRCode` + +If an active call produces a pullback object `%pb`, the final forward `IRCode` contains code +equivalent to: + +```text +%tuple = tuple(%pb) +push!(comms_stack, %tuple) +``` + +and the final pullback `IRCode` contains the matching restore: + +```text +%tuple = pop!(comms_stack) +%pb = getfield(%tuple, 1) +``` + +That pair of writes and reads is how a value produced only during the forward run becomes +available later in the pullback. + +### How the closures are built + +`generate_ir` produces two compiler `IRCode`s: + +- one for the forward closure +- one for the pullback closure + +`build_derived_rrule` then packages those into closure objects. At a high level: + +1. `generate_ir` returns `DerivedRuleInfo`, containing `fwd_ir`, `rvs_ir`, and `shared_data` +2. `build_derived_rrule` turns each `IRCode` plus `shared_data` into a `MistyClosure` +3. those closures are placed into a `DerivedRule` +4. when the derived rule is called, it returns the forward result plus a `Pullback` wrapper + around the reverse closure + +The end result is that users do not interact with raw `IRCode` directly. They call a derived +rule, which runs the generated forward closure and receives a callable pullback object whose +captures point at the same shared state prepared during derivation. + +### MWE: final wrapper structure + +Conceptually, the final derived rule looks like: + +```text +DerivedRule( + fwds_oc = MistyClosure(fwd_ir, shared_data), + pb_oc_ref = Ref(MistyClosure(rvs_ir, shared_data)), + ... +) +``` + +Calling that derived rule: + +1. runs `fwds_oc` +2. returns the forward `CoDual` +3. returns a `Pullback` object that knows how to call `pb_oc_ref[]` + +So the forward closure and pullback closure are separate pieces of generated code, but they are +stitched together by the shared captures tuple and the `DerivedRule` / `Pullback` wrappers. + +## SSA Nodes Not Covered -The IR considered is of type [`IRCode`](@ref Mooncake.CC.IRCode), which is different from the `CodeInfo` returned by `@code_typed`. -This format is obtained from `CodeInfo`, used to perform most optimizations in the Julia IR in the [evaluation pipeline](https://docs.julialang.org/en/v1/devdocs/eval/), then converted back to `CodeInfo`. +Most common SSA-form node kinds are handled by `make_ad_stmts!`, but a few are explicitly out +of scope. -The function [`normalise!`](@ref Mooncake.normalise!) is a custom pass to modify `IRCode` and make some expressions nicer to work with. -The possible expressions one can encountered in lowered ASTs are documented [here](https://docs.julialang.org/en/v1/devdocs/ast/#Lowered-form). +### `PhiCNode` -Reverse-mode specific stuff: return type retrieval, `ADInfo`, `bbcode.jl`, `zero_like_rdata.jl`. The `BBCode` structure was a convenience for IR transformation. +`PhiCNode`s are not currently supported. If reverse mode encounters one, it throws an +`unhandled_feature` error immediately. -Beyond the [`interpreter`](https://github.com/chalk-lab/Mooncake.jl/blob/main/src/interpreter/) folder, check out [`tangents.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/main/src/tangents.jl) for forward mode. +These nodes are associated with exception-flow joins rather than ordinary CFG joins, so they +need more than the current phi-node machinery. -[`Tangent`](@ref Mooncake.Tangent) is the correct representation required for Forward mode AD. `FData` and `RData` are not representations needed directly. +### `UpsilonNode` -For testing, all the tests got via the `generate_test_functions` method (defined in [`test_resources.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/1894b2f23916091d5022134db0af61a75c1035ee/src/test_resources.jl#L655)) must pass. -Recycle the functionality from reverse mode test utils. +`UpsilonNode`s are also not currently supported. Encountering one raises an `unhandled_feature` +error with guidance to avoid the construct or write a manual rule. -To manipulate `IRCode`, check out the fields: +In practice these nodes arise from some `try` / `catch` / `finally` lowering patterns. The +current reverse-mode transform does not attempt to differentiate through that exception-state +machinery. -- `ir.argtypes` is the signature. Some are annotated with `Core.Const` to facilitate constant propagation for instance. Other annotations are `PartialStruct`, `Conditional`, `PartialTypeVar`. `Core.Compiler.widenconst` is used to extract types from these. -- `ir.stmts` is a `Core.Compiler.InstructionStream`. This represents a sequence of instructions via 5 vectors of the same length: - - `stmts.stmt` is a vector of expressions (or other IR node types), see [AST docs](https://docs.julialang.org/en/v1/devdocs/ast/#Lowered-form) - - `stmts.type` is a vector of types for the left-hand side of the assignment - - three others -- `ir.cfg` is the Control Flow Graph of type `Core.Compiler.CFG` -- `ir.meta` is metadata, not important -- `ir.new_nodes` is an optimization buffer, not important -- `ir.sptypes` is for type parameters of the called function +### Practical boundary -We must maintain coherence between the various components of `IRCode` (especially `ir.stmts` and `ir.cfg`). That is the reason behind `BBCode`, to make coherence easier. -We can deduce the CFG from the statements but not the other way around: it's only composed of blocks of statement indices. -In forward mode we shouldn't have to modify anything but `ir.stmts`. -Do line by line transformation of the statements and then possibly refresh the CFG. +Ordinary `PhiNode`s are supported and lowered through predecessor-sensitive reverse CFG logic. +`PhiCNode`s and `UpsilonNode`s are not. So "control flow with standard SSA joins" is in scope, +while "exception SSA machinery" is still out of scope. -Examples of how line-by-line transformations can be done, are defined in [`Mooncake.make_ad_stmts!`](@ref). -The `IRCode` nodes are not explicitly documented in or . Might need completion of official docs, but Mooncake docs in the meantime. +## Known Boundaries -For additional information about `IRCode` and `BBCode` data structures and transformation examples, see [IR Representations and Code Transformations](@ref). +Beyond unsupported `PhiCNode` and `UpsilonNode` handling, a few broader boundaries are useful to +keep in mind when reading or extending reverse mode: -Inlining pass can prevent us from using high-level rules by inlining the function (e.g. unrolling a loop). -The contexts in [`interpreter/contexts.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/src/interpreter/contexts.jl) are `MinimalCtx` (necessary for AD to work) and `DefaultCtx` (ensure that we hit all of the rules). -Distinction between rules is not well maintained in Mooncake at the moment. -The function `is_primitive` defines whether we should recurse into the function during AD and break it into parts, or look for a rule. -If we define a rule we should set `is_primitive` to `true` for the corresponding function. +- many operations are only as good as the primitive or derived rules available for them +- debug mode wraps rule calls and changes the generated code shape slightly for diagnostics +- the transform depends on compiler IR conventions that can shift across Julia minor versions +- reverse mode assumes ordinary SSA/control-flow lowering patterns, not the full exception-state + machinery generated for every possible language feature -In [`interpreter/abstract_interpretation.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/src/interpreter/abstract_interpretation.jl) we interact with the Julia compiler. -The most important part is preventing the compiler from inlining. +## Further Reading -The `MooncakeInterpreter` subtypes `Core.Compiler.AbstractInterpreter` to interpret Julia code. -There are also Cthulhu, Enzyme, JET interpreters. -Tells you how things get run. +If you want the supporting background after this page: -For second order we will need to adapt IR lookup to misty closures. +- [`ir_representation.md`](ir_representation.md) explains the compiler-facing `IRCode` representation. +- [`forwards_mode_design.md`](forwards_mode_design.md) covers the forward-mode side. +- [`src/interpreter/reverse_mode.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/main/src/interpreter/reverse_mode.jl) is the implementation described here. diff --git a/docs/src/understanding_mooncake/rule_system.md b/docs/src/understanding_mooncake/rule_system.md index 4b43c7e896..a611c3d355 100644 --- a/docs/src/understanding_mooncake/rule_system.md +++ b/docs/src/understanding_mooncake/rule_system.md @@ -368,7 +368,7 @@ end ``` Consider the `function` -```jldoctest foo-doctest +```julia-repl julia> foo(x::Tuple{Float64, Vector{Float64}}) = x[1] + sum(x[2]) foo (generic function with 1 method) ``` @@ -396,7 +396,7 @@ where ``\mathbf{1}`` is the vector of length ``N`` in which each element is equa Now that we know what the adjoint is, we'll write down the `rrule!!`, and then explain what is going on in terms of the adjoint. This hand-written implementation is to aid your understanding -- Mooncake.jl should be relied upon to generate this code automatically in practice. -```jldoctest foo-doctest +```julia-repl julia> function rrule!!(::CoDual{typeof(foo)}, x::CoDual{Tuple{Float64, Vector{Float64}}}) dx_fdata = x.dx function dfoo_adjoint(dy::Float64) @@ -412,7 +412,7 @@ julia> function rrule!!(::CoDual{typeof(foo)}, x::CoDual{Tuple{Float64, Vector{F ``` where `dy` is the rdata for the output to `foo`. The `rrule!!` can be called with the appropriate `CoDual`s: -```jldoctest foo-doctest +```julia-repl julia> codual_foo = CoDual(foo, NoFData()); julia> codual_x = CoDual((5.0, [1.0, 2.0]), (NoFData(), [0.0, 0.0])); @@ -427,14 +427,14 @@ true ``` and the pullback with appropriate rdata: -```jldoctest foo-doctest +```julia-repl julia> pb!!(1.0) (NoRData(), (1.0, NoRData())) ``` which will update the fdata for the `Vector{Float64}` component in-place: -```jldoctest foo-doctest +```julia-repl julia> codual_x CoDual{Tuple{Float64, Vector{Float64}}, Tuple{NoFData, Vector{Float64}}}((5.0, [1.0, 2.0]), (NoFData(), [1.0, 1.0])) ``` @@ -503,7 +503,7 @@ This topic, in particular what goes wrong with permissive tangent type systems l First consider why closures are straightforward to support. Look at the type of the closure produced by `foo`: -```jldoctest +```julia-repl function foo(x) function bar(y) x .+= y @@ -523,7 +523,7 @@ Since the function itself is an argument to its rule, everything enters the rule On the other hand, globals do not appear in the functions that they are a part of. For example, -```jldoctest +```julia-repl const a = randn(10) function g(x) diff --git a/docs/src/understanding_mooncake/what_programme_are_you_differentiating.md b/docs/src/understanding_mooncake/what_programme_are_you_differentiating.md new file mode 100644 index 0000000000..ed09c4835e --- /dev/null +++ b/docs/src/understanding_mooncake/what_programme_are_you_differentiating.md @@ -0,0 +1,515 @@ +# What Programme Are You Differentiating? + +Prerequisites: [Algorithmic Differentiation](@ref). + +In [Mooncake.jl's Rule System](@ref) we discuss a generic mathematical model for a Julia `function`, and state what a rule to differentiate it in reverse-mode must do. +Our goal, however, is to implement an algorithm which produces rules for functions which we do not already have rules for. +This section explains the mathematical model required to know how to do this. + +## A Motivating Example + +By the end of this section we will understand why, for the following function: +```julia +function f(x, y) + a = g(x) + b = h(a, y) + return b +end +``` +the following rule is a correct implementation of reverse-mode AD for it: +```julia +function rr(f, x, y) + a, adj_g = rr(g, x) + b, adj_h = rr(h, a, y) + function adj_f(db) + _, da, dy = adj_h(db) + _, dx = adj_g(da) + return NoRData(), dx, dy + end + return b, adj_f +end +``` +Observe that the above rule essentially does the following: +1. forwards-pass: replace calls to rules. +2. reverse-pass: run adjoints in reverse order, adding together rdata when a variable is used multiple times. + +This opening example uses a Mooncake-like rule shape, including `NoRData()`, because it is meant +to preview the final target. In the next section we temporarily switch to a simpler pedagogical +rule interface that strips away some bookkeeping and focuses only on the adjoint structure. + +This way of writing rules is the essence of the "A" in "AD". +This page is therefore dedicated to building up to this example via a sequence of increasingly general examples. +Once we have this, extending it to a _very_ general class of Julia functions is comparatively straightforward. + +We shall adopt the following approach to each problem: +1. specify class of `function`s, +2. specify class of differentiable functions used to model these `function`s, +3. specify how to find the adjoints of this differentiable model, and +4. describe a rule system which implements these adjoints. + +At a high level, you can think of this approach as first "mathematising" the problem, applying the techniques developed in [Algorithmic Differentiation](@ref) to determine what it is that AD must do, and then providing an outline for implementing this model as a computer programme. + +## Part 1: Simple Compositions of Pure Functions + +For this class of `function`s, the translation between the Julia `function` and the +differentiable function used to model it is almost trivial. + +### `function` Class + +To start with, let us consider only `function`s which are pure (free of externally-visible side effects, such as the modification of their arguments of global variables), unary (single-argument), and don't contain any data themselves (e.g. no closures or callable `struct`s). +For example, consider: +#### `g`: +`g(x::Vector{Float64}) = sin.(x)`. + +#### `h`: +`h(x::Matrix{Float64}) = sum(x)`. + +#### Composition +Let `f` be the composition of `f_1, ..., f_N`, a collection of `N` Julia `function`s which are pure and unary. +This might be implemented as `f(x) := f_N ∘ ... ∘ f_1`, or perhaps +```julia +function f(x) + x_1 = x + x_2 = f_1(x_1) + ... + return f_N(x_N) +end +``` +There are many ways to implement this function. + +### Differentiable Model + +We propose to represent any `function` `f` in this class by a differentiable function ``f : \mathcal{X} \to \mathcal{Y}``. + +#### `g`: +Let ``\mathcal{X} = \mathcal{Y} =: \mathbb{R}^D`` where ``D`` is `length(x)`, and ``f(x) := \sin(x)`` applied elementwise. + +#### `h`: +Let ``\mathcal{X} := \mathbb{R}^{P \times Q}``, and ``\mathcal{Y} := \mathbb{R}``, where ``P`` and ``Q`` are the number of rows and columns in `x`, and ``f(x) := \sum_{p,q} x_{p,q}``. + +#### Composition: +Let ``f_n : \mathcal{X}_n \to \mathcal{X}_{n+1}`` be the differentiable model for `f_n`. +Then the differentiable model ``f : \mathcal{X} \to \mathcal{Y}`` for `f` is ``f := f_N \circ \dots \circ f_1``, with ``\mathcal{X} := \mathcal{X}_1`` and ``\mathcal{Y} := \mathcal{X}_{N+1}``. + +### Adjoints of Model + +You can apply the tools developed in [Algorithmic Differentiation](@ref) to figure out the adjoints of `g` and `h`. +The adjoint of `f` is also given there. +Let ``D f_n [x_n]^\ast`` be the adjoint of the derivative of ``f_n`` at ``x_n``, then the adjoint of ``f`` at ``x`` is just +```math +D f [x]^\ast = D f_1 [x_1]^\ast \circ \dots \circ D f_N [x_N]^\ast. +``` + +### Rules + +For this simple class of functions, a simple rule system will do. +We require that a rule for a `function` with mathematical model ``f`` accepts the same argument as the original `function`, and returns a 2-tuple containing +1. the result of applying the `function` to its input, and +2. another function which implements[^implementing_mathematics_on_a_computer] the adjoint, i.e. ``D f [x]^\ast``. + +Given a rule for a `function` of interest, we simply run the rule, and can then apply the adjoint to any gradient vector of interest. + +#### `g`: +```julia +function rrule(::typeof(g), x::Vector{Float64}) + g_adjoint(ȳ::Vector{Float64}) = cos.(x) .* ȳ + return g(x), g_adjoint +end +``` + +#### `h`: +```julia +function rrule(::typeof(h), x::Matrix{Float64}) + h_adjoint(ȳ::Float64) = fill(ȳ, size(x)) + return h(x), h_adjoint +end +``` + +#### Composition: + +One possible implementation for a rule for the composition of `f_1, ..., f_N` is +```julia +function rrule(::typeof(f), x) + x_1 = x + x_2, f_1_adjoint = rrule(f_1, x_1) + ... + y, f_N_adjoint = rrule(f_N, x_N) + function f_adjoint(ȳ) + x̄_N = f_N_adjoint(ȳ) + ... + x̄_1 = f_1_adjoint(x̄_2) + x̄ = x̄_1 + return x̄ + end + return y, f_adjoint +end +``` +You should convince yourself that this does indeed return a 2-tuple satisfying the specification above. + + + +## Part 2: Functions of Pure Functions + +The previous example demonstrated how we might treat a composition of pure `function`s of a single argument. +Here, we extend this to pure `function`s of multiple arguments. + +### Class of `function`s + +To see an example of this, consider the following computation graph: + +![linear_regression](../assets/computation_graph.png) + +It describes the loss function associated to linear regression, and might be written as Julia code in the following way: +```julia +function linear_regression_loss(W, X, Y) + Y_hat = X * W + eps = Y - Y_hat + l = dot(eps, eps) + return l +end +``` +As before, in order to produce a precise mathematical model for this Julia `function`, we reduce it to the composition of elementary functions. +However, in order to do so, we will have to be a little more creative in how we choose these functions. + +### Differentiable Mathematical Model + +Before writing the equations, it helps to keep one simple picture in mind: we model a Julia +program as a sequence of calls that gradually extends the list of values currently in scope. +Each call reads some of the existing values, computes a new one, and appends that new value to +the running state. The tuple-based model below is just a precise way of writing down that idea. + +We model this `function` as a function ``f`` defined as follows: + +For example, ``\varphi_1`` takes the current tuple ``(W, X, Y)``, computes ``XW``, and appends +that new value, producing ``(W, X, Y, XW)``. + +```math +\begin{align} + f :=&\, \operatorname{ret} \circ \varphi_3 \circ \varphi_2 \circ \varphi_1 \textrm{ where } \nonumber \\ + \varphi_1(W, X, Y) :=&\, (W, X, Y, XW) \nonumber \\ + \varphi_2(W, X, Y, \hat{Y}) :=&\, (W, X, Y, \hat{Y}, Y - \hat{Y}) \nonumber \\ + \varphi_3(W, X, Y, \hat{Y}, \varepsilon) :=&\, (W, X, Y, \hat{Y}, \varepsilon, \|\varepsilon\|_2^2) \nonumber \\ + \operatorname{ret}(W, X, Y, \hat{Y}, \varepsilon, l) :=&\, l \nonumber +\end{align} +``` +In words, our mathematical model for `linear_regression_loss` is the composition of four differentiable functions. The first three map from a tuple containing all variables seen so far, to a tuple containing the same variables _and_ the value returned by the operation being modeled, and the fourth simply reads off the elements of the final tuple which were passed in as arguments, and the return value. + +In general, we model the ``n``th Julia `function` _call_ with a function ``\varphi_n`` mapping from a tuple of ``D`` elements to a tuple of ``D + 1`` elements, of the form +```math +\varphi_n(x) := (x_1, \dots, x_D, g_n (a_n(x))) +``` +for some differentiable function ``g_n``, and "argument selector" function ``a_n``. +In words: each function call involves +1. preparing the arguments to be passed to the function call, (``a_n``) +2. calling the function (``g_n``), and +3. adding a new variable to the list of in-scope variables (new tuple is of length ``D + 1``). + +For example, in the case of our example above, +```math +\begin{align} + &\varphi_1:\quad a_1(x) := (x_2, x_1) &\text{ and } \quad &g_1(A,B) := AB \nonumber \\ + &\varphi_2:\quad a_2(x) := (x_3, x_4) &\text{ and } \quad &g_2(A,B) := A - B \nonumber \\ + &\varphi_3:\quad a_3(x) := x_5 &\text{ and } \quad &g_3(A) := \|A\|_2^2 \nonumber +\end{align} +``` +Note that the argument to ``a_1`` is a 3-tuple, to ``a_2`` a 4-tuple, and to ``a_3`` a 5-tuple. +Crucially, observe that ``f`` has exactly the same structure as ``g_1``, ``g_2``, and ``g_3`` -- it maps from the tuple containing its arguments to its `return` value. +This gives us a recursive structure which is essential for making AD work. + +``\operatorname{ret}`` always just maps from a tuple to the last element of that tuple. + +### Differentiating the Mathematical Model + +Dropping the subscript ``n``, functions such as ``\varphi`` have derivative +```math +D [\varphi, x] (\dot{x}) = (\dot{x}_1, \dots, \dot{x}_D, D [g \circ a, x] (\dot{x})). +``` +Letting ``\bar{y} := (\bar{y}_1, \dots \bar{y}_{D+1})``, we can perform the usual manipulations to find the adjoint of ``D[\varphi, x]``: +```math +\begin{align} + \langle \bar{y}, D[\varphi, x](\dot{x}) \rangle &= \langle (\bar{y}_1, \dots, \bar{y}_{D+1}), (\dot{x}_1, \dots, \dot{x}_D, D[g \circ a, x](\dot{x})) \rangle \nonumber \\ + &= \sum_{d=1}^D \langle \bar{y}_d, \dot{x}_d \rangle + \langle D[g \circ a, x]^\ast (\bar{y}_{D+1}), \dot{x} \rangle \nonumber \\ + &= \langle (\bar{y}_1, \dots, \bar{y}_D), \dot{x} \rangle + \langle D[g \circ a, x]^\ast (\bar{y}_{D+1}), \dot{x} \rangle \nonumber \\ + &= \langle (\bar{y}_1, \dots, \bar{y}_D) + D[g \circ a, x]^\ast (\bar{y}_{D+1}), \dot{x} \rangle. \nonumber +\end{align} +``` +So, what is ``D[g \circ a, x]^\ast (\bar{y}_{D+1})``? +First let ``z := a(x)`` and observe that ``D[a, x] = a`` since ``a`` is linear. +It follows that +```math +D[g \circ a, x]^\ast = (D[g, z] \circ D[a, x])^\ast = (D[g, z] \circ a)^\ast = a^\ast \circ D[g, z]^\ast . +``` + +Since ``g`` has the same form as ``f``, assume that we know ``D[g, z]^\ast`` by induction. +An expression for ``a^\ast``, on the other hand, we obtain from its definition directly. + +As discussed ``a`` maps from a tuple containing all variables passed in to ``f``, to the tuple which is to be passed to ``g``. +For example, suppose variables `x_1, x_2, x_3` are in-scope, then a call of the form +1. `g(x_2, x_1)` has argument selector ``a(x) := (x_2, x_1)``, and +2. `g(x_3, x_3)` has argument selector ``a(x) := (x_3, x_3)``, +where ``x`` is a 3-tuple in both examples. +The general form of an argument selector ``a`` mapping from a ``D``-tuple to an ``N``-tuple is +```math +a(x) = (x_{i_1}, \dots, x_{i_N}) +``` +for some set of integers ``i_1, \dots, i_N \in \{1, \dots, D\}``. +Let ``z = (z_1, \dots, z_N)``, and ``c_n(z)`` the ``D``-tuple which is equal to ``z_n`` at element ``i_n``, and zero everywhere else. Then adjoint of ``a`` is obtained in the usual manner: +```math +\begin{align} +\langle z, a^\ast (x) \rangle &= \langle (z_1, \dots, z_N), (x_{i_1}, \dots, x_{i_N}) \rangle = \sum_{n=1}^N \langle z_n, x_{i_n} \rangle = \sum_{n=1}^N \langle c_n(z), x \rangle = \langle \sum_{n=1}^N c_n(z), x \rangle, \nonumber +\end{align} +``` +from which we conclude +```math +a^\ast(z) = \sum_{n=1}^N c_n(z). \nonumber +``` +Applying this result to our previous examples, we see that +1. when ``a(x) := (x_2, x_1)``, ``a^\ast (z) := (z_2, z_1, 0)``, and that +2. when ``a(x) := (x_3, x_3)``, ``a^\ast (z) := (0, 0, z_1 + z_2)``. + +Combining this result with the adjoint of the derivative of ``\varphi`` yields +```math +D[\varphi, x]^\ast (\bar{y}) = (\bar{y}_1, \dots, \bar{y}_D) + \sum_{n=1}^N c_n(D [g, z]^\ast (\bar{y}_{D+1})). +``` + +We must also find the derivative of the adjoint of ``\operatorname{ret}``. +It is linear, so it is its own derivative. +Assume ``f`` is of the form +```math +f = \operatorname{ret} \circ \varphi_P \circ \dots \circ \varphi_1 +``` +and that its arguments are ``N``-tuples. +Then ``\operatorname{ret}`` maps from ``(N + P)``-tuples to a single value, and its adjoint is a +tuple of length ``N + P`` given by +```math +\operatorname{ret}^\ast(\bar{y}) = (0, \dots, 0, \bar{y}). +``` +Finally, the adjoint of the derivative of ``f`` is +```math +D [f,x]^\ast = D[\varphi_1, x_1]^\ast \circ \dots \circ D [\varphi_P, x_P]^\ast \circ \operatorname{ret}^\ast +``` +where ``x_p := \varphi_{p-1}(x_{p-1})`` and ``x_1 := x``. +Since we now have expressions for all of the terms in this, we consider how to produce a programme which implements this adjoint. + +### Implementation + +We start by revisiting our Julia `function` which computes the loss associated to linear regression, as it is easiest to start with a concrete example: +```julia +function f(W, X, Y) + Y_hat = X * W + eps = Y - Y_hat + l = dot(eps, eps) + return l +end +``` + +Let the function `rule` map from a `function` and its arguments to a 2-tuple comprising the return value of that function, and a closure which computes the adjoint. +The closure maps from a gradient vector, associated to the return value of the function, to a tuple containing the gradient vectors associated to each of the arguments to the function. + +Assume that we have methods of `rule` for `*`, `-`, and `dot`, then a possible implementation of `rule` for `f` is +```julia +function rule(f, W, X, Y) + Y_hat, adjoint_mul = rule(*, X, W) + eps, adjoint_minus = rule(-, Y, Y_hat) + l, adjoint_dot = rule(dot, eps, eps) + function adjoint_f(dout) + + # Seed the reverse pass with the output cotangent, and initialize zero gradients. + dl = dout + deps = zero_gradient(eps) + dY_hat = zero_gradient(Y_hat) + dY = zero_gradient(Y) + dX = zero_gradient(X) + dW = zero_gradient(W) + + # Run adjoint of `dot`. + (deps_inc_1, deps_inc_2) = adjoint_dot(dl) + + # Run adjoint of argument selector for the call to `dot`. + # Observe that the gradient w.r.t. `eps` gets incremented twice, because `eps` + # appears twice in the argument list to `dot`. This is consistent with the + # argument selector adjoint. + deps = deps + deps_inc_1 + deps = deps + deps_inc_2 + + # Run adjoint of `-`. + (dY_inc, dY_hat_inc) = adjoint_minus(deps) + + # Run adjoint of argument selector for the call to `-`. + dY = dY + dY_inc + dY_hat = dY_hat + dY_hat_inc + + # Run adjoint of `*`. + (dX_inc, dW_inc) = adjoint_mul(dY_hat) + + # Run adjoint of argument selector for the call to `*`. + dX = dX + dX_inc + dW = dW + dW_inc + + # Return the gradients w.r.t. the arguments. + return dW, dX, dY + end + return l, adjoint_f +end +``` + +At a high-level, rule derivation in the general case proceeds as follows. +First, replace all calls in the original function with calls to `rule`s. +Then define a closure which accepts a single argument. +This closure first implements the adjoint of ``r``, by assigning the argument to the closure to be the gradient of the primal `return` value, and setting the gradient of all other variables to zero. +Subsequently, for each call in the primal function, in reverse order, apply the adjoint returned by the associated rule, and apply the adjoint of the argument selector associated to the call by incrementing the value of the gradients of its arguments. +Finally, return a tuple containing the gradients w.r.t. each of the arguments to the primal function. + +There are other equivalent ways to implement the above -- see e.g. [Zygote Implementation](@ref) below. +If nothing else, the derivations in this section provide a clear route to explain what goes on there. + +At this point, we have enough to understand the bulk of many AD systems. +If you squint, the above provides a good starting point from which to understand what [Zygote.jl](https://github.com/FluxML/Zygote.jl), [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl), [PyTorch](https://pytorch.org), and [JAX](https://github.com/jax-ml/jax) do, albeit each of these systems has their own particular way of achieving the above (JAX in particular). +The defining property of each of these systems is that they (for the most part) involve pure-functions. + +## Part 3: Applying Sequences of Mutating Functions + +Julia `function` can modify their inputs. +For example, consider +```julia +function square!(x::Vector{Float64}) + x .= x .^ 2 + return nothing +end +``` +This `function` mutates (modifies / changes) the values stored in each element of `x`. +In order to model this kind of behaviour, we introduce the notion of state, as discussed in [Mooncake.jl's Rule System](@ref). +In general, we associate to a Julia `function` `f!` a differentiable function ``f : \mathcal{X} \to \mathcal{X}``, defined such that if `x` is associated to value ``x`` prior to running `f!`, it has value ``f(x)`` after running `f!`. +We call ``f`` the _transition_ _function_ associated to `f!`. +In the case of `square!`, the transition function is ``f(x) = x \odot x``. + +We first study `function`s of the following form: +```julia +function f!(x::Vector{Float64}) + f_1!(x) + f_2!(x) + ... + f_N!(x) + return nothing +end +``` +We associate to each `f_n!` its transition function ``f_n : \mathcal{X} \to \mathcal{X}``, where ``\mathcal{X} := \mathbb{R}^D`` (assume for now that the `length(x)` is not modified by any of the operations). +The transition function ``f : \mathcal{X} \to \mathcal{X}`` associated to `f!` is simply +```math +f := f_N \circ \dots \circ f_1 +``` +Therefore the adjoint of the derivative is +```math +D[f, x]^\ast = D[f_1, x_1]^\ast \circ \dots \circ D[f_N, x_N]^\ast, +``` +where ``x_{n+1} := f_n(x_n)`` and ``x_1 := x`` as usual. +How might we implement this? +The crucial difference from before is that we lose access to some (or perhaps all) of e.g. ``x_n`` when we apply `f_n!` to put `x` into the state with value ``x_{n+1}``. +This is important in the case of `square!`. +To see this, first note that the derivative of its transition function is ``D[f, x](\dot{x}) = 2 x \odot \dot{x}``, with corresponding adjoint ``D[f, x]^\ast(\bar{y}) = 2 \bar{y} \odot x``. +The fact that ``x`` appears in the adjoint expression means that we need access to the value that the programme variable `x` takes _before_ running `square!`, but this is overwritten when we run `square!`. +The same consideration applies to each `f_n` in `f!` -- if the adjoints of their transition functions involve ``x``, we need access to the value that `x` took at the point just before running `f_n!`. + +We can solve this problem by insisting that the reverse-pass of a rule for a `function` return its arguments to the state that they were in prior to running the `function`, thus granting rules license to safely assume that all state is "as they left it" from the forwards-pass. + +Under this framework, a rule for `square!` might be something like +```julia +function rule(::typeof(square!), x::CoDual{Vector{Float64}}) + # Save current value of `x` for use in the reverse-pass. + x_copy = copy(primal(x)) + + # Run primal operation. + square!(x) + + # On entry, x̄ is the gradient associated to the value of `x` after running `square!`. + function square!_reverse() + + # Reset `x` to have the value it took before running the forwards pass. + primal(x) .= x_copy + + # Modify gradient to correspond to result of adjoint of transition function. + tangent(x) .= 2 .* primal(x) + + return nothing + end + + return nothing, square!_reverse +end +``` +where `CoDual{Vector{Float64}}` is a `struct` containing the value associated to `x`, retrieved using `primal(x)`, and memory into which its gradients can be written, retrieved using `tangent(x)`. +It has value equal to the input to the transition function before running the reverse-pass, and value equal to the application of the adjoint to this input after running the reverse-pass. + +So to use the above `rule` to compute the adjoint of the transition function associated to `square!`, you might do the following: +1. initialise `x = CoDual(x, zeros(x))`, +2. compute `_, square!_reverse = rule(square!, x)`, +3. set `tangent(x)` equal to the gradient you wish to use as the input to the adjoint of the transition function associated to `square!`, +4. run `square!_reverse`, and +5. retrieve the gradient `tangent(x)`. + +Notice that this has a distinctly different style to before. +The forwards-pass happens through mutation of the argument `x`, and the reverse-pass happens entirely as a side-effect of running the closure `square!_reverse`. + +Similarly, we can implement a rule for `f!` as follows: +```julia +function rule(::typeof(f!), x::CoDual{Vector{Float64}}) + + # Run forwards-pass. + _, f_1!_reverse = rule(f_1!, x) + _, f_2!_reverse = rule(f_2!, x) + ... + _, f_N!_reverse = rule(f_N!, x) + + # Define reverse-pass -- just run all rules in reverse. + function f!_reverse() + f_N!_reverse() + ... + f_2!_reverse() + f_1!_reverse() + return nothing + end + + # Return nothing, and the new rule. + return nothing, f!_reverse +end +``` +This rule satisfies our requirement that all modifications to `x` are un-done by `f!_reverse` inductively -- we assume that each `f_n!_reverse` satisfies this require. + + +## Part 4: Computational Graphs of Mutating Functions + +## Part 5: Computational Graphs of Mutating Functions with Aliasing + + +[^implementing_mathematics_on_a_computer]: put differently, suppose that someone wrote down some equations in a paper or textbook, and gave you a piece of code which they claim is an implementation of these equations (e.g. a neural network, a probabilistic model, an ODE, etc). Under what conditions would you be satisfied that the implementation was correct? We all do this in informal ways all of the time. I propose that you apply the same set of standards here: we have written down some equations for the adjoints, and are claiming that our rule system is an implementation of these. The fact that we arrived at this set of equations by modelling a computer programme is neither here nor there for this step of the process. + +### Zygote Implementation + +For example, rather than incrementing gradients immediately after calls to adjoints, [Zygote.jl](https://github.com/FluxML/Zygote.jl) adds together the gradients of a variable immediately before the gradient is used. +The implementation of adjoint produced by Zygote is, roughly speaking, something like +```julia +function rule(f, W, X, Y) + Y_hat, adjoint_mul = rule(*, X, W) + eps, adjoint_minus = rule(-, Y, Y_hat) + l, adjoint_dot = rule(dot, eps, eps) + function adjoint_f(dout) + + # Implement adjoint of r. Assume that we have a way to produce zero gradients. + dl = dout + + # Run adjoint of `dot`. + (deps_1, deps_2) = adjoint_dot(dl) + + # Run adjoint of `-`. + deps = deps_1 + deps_2 + (dY, dY_hat) = adjoint_minus(deps) + + # Run adjoint of `*`. + (dX, dW) = adjoint_mul(dY_hat) + + # Return the gradients w.r.t. the arguments. + return dW, dX, dY + end + return l, adjoint_f +end +``` +It computes the same thing as discussed previously, but avoids redundant calls to `zero_gradient` and `+`. diff --git a/src/Mooncake.jl b/src/Mooncake.jl index 910e5ea444..2a6ee8607a 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -179,9 +179,6 @@ include("debug_mode.jl") include("stack.jl") @unstable begin -include(joinpath("interpreter", "bbcode.jl")) -using .BasicBlockCode - include(joinpath("interpreter", "contexts.jl")) include(joinpath("interpreter", "abstract_interpretation.jl")) include(joinpath("interpreter", "patch_for_319.jl")) diff --git a/src/interpreter/bbcode.jl b/src/interpreter/bbcode.jl deleted file mode 100644 index 374cdadee9..0000000000 --- a/src/interpreter/bbcode.jl +++ /dev/null @@ -1,1061 +0,0 @@ -""" - module BasicBlockCode - -See the docstring for the `BBCode` `struct` for info on this file. -""" -module BasicBlockCode - -using Graphs - -using Core.Compiler: - ReturnNode, - PhiNode, - GotoIfNot, - GotoNode, - NewInstruction, - IRCode, - SSAValue, - PiNode, - Argument -const CC = Core.Compiler - -export ID, - seed_id!, - IDPhiNode, - IDGotoNode, - IDGotoIfNot, - Switch, - BBlock, - phi_nodes, - terminator, - insert_before_terminator!, - collect_stmts, - compute_all_predecessors, - BBCode, - remove_unreachable_blocks!, - characterise_used_ids, - characterise_unique_predecessor_blocks, - sort_blocks!, - InstVector, - IDInstPair, - __line_numbers_to_block_numbers!, - is_reachable_return_node, - new_inst - -const _id_count::Dict{Int,Int32} = Dict{Int,Int32}() - -""" - new_inst(stmt, type=Any, flag=CC.IR_FLAG_REFINED)::NewInstruction - -Create a `NewInstruction` with fields: -- `stmt` = `stmt` -- `type` = `type` -- `info` = `CC.NoCallInfo()` -- `line` = `Int32(1)` -- `flag` = `flag` -""" -function new_inst(@nospecialize(stmt), @nospecialize(type)=Any, flag=CC.IR_FLAG_REFINED) - return NewInstruction(stmt, type, CC.NoCallInfo(), Int32(1), flag) -end - -""" - const InstVector = Vector{NewInstruction} - -Note: the `CC.NewInstruction` type is used to represent instructions because it has the -correct fields. While it is only used to represent new instrucdtions in `Core.Compiler`, it -is used to represent all instructions in `BBCode`. -""" -const InstVector = Vector{NewInstruction} - -""" - ID() - -An `ID` (read: unique name) is just a wrapper around an `Int32`. Uniqueness is ensured via a -global counter, which is incremented each time that an `ID` is created. - -This counter can be reset using `seed_id!` if you need to ensure deterministic `ID`s are -produced, in the same way that seed for random number generators can be set. -""" -struct ID - id::Int32 - function ID() - current_thread_id = Threads.threadid() - id_count = get(_id_count, current_thread_id, Int32(0)) - _id_count[current_thread_id] = id_count + Int32(1) - return new(id_count) - end -end - -Base.copy(id::ID) = id - -""" - seed_id!() - -Set the global counter used to ensure ID uniqueness to 0. This is useful when you want to -ensure determinism between two runs of the same function which makes use of `ID`s. - -This is akin to setting the random seed associated to a random number generator globally. -""" -function seed_id!() - return global _id_count[Threads.threadid()] = 0 -end - -""" - IDPhiNode(edges::Vector{ID}, values::Vector{Any}) - -Like a `PhiNode`, but `edges` are `ID`s rather than `Int32`s. -""" -struct IDPhiNode - edges::Vector{ID} - values::Vector{Any} -end - -Base.:(==)(x::IDPhiNode, y::IDPhiNode) = x.edges == y.edges && x.values == y.values - -Base.copy(node::IDPhiNode) = IDPhiNode(copy(node.edges), copy(node.values)) - -""" - IDGotoNode(label::ID) - -Like a `GotoNode`, but `label` is an `ID` rather than an `Int64`. -""" -struct IDGotoNode - label::ID -end - -Base.copy(node::IDGotoNode) = IDGotoNode(copy(node.label)) - -""" - IDGotoIfNot(cond::Any, dest::ID) - -Like a `GotoIfNot`, but `dest` is an `ID` rather than an `Int64`. -""" -struct IDGotoIfNot - cond::Any - dest::ID -end - -Base.copy(node::IDGotoIfNot) = IDGotoIfNot(copy(node.cond), copy(node.dest)) - -""" - Switch(conds::Vector{Any}, dests::Vector{ID}, fallthrough_dest::ID) - -A switch-statement node. These can be inserted in the `BBCode` representation of Julia IR. -`Switch` has the following semantics: -```julia -goto dests[1] if not conds[1] -goto dests[2] if not conds[2] -... -goto dests[N] if not conds[N] -goto fallthrough_dest -``` -where the value associated to each element of `conds` is a `Bool`, and `dests` indicate -which block to jump to. If none of the conditions are met, then we go to whichever block is -specified by `fallthrough_dest`. - -`Switch` statements are lowered into the above sequence of `GotoIfNot`s and `GotoNode`s -when converting `BBCode` back into `IRCode`, because `Switch` statements are not valid -nodes in regular Julia IR. -""" -struct Switch - conds::Vector{Any} - dests::Vector{ID} - fallthrough_dest::ID - function Switch(conds::Vector{Any}, dests::Vector{ID}, fallthrough_dest::ID) - @assert length(conds) == length(dests) - return new(conds, dests, fallthrough_dest) - end -end - -""" - Terminator = Union{Switch, IDGotoIfNot, IDGotoNode, ReturnNode} - -A Union of the possible types of a terminator node. -""" -const Terminator = Union{Switch,IDGotoIfNot,IDGotoNode,ReturnNode} - -""" - BBlock(id::ID, stmt_ids::Vector{ID}, stmts::InstVector) - -A basic block data structure (not called `BasicBlock` to avoid accidental confusion with -`CC.BasicBlock`). Forms a single basic block. - -Each `BBlock` has an `ID` (a unique name). This makes it possible to refer to blocks in a -way that does not change when additional `BBlocks` are inserted into a `BBCode`. -This differs from the positional block numbering found in `IRCode`, in which the number -associated to a basic block changes when new blocks are inserted. - -The `n`th line of code in a `BBlock` is associated to `ID` `stmt_ids[n]`, and the `n`th -instruction from `stmts`. - -Note that `PhiNode`s, `GotoIfNot`s, and `GotoNode`s should not appear in a `BBlock` -- -instead an `IDPhiNode`, `IDGotoIfNot`, or `IDGotoNode` should be used. -""" -mutable struct BBlock - id::ID - inst_ids::Vector{ID} - insts::InstVector - function BBlock(id::ID, inst_ids::Vector{ID}, insts::InstVector) - @assert length(inst_ids) == length(insts) - return new(id, inst_ids, insts) - end -end - -""" - const IDInstPair = Tuple{ID, NewInstruction} -""" -const IDInstPair = Tuple{ID,NewInstruction} - -""" - BBlock(id::ID, inst_pairs::Vector{IDInstPair}) - -Convenience constructor -- splits `inst_pairs` into a `Vector{ID}` and `InstVector` in order -to build a `BBlock`. -""" -function BBlock(id::ID, inst_pairs::Vector{IDInstPair}) - return BBlock(id, first.(inst_pairs), last.(inst_pairs)) -end - -Base.length(bb::BBlock) = length(bb.inst_ids) - -Base.copy(bb::BBlock) = BBlock(bb.id, copy(bb.inst_ids), copy(bb.insts)) - -""" - phi_nodes(bb::BBlock)::Tuple{Vector{ID}, Vector{IDPhiNode}} - -Returns all of the `IDPhiNode`s at the start of `bb`, along with their `ID`s. If there are -no `IDPhiNode`s at the start of `bb`, then both vectors will be empty. -""" -function phi_nodes(bb::BBlock) - n_phi_nodes = findlast(x -> x.stmt isa IDPhiNode, bb.insts) - if n_phi_nodes === nothing - n_phi_nodes = 0 - end - return bb.inst_ids[1:n_phi_nodes], bb.insts[1:n_phi_nodes] -end - -""" - Base.insert!(bb::BBlock, n::Int, id::ID, stmt::CC.NewInstruction)::Nothing - -Inserts `stmt` and `id` into `bb` immediately before the `n`th instruction. -""" -function Base.insert!(bb::BBlock, n::Int, id::ID, inst::NewInstruction)::Nothing - insert!(bb.inst_ids, n, id) - insert!(bb.insts, n, inst) - return nothing -end - -""" - terminator(bb::BBlock) - -Returns the terminator associated to `bb`. If the last instruction in `bb` isa -`Terminator` then that is returned, otherwise `nothing` is returned. -""" -terminator(bb::BBlock) = isa(bb.insts[end].stmt, Terminator) ? bb.insts[end].stmt : nothing - -""" - insert_before_terminator!(bb::BBlock, id::ID, inst::NewInstruction)::Nothing - -If the final instruction in `bb` is a `Terminator`, insert `inst` immediately before it. -Otherwise, insert `inst` at the end of the block. -""" -function insert_before_terminator!(bb::BBlock, id::ID, inst::NewInstruction)::Nothing - insert!(bb, length(bb.insts) + (terminator(bb) === nothing ? 1 : 0), id, inst) - return nothing -end - -""" - collect_stmts(bb::BBlock)::Vector{IDInstPair} - -Returns a `Vector` containing the `ID`s and instructions associated to each line in `bb`. -These should be assumed to be ordered. -""" -collect_stmts(bb::BBlock)::Vector{IDInstPair} = collect(zip(bb.inst_ids, bb.insts)) - -@eval begin - """ - BBCode( - blocks::Vector{BBlock} - argtypes::Vector{Any} - sptypes::Vector{CC.VarState} - linetable::Vector{Core.LineInfoNode} (v1.11 and lower) - debuginfo::CC.DebugInfoStream (v1.12+) - meta::Vector{Expr} - valid_worlds::CC.WorldRange (v1.12+) - ) - - A `BBCode` is a data structure which is similar to `IRCode`, but adds additional structure. - - In particular, a `BBCode` comprises a sequence of basic blocks (`BBlock`s), each of which - comprises a sequence of statements. Moreover, each `BBlock` has its own unique `ID`, as does - each statement. - - The consequence of this is that new basic blocks can be inserted into a `BBCode`. This is - distinct from `IRCode`, in which to create a new basic block, one must insert additional - statments which you know will create a new basic block -- this is generally quite an - unreliable process, while inserting a new `BBlock` into `BBCode` is entirely predictable. - Furthermore, inserting a new `BBlock` does not change the `ID` associated to the other - blocks, meaning that you can safely assume that references from existing basic block - terminators / phi nodes to other blocks will not be modified by inserting a new basic block. - - Additionally, since each statement in each basic block has its own unique `ID`, new - statements can be inserted without changing references between other blocks. `IRCode` also - has some support for this via its `new_nodes` field, but eventually all statements will be - renamed upon `compact!`ing the `IRCode`, meaning that the name of any given statement will - eventually change. - - Finally, note that the basic blocks in a `BBCode` support the custom `Switch` statement. - This statement is not valid in `IRCode`, and is therefore lowered into a collection of - `GotoIfNot`s and `GotoNode`s when a `BBCode` is converted back into an `IRCode`. - """ - struct BBCode - blocks::Vector{BBlock} - argtypes::Vector{Any} - sptypes::Vector{CC.VarState} - $( - if VERSION > v"1.12-" - :(debuginfo::CC.DebugInfoStream) - else - :(linetable::Vector{Core.LineInfoNode}) - end - ) - meta::Vector{Expr} - $(VERSION > v"1.12-" ? :(valid_worlds::CC.WorldRange) : nothing) - end - export BBCode -end - -""" - BBCode(ir::Union{IRCode, BBCode}, new_blocks::Vector{Block}) - -Make a new `BBCode` whose `blocks` is given by `new_blocks`, and fresh copies are made of -all other fields from `ir`. -""" -@static if VERSION > v"1.12-" - function BBCode(ir::Union{IRCode,BBCode}, new_blocks::Vector{BBlock}) - return BBCode( - new_blocks, - CC.copy(ir.argtypes), - CC.copy(ir.sptypes), - CC.copy(ir.debuginfo), - CC.copy(ir.meta), - ir.valid_worlds, - ) - end -else - function BBCode(ir::Union{IRCode,BBCode}, new_blocks::Vector{BBlock}) - return BBCode( - new_blocks, - CC.copy(ir.argtypes), - CC.copy(ir.sptypes), - CC.copy(ir.linetable), - CC.copy(ir.meta), - ) - end -end - -# Makes use of the above outer constructor for `BBCode`. -Base.copy(ir::BBCode) = BBCode(ir, copy(ir.blocks)) - -""" - compute_all_successors(ir::BBCode)::Dict{ID, Vector{ID}} - -Compute a map from the `ID` of each `BBlock` in `ir` to its possible successors. -""" -compute_all_successors(ir::BBCode)::Dict{ID,Vector{ID}} = _compute_all_successors(ir.blocks) - -""" - _compute_all_successors(blks::Vector{BBlock})::Dict{ID, Vector{ID}} - -Internal method implementing [`compute_all_successors`](@ref). This method is easier to -construct test cases for because it only requires the collection of `BBlocks`, not all of -the other stuff that goes into a `BBCode`. -""" -@noinline function _compute_all_successors(blks::Vector{BBlock})::Dict{ID,Vector{ID}} - succs = map(enumerate(blks)) do (n, blk) - is_final_block = n == length(blks) - t = terminator(blk) - if t === nothing - return is_final_block ? ID[] : ID[blks[n + 1].id] - elseif t isa IDGotoNode - return [t.label] - elseif t isa IDGotoIfNot - return is_final_block ? ID[t.dest] : ID[t.dest, blks[n + 1].id] - elseif t isa ReturnNode - return ID[] - elseif t isa Switch - return vcat(t.dests, t.fallthrough_dest) - else - error("Unhandled terminator $t") - end - end - return Dict{ID,Vector{ID}}((b.id, succ) for (b, succ) in zip(blks, succs)) -end - -""" - compute_all_predecessors(ir::BBCode)::Dict{ID, Vector{ID}} - -Compute a map from the `ID` of each `BBlock` in `ir` to its possible predecessors. -""" -function compute_all_predecessors(ir::BBCode)::Dict{ID,Vector{ID}} - return _compute_all_predecessors(ir.blocks) -end - -""" - _compute_all_predecessors(blks::Vector{BBlock})::Dict{ID, Vector{ID}} - -Internal method implementing [`compute_all_predecessors`](@ref). This method is easier to -construct test cases for because it only requires the collection of `BBlocks`, not all of -the other stuff that goes into a `BBCode`. -""" -function _compute_all_predecessors(blks::Vector{BBlock})::Dict{ID,Vector{ID}} - successor_map = _compute_all_successors(blks) - - # Initialise predecessor map to be empty. - ks = collect(keys(successor_map)) - predecessor_map = Dict{ID,Vector{ID}}(zip(ks, map(_ -> ID[], ks))) - - # Find all predecessors by iterating through the successor map. - for (k, succs) in successor_map - for succ in succs - push!(predecessor_map[succ], k) - end - end - - return predecessor_map -end - -""" - collect_stmts(ir::BBCode)::Vector{IDInstPair} - -Produce a `Vector` containing all of the statements in `ir`. These are returned in -order, so it is safe to assume that element `n` refers to the `nth` element of the `IRCode` -associated to `ir`. -""" -collect_stmts(ir::BBCode)::Vector{IDInstPair} = reduce(vcat, map(collect_stmts, ir.blocks)) - -""" - id_to_line_map(ir::BBCode) - -Produces a `Dict` mapping from each `ID` associated with a line in `ir` to its line number. -This is isomorphic to mapping to its `SSAValue` in `IRCode`. Terminators do not have `ID`s -associated to them, so not every line in the original `IRCode` is mapped to. -""" -function id_to_line_map(ir::BBCode) - lines = collect_stmts(ir) - lines_and_line_numbers = collect(zip(lines, eachindex(lines))) - ids_and_line_numbers = map(x -> (x[1][1], x[2]), lines_and_line_numbers) - return Dict(ids_and_line_numbers) -end - -concatenate_ids(bb_code::BBCode) = reduce(vcat, map(b -> b.inst_ids, bb_code.blocks)) -concatenate_stmts(bb_code::BBCode) = reduce(vcat, map(b -> b.insts, bb_code.blocks)) - -""" - control_flow_graph(bb_code::BBCode)::Core.Compiler.CFG - -Computes the `Core.Compiler.CFG` object associated to this `bb_code`. -""" -control_flow_graph(bb_code::BBCode)::Core.Compiler.CFG = _control_flow_graph(bb_code.blocks) - -""" - _control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG - -Internal function, used to implement [`control_flow_graph`](@ref). Easier to write test -cases for because there is no need to construct an ensure BBCode object, just the `BBlock`s. -""" -function _control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG - - # Get IDs of predecessors and successors. - preds_ids = _compute_all_predecessors(blks) - succs_ids = _compute_all_successors(blks) - - # Construct map from block ID to block number. - block_ids = map(b -> b.id, blks) - id_to_num = Dict{ID,Int}(zip(block_ids, collect(eachindex(block_ids)))) - - # Convert predecessor and successor IDs to numbers. - preds = map(id -> sort(map(p -> id_to_num[p], preds_ids[id])), block_ids) - succs = map(id -> sort(map(s -> id_to_num[s], succs_ids[id])), block_ids) - - # Predecessor of entry block is `0`. This needs to be added in manually. - @static if VERSION >= v"1.11" - push!(preds[1], 0) - end - - # Compute the statement numbers associated to each basic block. - index = vcat(0, cumsum(map(length, blks))) .+ 1 - basic_blocks = map(eachindex(blks)) do n - stmt_range = Core.Compiler.StmtRange(index[n], index[n + 1] - 1) - return Core.Compiler.BasicBlock(stmt_range, preds[n], succs[n]) - end - return Core.Compiler.CFG(basic_blocks, index[2:(end - 1)]) -end - -""" - _instructions_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector - -Pulls out the instructions from `insts`, and calls `__line_numbers_to_block_numbers!`. -""" -function _lines_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector - stmts = __line_numbers_to_block_numbers!(Any[x.stmt for x in insts], cfg) - return map((inst, stmt) -> NewInstruction(inst; stmt), insts, stmts) -end - -""" - __line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG) - -Converts any edges in `GotoNode`s, `GotoIfNot`s, `PhiNode`s, and `:enter` expressions which -refer to line numbers into references to block numbers. The `cfg` provides the information -required to perform this conversion. - -For context, `CodeInfo` objects have references to line numbers, while `IRCode` uses -block numbers. - -This code is copied over directly from the body of `Core.Compiler.inflate_ir!`. -""" -function __line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG) - for i in eachindex(insts) - stmt = insts[i] - if isa(stmt, GotoNode) - insts[i] = GotoNode(CC.block_for_inst(cfg, stmt.label)) - elseif isa(stmt, GotoIfNot) - insts[i] = GotoIfNot(stmt.cond, CC.block_for_inst(cfg, stmt.dest)) - elseif isa(stmt, PhiNode) - insts[i] = PhiNode( - Int32[CC.block_for_inst(cfg, Int(edge)) for edge in stmt.edges], stmt.values - ) - elseif Meta.isexpr(stmt, :enter) - stmt.args[1] = CC.block_for_inst(cfg, stmt.args[1]::Int) - insts[i] = stmt - end - end - return insts -end - -# -# Converting from IRCode to BBCode -# - -""" - BBCode(ir::IRCode) - -Convert an `ir` into a `BBCode`. Creates a completely independent data structure, so -mutating the `BBCode` returned will not mutate `ir`. - -All `PhiNode`s, `GotoIfNot`s, and `GotoNode`s will be replaced with the `IDPhiNode`s, -`IDGotoIfNot`s, and `IDGotoNode`s respectively. - -See `IRCode` for conversion back to `IRCode`. - -Note that `IRCode(BBCode(ir))` should be equal to the identity function. -""" -function BBCode(ir::IRCode) - - # Produce a new set of statements with `IDs` rather than `SSAValues` and block numbers. - insts = new_inst_vec(ir.stmts) - ssa_ids, stmts = _ssas_to_ids(insts) - block_ids, stmts = _block_nums_to_ids(stmts, ir.cfg) - - # Chop up the new statements into `BBlocks`, according to the `CFG` in `ir`. - blocks = map(zip(ir.cfg.blocks, block_ids)) do (bb, id) - return BBlock(id, ssa_ids[bb.stmts], stmts[bb.stmts]) - end - return BBCode(ir, blocks) -end - -""" - new_inst_vec(x::CC.InstructionStream) - -Convert an `Compiler.InstructionStream` into a list of `Compiler.NewInstruction`s. -""" -function new_inst_vec(x::CC.InstructionStream) - stmt = @static VERSION < v"1.11.0-rc4" ? x.inst : x.stmt - return map((v...,) -> NewInstruction(v...), stmt, x.type, x.info, x.line, x.flag) -end - -# Maps from positional names (SSAValues for nodes, Integers for basic blocks) to IDs. -const SSAToIdDict = Dict{SSAValue,ID} -const BlockNumToIdDict = Dict{Integer,ID} - -""" - _ssas_to_ids(insts::InstVector)::Tuple{Vector{ID}, InstVector} - -Assigns an ID to each line in `stmts`, and replaces each instance of an `SSAValue` in each -line with the corresponding `ID`. For example, a call statement of the form -`Expr(:call, :f, %4)` is be replaced with `Expr(:call, :f, id_assigned_to_%4)`. -""" -function _ssas_to_ids(insts::InstVector)::Tuple{Vector{ID},InstVector} - ids = map(_ -> ID(), insts) - val_id_map = SSAToIdDict(zip(SSAValue.(eachindex(insts)), ids)) - return ids, map(Base.Fix1(_ssa_to_ids, val_id_map), insts) -end - -""" - _ssa_to_ids(d::SSAToIdDict, inst::NewInstruction) - -Produce a new instance of `inst` in which all instances of `SSAValue`s are replaced with -the `ID`s prescribed by `d`, all basic block numbers are replaced with the `ID`s -prescribed by `d`, and `GotoIfNot`, `GotoNode`, and `PhiNode` instances are replaced with -the corresponding `ID` versions. -""" -function _ssa_to_ids(d::SSAToIdDict, inst::NewInstruction) - return NewInstruction(inst; stmt=_ssa_to_ids(d, inst.stmt)) -end -function _ssa_to_ids(d::SSAToIdDict, x::ReturnNode) - return isdefined(x, :val) ? ReturnNode(get(d, x.val, x.val)) : x -end -_ssa_to_ids(d::SSAToIdDict, x::Expr) = Expr(x.head, map(a -> get(d, a, a), x.args)...) -_ssa_to_ids(d::SSAToIdDict, x::PiNode) = PiNode(get(d, x.val, x.val), get(d, x.typ, x.typ)) -_ssa_to_ids(d::SSAToIdDict, x::QuoteNode) = x -_ssa_to_ids(d::SSAToIdDict, x) = x -function _ssa_to_ids(d::SSAToIdDict, x::PhiNode) - new_values = Vector{Any}(undef, length(x.values)) - for n in eachindex(x.values) - if isassigned(x.values, n) - new_values[n] = get(d, x.values[n], x.values[n]) - end - end - return PhiNode(x.edges, new_values) -end -_ssa_to_ids(d::SSAToIdDict, x::GotoNode) = x -_ssa_to_ids(d::SSAToIdDict, x::GotoIfNot) = GotoIfNot(get(d, x.cond, x.cond), x.dest) - -""" - _block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID}, InstVector} - -Assign to each basic block in `cfg` an `ID`. Replace all integers referencing block numbers -in `insts` with the corresponding `ID`. Return the `ID`s and the updated instructions. -""" -function _block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID},InstVector} - ids = map(_ -> ID(), cfg.blocks) - block_num_id_map = BlockNumToIdDict(zip(eachindex(cfg.blocks), ids)) - return ids, map(Base.Fix1(_block_num_to_ids, block_num_id_map), insts) -end - -function _block_num_to_ids(d::BlockNumToIdDict, x::NewInstruction) - return NewInstruction(x; stmt=_block_num_to_ids(d, x.stmt)) -end -function _block_num_to_ids(d::BlockNumToIdDict, x::PhiNode) - return IDPhiNode(ID[d[e] for e in x.edges], x.values) -end -_block_num_to_ids(d::BlockNumToIdDict, x::GotoNode) = IDGotoNode(d[x.label]) -_block_num_to_ids(d::BlockNumToIdDict, x::GotoIfNot) = IDGotoIfNot(x.cond, d[x.dest]) -_block_num_to_ids(d::BlockNumToIdDict, x) = x - -# -# Converting from BBCode to IRCode -# - -""" - IRCode(bb_code::BBCode) - -Produce an `IRCode` instance which is equivalent to `bb_code`. The resulting `IRCode` -shares no memory with `bb_code`, so can be safely mutated without modifying `bb_code`. - -All `IDPhiNode`s, `IDGotoIfNot`s, and `IDGotoNode`s are converted into `PhiNode`s, -`GotoIfNot`s, and `GotoNode`s respectively. - -In the resulting `bb_code`, any `Switch` nodes are lowered into a semantically-equivalent -collection of `GotoIfNot` nodes. -""" -function CC.IRCode(bb_code::BBCode) - bb_code = _lower_switch_statements(bb_code) - bb_code = _remove_double_edges(bb_code) - insts = _ids_to_line_numbers(bb_code) - cfg = control_flow_graph(bb_code) - insts = _lines_to_blocks(insts, cfg) - @static if VERSION > v"1.12-" - lines = CC.copy(bb_code.debuginfo.codelocs) - n = length(insts) - if length(lines) > 3n - resize!(lines, 3n) - elseif length(lines) < 3n - for _ in (length(lines) + 1):3n - push!(lines, 0) - end - end - return IRCode( - CC.InstructionStream( - Any[x.stmt for x in insts], - Any[x.type for x in insts], - CC.CallInfo[x.info for x in insts], - lines, - UInt32[x.flag for x in insts], - ), - cfg, - CC.copy(bb_code.debuginfo), - CC.copy(bb_code.argtypes), - CC.copy(bb_code.meta), - CC.copy(bb_code.sptypes), - bb_code.valid_worlds, - ) - else - return IRCode( - CC.InstructionStream( - Any[x.stmt for x in insts], - Any[x.type for x in insts], - CC.CallInfo[x.info for x in insts], - Int32[x.line for x in insts], - UInt32[x.flag for x in insts], - ), - cfg, - CC.copy(bb_code.linetable), - CC.copy(bb_code.argtypes), - CC.copy(bb_code.meta), - CC.copy(bb_code.sptypes), - ) - end -end - -""" - _lower_switch_statements(bb_code::BBCode) - -Converts all `Switch`s into a semantically-equivalent collection of `GotoIfNot`s. See the -`Switch` docstring for an explanation of what is going on here. -""" -function _lower_switch_statements(bb_code::BBCode) - new_blocks = Vector{BBlock}(undef, 0) - for block in bb_code.blocks - t = terminator(block) - if t isa Switch - - # Create new block without the `Switch`. - bb = BBlock(block.id, block.inst_ids[1:(end - 1)], block.insts[1:(end - 1)]) - push!(new_blocks, bb) - - # Create new blocks for each `GotoIfNot` from the `Switch`. - foreach(t.conds, t.dests) do cond, dest - blk = BBlock(ID(), [ID()], [new_inst(IDGotoIfNot(cond, dest), Any)]) - push!(new_blocks, blk) - end - - # Create a new block for the fallthrough dest. - fallthrough_inst = new_inst(IDGotoNode(t.fallthrough_dest), Any) - push!(new_blocks, BBlock(ID(), [ID()], [fallthrough_inst])) - else - push!(new_blocks, block) - end - end - return BBCode(bb_code, new_blocks) -end - -""" - _ids_to_line_numbers(bb_code::BBCode)::InstVector - -For each statement in `bb_code`, returns a `NewInstruction` in which every `ID` is replaced -by either an `SSAValue`, or an `Int64` / `Int32` which refers to an `SSAValue`. -""" -function _ids_to_line_numbers(bb_code::BBCode)::InstVector - - # Construct map from `ID`s to `SSAValue`s. - block_ids = [b.id for b in bb_code.blocks] - block_lengths = map(length, bb_code.blocks) - block_start_ssas = SSAValue.(vcat(1, cumsum(block_lengths)[1:(end - 1)] .+ 1)) - line_ids = concatenate_ids(bb_code) - line_ssas = SSAValue.(eachindex(line_ids)) - id_to_ssa_map = Dict(zip(vcat(block_ids, line_ids), vcat(block_start_ssas, line_ssas))) - - # Apply map. - return [_to_ssas(id_to_ssa_map, stmt) for stmt in concatenate_stmts(bb_code)] -end - -""" - _to_ssas(d::Dict, inst::NewInstruction) - -Like `_ssas_to_ids`, but in reverse. Converts IDs to SSAValues / (integers corresponding -to ssas). -""" -_to_ssas(d::Dict, inst::NewInstruction) = NewInstruction(inst; stmt=_to_ssas(d, inst.stmt)) -_to_ssas(d::Dict, x::ReturnNode) = isdefined(x, :val) ? ReturnNode(get(d, x.val, x.val)) : x -_to_ssas(d::Dict, x::Expr) = Expr(x.head, map(a -> get(d, a, a), x.args)...) -_to_ssas(d::Dict, x::PiNode) = PiNode(get(d, x.val, x.val), get(d, x.typ, x.typ)) -_to_ssas(d::Dict, x::QuoteNode) = x -_to_ssas(d::Dict, x) = x -function _to_ssas(d::Dict, x::IDPhiNode) - new_values = Vector{Any}(undef, length(x.values)) - for n in eachindex(x.values) - if isassigned(x.values, n) - new_values[n] = get(d, x.values[n], x.values[n]) - end - end - return PhiNode(map(e -> Int32(getindex(d, e).id), x.edges), new_values) -end -_to_ssas(d::Dict, x::IDGotoNode) = GotoNode(d[x.label].id) -_to_ssas(d::Dict, x::IDGotoIfNot) = GotoIfNot(get(d, x.cond, x.cond), d[x.dest].id) - -""" - _remove_double_edges(ir::BBCode)::BBCode - -If the `dest` field of an `IDGotoIfNot` node in block `n` of `ir` points towards the `n+1`th -block then we have two edges from block `n` to block `n+1`. This transformation replaces all -such `IDGotoIfNot` nodes with unconditional `IDGotoNode`s pointing towards the `n+1`th block -in `ir`. -""" -function _remove_double_edges(ir::BBCode) - new_blks = map(enumerate(ir.blocks)) do (n, blk) - t = terminator(blk) - if t isa IDGotoIfNot && t.dest == ir.blocks[n + 1].id - new_insts = vcat(blk.insts[1:(end - 1)], NewInstruction(t; stmt=IDGotoNode(t.dest))) - return BBlock(blk.id, blk.inst_ids, new_insts) - else - return blk - end - end - return BBCode(ir, new_blks) -end - -""" - _build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph, Dict{ID, Int}} - -Builds a `SimpleDiGraph`, `g`, representing of the CFG associated to `blks`, where `blks` -comprises the collection of basic blocks associated to a `BBCode`. -This is a type from Graphs.jl, so constructing `g` makes it straightforward to analyse the -control flow structure of `ir` using algorithms from Graphs.jl. - -Returns a 2-tuple, whose first element is `g`, and whose second element is a map from -the `ID` associated to each basic block in `ir`, to the `Int` corresponding to its node -index in `g`. -""" -function _build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph,Dict{ID,Int}} - node_ints = collect(eachindex(blks)) - id_to_int = Dict(zip(map(blk -> blk.id, blks), node_ints)) - successors = _compute_all_successors(blks) - g = SimpleDiGraph(length(blks)) - for blk in blks, successor in successors[blk.id] - add_edge!(g, id_to_int[blk.id], id_to_int[successor]) - end - return g, id_to_int -end - -""" - _distance_to_entry(blks::Vector{BBlock})::Vector{Int} - -For each basic block in `blks`, compute the distance from it to the entry point (the first -block. The distance is `typemax(Int)` if no path from the entry point to a given node. -""" -function _distance_to_entry(blks::Vector{BBlock})::Vector{Int} - g, id_to_int = _build_graph_of_cfg(blks) - return dijkstra_shortest_paths(g, id_to_int[blks[1].id]).dists -end - -""" - sort_blocks!(ir::BBCode)::BBCode - -Ensure that blocks appear in order of distance-from-entry-point, where distance the -distance from block b to the entry point is defined to be the minimum number of basic -blocks that must be passed through in order to reach b. - -For reasons unknown (to me, Will), the compiler / optimiser needs this for inference to -succeed. Since we do quite a lot of re-ordering on the reverse-pass of AD, this is a problem -there. - -WARNING: use with care. Only use if you are confident that arbitrary re-ordering of basic -blocks in `ir` is valid. Notably, this does not hold if you have any `IDGotoIfNot` nodes in -`ir`. -""" -function sort_blocks!(ir::BBCode)::BBCode - I = sortperm(_distance_to_entry(ir.blocks)) - ir.blocks .= ir.blocks[I] - return ir -end - -""" - characterise_unique_predecessor_blocks(blks::Vector{BBlock}) -> - Tuple{Dict{ID, Bool}, Dict{ID, Bool}} - -We call a block `b` a _unique_ _predecessor_ in the control flow graph associated to `blks` -if it is the only predecessor to all of its successors. Put differently we call `b` a unique -predecessor if, whenever control flow arrives in any of the successors of `b`, we know for -certain that the previous block must have been `b`. - -Returns two `Dict`s. A value in the first `Dict` is `true` if the block associated to its -key is a unique precessor, and is `false` if not. A value in the second `Dict` is `true` if -it has a single predecessor, and that predecessor is a unique predecessor. - -*Context*: - -This information is important for optimising AD because knowing that `b` is a unique -predecessor means that -1. on the forwards-pass, there is no need to push the ID of `b` to the block stack when - passing through it, and -2. on the reverse-pass, there is no need to pop the block stack when passing through one of - the successors to `b`. - -Utilising this reduces the overhead associated to doing AD. It is quite important when -working with cheap loops -- loops where the operations performed at each iteration -are inexpensive -- for which minimising memory pressure is critical to performance. It is -also important for single-block functions, because it can be used to entirely avoid using a -block stack at all. -""" -function characterise_unique_predecessor_blocks( - blks::Vector{BBlock} -)::Tuple{Dict{ID,Bool},Dict{ID,Bool}} - - # Obtain the block IDs in order -- this ensures that we get the entry block first. - blk_ids = ID[b.id for b in blks] - preds = _compute_all_predecessors(blks) - succs = _compute_all_successors(blks) - - # The bulk of blocks can be hanled by this general loop. - is_unique_pred = Dict{ID,Bool}() - for id in blk_ids - ss = succs[id] - is_unique_pred[id] = !isempty(ss) && all(s -> length(preds[s]) == 1, ss) - end - - # If there is a single reachable return node, then that block is treated as a unique - # pred, since control can only pass "out" of the function via this block. Conversely, - # if there are multiple reachable return nodes, then execution can return to the calling - # function via any of them, so they are not unique predecessors. - # Note that the previous block sets is_unique_pred[id] to false for all nodes which - # end with a reachable return node, so the value only needs changing if there is a - # unique reachable return node. - reachable_return_blocks = filter(blks) do blk - is_reachable_return_node(terminator(blk)) - end - if length(reachable_return_blocks) == 1 - is_unique_pred[only(reachable_return_blocks).id] = true - end - - # pred_is_unique_pred is true if the unique predecessor to a block is a unique pred. - pred_is_unique_pred = Dict{ID,Bool}() - for id in blk_ids - pred_is_unique_pred[id] = length(preds[id]) == 1 && is_unique_pred[only(preds[id])] - end - - # If the entry block has no predecessors, then it can only be entered once, when the - # function is first entered. In this case, we treat it as having a unique predecessor. - entry_id = blk_ids[1] - pred_is_unique_pred[entry_id] = isempty(preds[entry_id]) - - return is_unique_pred, pred_is_unique_pred -end - -""" - is_reachable_return_node(x::ReturnNode) - -Determine whether `x` is a `ReturnNode`, and if it is, if it is also reachable. This is -purely a function of whether or not its `val` field is defined or not. -""" -is_reachable_return_node(x::ReturnNode) = isdefined(x, :val) -is_reachable_return_node(x) = false - -""" - characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID, Bool} - -For each line in `stmts`, determine whether it is referenced anywhere else in the code. -Returns a dictionary containing the results. An element is `false` if the corresponding -`ID` is unused, and `true` if is used. -""" -function characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID,Bool} - ids = first.(stmts) - insts = last.(stmts) - - # Initialise to false. - is_used = Dict{ID,Bool}(zip(ids, fill(false, length(ids)))) - - # Hunt through the instructions, flipping a value in is_used to true whenever an ID - # is encountered which corresponds to an SSA. - for inst in insts - _find_id_uses!(is_used, inst.stmt) - end - return is_used -end - -""" - _find_id_uses!(d::Dict{ID, Bool}, x) - -Helper function used in [`characterise_used_ids`](@ref). For all uses of `ID`s in `x`, set -the corresponding value of `d` to `true`. - -For example, if `x = ReturnNode(ID(5))`, then this function sets `d[ID(5)] = true`. -""" -function _find_id_uses!(d::Dict{ID,Bool}, x::Expr) - for arg in x.args - in(arg, keys(d)) && setindex!(d, true, arg) - end -end -function _find_id_uses!(d::Dict{ID,Bool}, x::IDGotoIfNot) - return in(x.cond, keys(d)) && setindex!(d, true, x.cond) -end -_find_id_uses!(::Dict{ID,Bool}, ::IDGotoNode) = nothing -function _find_id_uses!(d::Dict{ID,Bool}, x::PiNode) - return in(x.val, keys(d)) && setindex!(d, true, x.val) -end -function _find_id_uses!(d::Dict{ID,Bool}, x::IDPhiNode) - v = x.values - for n in eachindex(v) - isassigned(v, n) && in(v[n], keys(d)) && setindex!(d, true, v[n]) - end -end -function _find_id_uses!(d::Dict{ID,Bool}, x::ReturnNode) - return isdefined(x, :val) && in(x.val, keys(d)) && setindex!(d, true, x.val) -end -_find_id_uses!(d::Dict{ID,Bool}, x::QuoteNode) = nothing -_find_id_uses!(d::Dict{ID,Bool}, x) = nothing - -""" - _is_reachable(blks::Vector{BBlock})::Vector{Bool} - -Computes a `Vector` whose length is `length(blks)`. The `n`th element is `true` iff it is -possible for control flow to reach the `n`th block. -""" -_is_reachable(blks::Vector{BBlock})::Vector{Bool} = _distance_to_entry(blks) .< typemax(Int) - -""" - remove_unreachable_blocks!(ir::BBCode)::BBCode - -If a basic block in `ir` cannot possibly be reached during execution, then it can be safely -removed from `ir` without changing its functionality. -A block is unreachable if either: -1. it has no predecessors _and_ it is not the first block, or -2. all of its predecessors are themselves unreachable. - -For example, consider the following IR: -```jldoctest remove_unreachable_blocks; setup = :(using Mooncake) -julia> ir = Mooncake.ircode( - Any[Core.ReturnNode(nothing), Expr(:call, sin, 5), Core.ReturnNode(Core.SSAValue(2))], - Any[Any, Any, Any], - ); -``` -There is no possible way to reach the second basic block (lines 2 and 3). Applying this -function will therefore remove it, yielding the following: -```jldoctest remove_unreachable_blocks; setup = :(using Mooncake) -julia> Mooncake.IRCode(Mooncake.remove_unreachable_blocks!(Mooncake.BBCode(ir))) - 1 ─ return nothing -``` - -In the blocks which have not been removed, there may be references to blocks which have been -removed. For example, the `edge`s in a `PhiNode` may contain a reference to a removed block. -These references are removed in-place from these remaining blocks, so this function will (in -general) modify `ir`. -""" -remove_unreachable_blocks!(ir::BBCode) = BBCode(ir, _remove_unreachable_blocks!(ir.blocks)) - -function _remove_unreachable_blocks!(blks::Vector{BBlock}) - - # Figure out which blocks are reachable. - is_reachable = _is_reachable(blks) - - # Collect all blocks which are reachable. - remaining_blks = blks[is_reachable] - - # For each reachable block, remove any references to removed blocks. These can appear in - # `PhiNode`s with edges that come from remove blocks. - removed_block_ids = map(idx -> blks[idx].id, findall(!, is_reachable)) - for blk in remaining_blks, inst in blk.insts - stmt = inst.stmt - stmt isa IDPhiNode || continue - for n in reverse(1:length(stmt.edges)) - if stmt.edges[n] in removed_block_ids - deleteat!(stmt.edges, n) - deleteat!(stmt.values, n) - end - end - end - - return remaining_blks -end - -end diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index 7127a01e98..61fdc3aeae 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -15,6 +15,20 @@ Get the statement from `x`. This field changed name in 1.11 from `inst` to `stmt """ stmt(x::CC.Instruction) = CC.getindex(x, stmt_field_name()) +""" + new_inst(stmt, type=Any, flag=CC.IR_FLAG_REFINED)::NewInstruction + +Create a `NewInstruction` with fields: +- `stmt` = `stmt` +- `type` = `type` +- `info` = `CC.NoCallInfo()` +- `line` = `Int32(1)` +- `flag` = `flag` +""" +function new_inst(@nospecialize(stmt), @nospecialize(type)=Any, flag=CC.IR_FLAG_REFINED) + return NewInstruction(stmt, type, CC.NoCallInfo(), Int32(1), flag) +end + set_stmt!(ir::IRCode, ssa::SSAValue, a) = set_ir!(ir, ssa, stmt_field_name(), a) get_ir(ir::IRCode, idx::SSAValue) = CC.getindex(ir, idx) @@ -100,6 +114,33 @@ function __insts_to_instruction_stream(insts::Vector{Any}) ) end +""" + __line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG) + +Convert any edges in `GotoNode`s, `GotoIfNot`s, `PhiNode`s, and `:enter` expressions that +refer to line numbers into references to basic-block numbers. + +For context, `CodeInfo` uses line-number edges while `IRCode` uses block-number edges. +""" +function __line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG) + for i in eachindex(insts) + stmt = insts[i] + if isa(stmt, GotoNode) + insts[i] = GotoNode(CC.block_for_inst(cfg, stmt.label)) + elseif isa(stmt, GotoIfNot) + insts[i] = GotoIfNot(stmt.cond, CC.block_for_inst(cfg, stmt.dest)) + elseif isa(stmt, PhiNode) + insts[i] = PhiNode( + Int32[CC.block_for_inst(cfg, Int(edge)) for edge in stmt.edges], stmt.values + ) + elseif Meta.isexpr(stmt, :enter) + stmt.args[1] = CC.block_for_inst(cfg, stmt.args[1]::Int) + insts[i] = stmt + end + end + return insts +end + """ infer_ir!(ir::IRCode) -> IRCode @@ -333,6 +374,8 @@ purely a function of whether or not its `val` field is defined or not. """ is_unreachable_return_node(x::ReturnNode) = !isdefined(x, :val) is_unreachable_return_node(x) = false +is_reachable_return_node(x::ReturnNode) = isdefined(x, :val) +is_reachable_return_node(x) = false """ UnhandledLanguageFeatureException(message::String) diff --git a/src/interpreter/reverse_mode.jl b/src/interpreter/reverse_mode.jl index 21974a94c0..249d0dbfdf 100644 --- a/src/interpreter/reverse_mode.jl +++ b/src/interpreter/reverse_mode.jl @@ -1,3 +1,133 @@ +# +# Reverse-mode source-to-source transform roadmap +# +# This file is in roughly the following order: +# 1. Local reverse-mode IR types and ID utilities. +# 2. Shared closure-capture management and global AD state. +# 3. Statement translation from primal IR to forward/reverse fragments. +# 4. Callable wrapper types used by derived rules at runtime. +# 5. Deferred rule wrappers for dynamic dispatch and recursive :invoke handling. +# 6. Rule derivation entry points and IR generation. +# 7. Forward-pass assembly, CFGBlock-based lowering, and pullback assembly. +# +# The implementation starts with low-level types because later sections share them heavily, +# but the main transform entry points are `build_rrule`, `build_derived_rrule`, and +# `generate_ir`. +# + +# +# Reverse-mode local IR: IDs and CFG-local node types +# + +const _id_count::Dict{Int,Int32} = Dict{Int,Int32}() +# `seed_id!` resets per-thread counters for deterministic IR generation, so updates to the +# shared thread-id map must be serialized when rules are derived concurrently. +const _id_count_lock = ReentrantLock() + +struct ID + id::Int32 + function ID() + lock(_id_count_lock) + try + current_thread_id = Threads.threadid() + id_count = get(_id_count, current_thread_id, Int32(0)) + _id_count[current_thread_id] = id_count + Int32(1) + return new(id_count) + finally + unlock(_id_count_lock) + end + end +end + +Base.copy(id::ID) = id + +function seed_id!() + lock(_id_count_lock) + try + return global _id_count[Threads.threadid()] = 0 + finally + unlock(_id_count_lock) + end +end + +struct IDPhiNode + edges::Vector{ID} + values::Vector{Any} +end + +Base.:(==)(x::IDPhiNode, y::IDPhiNode) = x.edges == y.edges && x.values == y.values +Base.copy(node::IDPhiNode) = IDPhiNode(copy(node.edges), copy(node.values)) + +struct IDGotoNode + label::ID +end + +Base.copy(node::IDGotoNode) = IDGotoNode(copy(node.label)) + +struct IDGotoIfNot + cond::Any + dest::ID +end + +Base.copy(node::IDGotoIfNot) = IDGotoIfNot(copy(node.cond), copy(node.dest)) + +struct Switch + conds::Vector{Any} + dests::Vector{ID} + fallthrough_dest::ID + function Switch(conds::Vector{Any}, dests::Vector{ID}, fallthrough_dest::ID) + @assert length(conds) == length(dests) + return new(conds, dests, fallthrough_dest) + end +end + +const IDInstPair = Tuple{ID,NewInstruction} +const InstVector = Vector{NewInstruction} +const SSAToIdDict = Dict{SSAValue,ID} +const BlockNumToIdDict = Dict{Integer,ID} + +function characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID,Bool} + is_used = Dict{ID,Bool}() + for (id, _) in stmts + @assert !haskey(is_used, id) + is_used[id] = false + end + for (_, inst) in stmts + _find_id_uses!(is_used, inst.stmt) + end + return is_used +end + +function _find_id_uses!(d::Dict{ID,Bool}, x::Expr) + foreach(a -> _find_id_uses!(d, a), x.args) + return nothing +end +function _find_id_uses!(d::Dict{ID,Bool}, x::IDGotoIfNot) + return _find_id_uses!(d, x.cond) +end +_find_id_uses!(::Dict{ID,Bool}, ::IDGotoNode) = nothing +function _find_id_uses!(d::Dict{ID,Bool}, x::PiNode) + return _find_id_uses!(d, x.val) +end +function _find_id_uses!(d::Dict{ID,Bool}, x::IDPhiNode) + for n in eachindex(x.values) + # Normalized compiler phi nodes can leave incoming values undefined on dead edges. + isassigned(x.values, n) || continue + _find_id_uses!(d, x.values[n]) + end + return nothing +end +function _find_id_uses!(d::Dict{ID,Bool}, x::ReturnNode) + return isdefined(x, :val) ? _find_id_uses!(d, x.val) : nothing +end +_find_id_uses!(::Dict{ID,Bool}, ::QuoteNode) = nothing +_find_id_uses!(d::Dict{ID,Bool}, x::ID) = d[x] = true +_find_id_uses!(::Dict{ID,Bool}, x) = nothing + +# +# Shared closure captures and reverse-mode global state +# + """ SharedDataPairs() @@ -87,6 +217,10 @@ This data structure is used to hold "global" information associated to a particu `build_rrule`. It is used as a means of communication between `make_ad_stmts!` and the codegen which produces the forwards- and reverse-passes. +At a high level, the most important fields are the shared captures, the block-stack state used +to replay control flow, the reverse-data refs for arguments and SSA values, and the static +primal type information used while translating statements. + - `interp`: a `MooncakeInterpreter`. - `block_stack_id`: the ID associated to the block stack -- the stack which keeps track of which blocks we visited during the forwards-pass, and which is used on the reverse-pass @@ -137,7 +271,6 @@ struct ADInfo rvs_ret_type::Type end -# The constructor that you should use for ADInfo if you don't have a BBCode lying around. # See the definition of the ADInfo struct for info on the arguments. function ADInfo( interp::MooncakeInterpreter, @@ -169,35 +302,6 @@ function ADInfo( ) end -# The constructor you should use for ADInfo if you _do_ have a BBCode lying around. See the -# ADInfo struct for information regarding `interp` and `debug_mode`. -function ADInfo( - interp::MooncakeInterpreter, - ir::BBCode, - debug_mode::Bool, - fwd_ret_type::Type, - rvs_ret_type::Type, -) - arg_types = Dict{Argument,Any}( - map(((n, t),) -> (Argument(n) => CC.widenconst(t)), enumerate(ir.argtypes)) - ) - stmts = collect_stmts(ir) - ssa_insts = Dict{ID,NewInstruction}(stmts) - is_used_dict = characterise_used_ids(stmts) - Tlazy_rdata_ref = Tuple{map(lazy_zero_rdata_type ∘ CC.widenconst, ir.argtypes)...} - zero_lazy_rdata_ref = Ref{Tlazy_rdata_ref}() - return ADInfo( - interp, - arg_types, - ssa_insts, - is_used_dict, - debug_mode, - zero_lazy_rdata_ref, - fwd_ret_type, - rvs_ret_type, - ) -end - """ add_data!(info::ADInfo, data)::ID @@ -329,6 +433,10 @@ end return y::CoDual, (pb!! isa NoPullback ? pb!! : RRuleWrapperPb(pb!!, l)) end +# +# Statement translation bookkeeping +# + """ ADStmtInfo @@ -349,6 +457,18 @@ struct ADStmtInfo rvs::Vector{IDInstPair} end +struct RuleSelection + args::Tuple + rule_ref + T_pb!!::Type + output_type::Type +end + +struct BlockCommsInsts + fwds_suffix::Vector{IDInstPair} + rvs_prefix::Vector{IDInstPair} +end + """ ad_stmt_info(line::ID, comms_id::Union{ID, Nothing}, fwds, rvs) @@ -362,15 +482,66 @@ function ad_stmt_info(line::ID, comms_id::Union{ID,Nothing}, fwds, rvs) return ADStmtInfo(line, comms_id, __vec(line, fwds), __vec(line, rvs)) end +function _select_rule(stmt::Expr, line::ID, info::ADInfo, is_invoke::Bool) + args = ((is_invoke ? stmt.args[2:end] : stmt.args)...,) + arg_types = map(arg -> get_primal_type(info, arg), args) + + sig = Tuple{arg_types...} + interp = info.interp + raw_rule = if is_primitive(context_type(interp), ReverseMode, sig, interp.world) + build_primitive_rrule(sig) + elseif is_invoke + LazyDerivedRule(get_mi(stmt.args[1]), info.debug_mode) + else + DynamicDerivedRule(info.debug_mode) + end + + output_type = get_primal_type(info, line) + is_no_pullback = pullback_type(_typeof(raw_rule), arg_types) <: NoPullback + strip_zero_rdata = can_produce_zero_rdata_from_type(output_type) || is_no_pullback + wrapped_rule = strip_zero_rdata ? raw_rule : RRuleZeroWrapper(raw_rule) + rule = info.debug_mode ? DebugRRule(wrapped_rule) : wrapped_rule + + return RuleSelection( + args, + add_data_if_not_singleton!(info, rule), + pullback_type(_typeof(rule), arg_types), + output_type, + ) +end + +function _pullback_increment_stmts( + info::ADInfo, args, call_pullback_id::ID +)::Vector{IDInstPair} + increments = IDInstPair[] + for (n, arg) in enumerate(args) + rev_data_id = get_rev_data_id(info, arg) + rev_data_id === nothing && continue + + rdata_inc_id = ID() + push!( + increments, (rdata_inc_id, new_inst(Expr(:call, getfield, call_pullback_id, n))) + ) + append!(increments, increment_ref_stmts(rev_data_id, rdata_inc_id)) + end + return increments +end + __vec(line::ID, x::Any) = __vec(line, new_inst(x)) __vec(line::ID, x::NewInstruction) = IDInstPair[(line, x)] -__vec(line::ID, x::Vector{Tuple{ID,Any}}) = throw(error("boooo")) +function __vec(::ID, x::Vector{Tuple{ID,Any}}) + throw( + ArgumentError( + "Expected `Vector{IDInstPair}` but found a plain `Vector{Tuple{ID,Any}}`." + ), + ) +end __vec(line::ID, x::Vector{IDInstPair}) = x """ comms_channel(info::ADStmtInfo) -Return the element of `fwds` whose `ID` is the communcation `ID`. Returns `Nothing` if +Return the element of `fwds` whose `ID` is the communication `ID`. Returns `Nothing` if `comms_id` is `nothing`. """ function comms_channel(info::ADStmtInfo) @@ -675,15 +846,17 @@ function make_ad_stmts!(stmt::Core.UpsilonNode, ::ID, ::ADInfo) ) end -# There are quite a number of possible `Expr`s that can be encountered. Each case has its -# own comment, explaining what is going on. +# There are quite a number of possible `Expr`s that can be encountered. This `:call` / +# `:invoke` path stays mostly linear on purpose: keeping rule selection, forward emission, +# and reverse emission together makes the translated dataflow easier to inspect and avoids +# introducing helper boundaries that can interfere with inlining in unstable cases. function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) is_invoke = Meta.isexpr(stmt, :invoke) if Meta.isexpr(stmt, :call) || is_invoke - # Find the types of all arguments to this call / invoke. - args = ((is_invoke ? stmt.args[2:end] : stmt.args)...,) - arg_types = map(arg -> get_primal_type(info, arg), args) + # + # Step 1: classify the call site and choose the rule object. + # # Special case: if the result of a call to getfield is un-used, then leave the # primal statement alone (just increment arguments as usual). This was causing @@ -695,50 +868,24 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) # # This might need to be generalised to more things than just `getfield`, but at the # time of writing this comment, it's unclear whether or not this is the case. + args = ((is_invoke ? stmt.args[2:end] : stmt.args)...,) if !is_used(info, line) && get_const_primal_value(args[1]) == getfield fwds = new_inst(Expr(:call, __fwds_pass_no_ad!, map(__inc, args)...)) return ad_stmt_info(line, nothing, fwds, nothing) end - # Construct signature, and determine how the rrule is to be computed. - sig = Tuple{arg_types...} - interp = info.interp - raw_rule = if is_primitive(context_type(interp), ReverseMode, sig, interp.world) - build_primitive_rrule(sig) # intrinsic / builtin / thing we provably have rule for - elseif is_invoke - mi = get_mi(stmt.args[1]) - LazyDerivedRule(mi, info.debug_mode) # Static dispatch - else - DynamicDerivedRule(info.debug_mode) # Dynamic dispatch - end - - # Wrap the raw rule in a struct which ensures that any `ZeroRData`s are stripped - # away before the raw_rule is called. Only do this if we cannot prove that the - # output of `can_produce_zero_rdata_from_type(P)`, where `P` is the type of the - # value returned by this line. - is_no_pullback = pullback_type(_typeof(raw_rule), arg_types) <: NoPullback - tmp = can_produce_zero_rdata_from_type(get_primal_type(info, line)) - zero_wrapped_rule = (tmp || is_no_pullback) ? raw_rule : RRuleZeroWrapper(raw_rule) - - # If debug mode has been requested, use a debug rule. - rule = info.debug_mode ? DebugRRule(zero_wrapped_rule) : zero_wrapped_rule - - # If the primitive rule is a singleton, then don't bother putting it into shared - # data because it's safe to put it directly into the code. - rule_ref = add_data_if_not_singleton!(info, rule) - - # If the type of the pullback is a singleton type, then there is no need to store it - # in the shared data, it can be interpolated directly into the generated IR. - T_pb!! = pullback_type(_typeof(rule), arg_types) + selection = _select_rule(stmt, line, info, is_invoke) # - # Write forwards-pass. These statements are written out manually, as writing them - # out in a function would prevent inlining in some (all?) type-unstable situations. + # Step 2: write the forward fragment. + # + # These statements are written out manually because routing them through a helper can + # prevent inlining in type-unstable situations. # # Make arguments to rrule call. Things which are not already CoDual must be made so. codual_args = IDInstPair[] - codual_arg_ids = map(args) do arg + codual_arg_ids = map(selection.args) do arg is_active(arg) && return __inc(arg) id = ID() push!(codual_args, (id, new_inst(inc_or_const_stmt(arg, info)))) @@ -747,7 +894,7 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) # Make call to rule. rule_call_id = ID() - rule_call = Expr(:call, rule_ref, codual_arg_ids...) + rule_call = Expr(:call, selection.rule_ref, codual_arg_ids...) # Extract the output-codual from the returned tuple. raw_output_id = ID() @@ -755,13 +902,15 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) # Extract the pullback from the returned tuple. Specialise on the case that the # pullback is provably a singleton type. - if Base.issingletontype(T_pb!!) - pb = T_pb!!.instance + if Base.issingletontype(selection.T_pb!!) + pb = selection.T_pb!!.instance pb_stmt = (ID(), new_inst(nothing)) comms_id = nothing else pb = ID() - pb_stmt = (pb, new_inst(Expr(:call, getfield, rule_call_id, 2), T_pb!!)) + pb_stmt = ( + pb, new_inst(Expr(:call, getfield, rule_call_id, 2), selection.T_pb!!) + ) comms_id = pb end @@ -771,7 +920,7 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) # be optimised away in situations where the compiler is able to successfully infer # the type, so performance in performance-critical situations is unaffected. output_id = line - F = fcodual_type(get_primal_type(info, line)) + F = fcodual_type(selection.output_type) output = Expr(:call, Core.typeassert, raw_output_id, F) # Create statements associated to forwards-pass. @@ -785,9 +934,11 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) ], ) - # Make statement associated to reverse-pass. If the reverse-pass is provably a - # NoPullback, then don't bother doing anything at all. - rvs_pass = if T_pb!! <: NoPullback + # + # Step 3: write the reverse fragment. + # + # If the reverse pass is provably `NoPullback`, there is nothing to emit. + rvs_pass = if selection.T_pb!! <: NoPullback nothing else # Get the rdata which we pass into the pullback from its rdata ref. @@ -799,41 +950,19 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) # Zero out the value stored in this rdata ref now that we have its current # value. The new value is rdata, so must be an instance of a bits type, so is # safe to interpolate straight into instruction. - zero_val = zero_like_rdata_from_type(get_primal_type(info, line)) + zero_val = zero_like_rdata_from_type(selection.output_type) zero_rdata_expr = Expr(:call, setfield!, rdata_ref_id, QuoteNode(:x), zero_val) zero_rdata_ref = (ID(), new_inst(zero_rdata_expr)) # Run the pullback. The result is a tuple comprising `length(args)` elements. call_pullback_id = ID() call_pullback = (call_pullback_id, new_inst(Expr(:call, pb, rdata_output_id))) - - # For each element of the tuple returned by call_pullback, if the corresponding - # value in the primal IR is an Argument / SSA (if `get_rev_data_id` does not - # return nothing), increment the value in its rdata ref. This is equivalent to - # rdata_ref[] = increment!!(rdata_ref[], rdata_inc_resulting_from_pullback), - # but written out manually to ensure nothing fails to inline. - # If the corresponding value in the primal IR is not an Argument / SSA (e.g. it - # is a literal, a `QuoteNode`, or a `GlobalRef`), do nothing as we do not track - # gradients w.r.t. it. - tmp = map(enumerate(args)) do (n, arg) - rev_data_id = get_rev_data_id(info, arg) - - # If arg is not an SSA / Argument, then no rdata ref to inc. - rev_data_id === nothing && return nothing - - # Extract rdata from result of calling pullback. - rdata_inc_id = ID() - rdata_inc_expr = Expr(:call, getfield, call_pullback_id, n) - rdata_inc = (rdata_inc_id, new_inst(rdata_inc_expr)) - - # Construct statements to increment ref. - return vcat(rdata_inc, increment_ref_stmts(rev_data_id, rdata_inc_id)) - end - - # Concatenate all statements, and return them. + pullback_increments = _pullback_increment_stmts( + info, selection.args, call_pullback_id + ) vcat( IDInstPair[rdata_output, zero_rdata_ref, call_pullback], - reduce(vcat, filter(x -> !(x === nothing), tmp); init=IDInstPair[]), + pullback_increments, ) end return ad_stmt_info(line, comms_id, fwds, rvs_pass) @@ -936,8 +1065,10 @@ __get_primal(x) = x const RuleMC{A,R} = MistyClosure{OpaqueClosure{A,R}} # -# Runners for generated code. The main job of these functions is to handle the translation -# between differing varargs conventions. +# Runtime wrapper types for generated rules. +# +# These wrappers sit on the hot path once a rule has already been derived. Their main job is +# to hide closure/capture details and translate between differing varargs conventions. # struct Pullback{Tprimal,Tpb_args,Tpb_ret,isva,nargs} @@ -1032,137 +1163,295 @@ function __unflatten_codual_varargs(isva::Bool, args, ::Val{nargs}) where {nargs end # -# Rule derivation. +# Deferred runtime rule wrappers for dynamic dispatch and recursive `:invoke` +# +# These wrappers live next to the other callable rule wrappers above because they are also +# part of the runtime surface seen by generated reverse-mode code. Their constructors depend +# on compilation helpers such as `build_rrule` and `rule_type`, which are defined later. # -_get_sig(sig::Type) = sig -_get_sig(mi::Core.MethodInstance) = mi.specTypes -_get_sig(mc::MistyClosure) = Tuple{map(CC.widenconst, mc.ir[].argtypes)...} - -""" -Flatten the signature of a vararg method to group the -possibly multiple vararg arguments (what users pass to the function) -into a single tuple argument matching `ir.argtypes`. """ -function flatten_va_sig(sig, isva, nargs) - @nospecialize sig - return if isva - Tuple{sig.parameters[1:(nargs - 1)]...,Tuple{sig.parameters[nargs:end]...}} - else - sig - end -end + DynamicDerivedRule(interp::MooncakeInterpreter, debug_mode::Bool) -function forwards_ret_type(primal_ir::IRCode) - return fcodual_type(compute_ir_rettype(primal_ir)) -end +For internal use only. -function pullback_ret_type(primal_ir::IRCode) - return Tuple{map(rdata_type ∘ tangent_type ∘ CC.widenconst, primal_ir.argtypes)...} -end +A callable data structure which, when invoked, calls an rrule specific to the dynamic types +of its arguments. Stores rules in an internal cache to avoid re-deriving. -struct MooncakeRuleCompilationError <: Exception - interp::MooncakeInterpreter - sig +This is used to implement dynamic dispatch. +""" +struct DynamicDerivedRule{V} + cache::V debug_mode::Bool - cause::Exception end -function Base.showerror(io::IO, err::MooncakeRuleCompilationError) - msg_lines = ( - "MooncakeRuleCompilationError: an error occurred while Mooncake was compiling a", - "rule to differentiate something. If the `caused by` error message below does", - "not make it clear to you how the problem can be fixed, please open an issue", - "at github.com/chalk-lab/Mooncake.jl describing your problem.", - ) - cause_width = min(_boxed_message_width(io, "│ "), 78) - cause_lines = let lines = if hasfield(typeof(err.cause), :msg) - msg = getfield(err.cause, :msg) - if msg isa AbstractString - split(msg, '\n') - else - split(sprint(showerror, err.cause), '\n') - end - else - split(sprint(showerror, err.cause), '\n') - end - while !isempty(lines) && isempty(last(lines)) - pop!(lines) - end - wrapped_lines = String[] - for line in lines - append!(wrapped_lines, _wrap_boxed_line(line, cause_width)) - end - wrapped_lines - end - detail_lines = ("Caused by:", cause_lines..., "", msg_lines...) +DynamicDerivedRule(debug_mode::Bool) = DynamicDerivedRule(Dict{Any,Any}(), debug_mode) - # Print the source location of the method being differentiated, if available. - try - m = lookup_method(err.sig) - if m !== nothing - mstr = sprint(show, m) - header, location = let parts = split(mstr, " @ "; limit=2) - length(parts) == 2 ? (parts[1], parts[2]) : (mstr, nothing) - end - _print_boxed_error( - io, - ( - "Mooncake failed to differentiate the following method:", - header, - "", - detail_lines..., - ); - footer=isnothing(location) ? nothing : "@ $location", - ) - println(io) # blank line before the main error body - else - _print_boxed_error(io, detail_lines) - println(io) - end - catch e - # If method lookup fails for any reason, skip gracefully. - @debug "MooncakeRuleCompilationError: method lookup failed" exception = e - _print_boxed_error(io, detail_lines) - println(io) - end - println(io, "To replicate this error run the following:\n") - println( - io, - "Mooncake.build_rrule(Mooncake.$(err.interp), $(err.sig); debug_mode=$(err.debug_mode))", - ) - return println( - io, - "\nNote that you may need to `using` some additional packages if not all of the " * - "names printed in the above signature are available currently in your environment.", - ) -end +# Create new dynamic rule with empty cache and same debug mode +_copy(x::P) where {P<:DynamicDerivedRule} = P(Dict{Any,Any}(), x.debug_mode) -""" - build_rrule(args...; kwargs...) +function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any,N}) where {N} -Helper method: equivalent to extracting the signature from `args` and calling -`build_rrule(sig; kwargs...)`. -""" -function build_rrule(args...; kwargs...) - interp = get_interpreter(ReverseMode) - return build_rrule(interp, _typeof(TestUtils.__get_primals(args)); kwargs...) + # `Base._stable_typeof` is used here, rather than `typeof` or `Mooncake._typeof`. Its + # precise behaviour (equivalent to `typeof` for everything except `Type`s, for which it + # returns `Type{P}` rather than `typeof(P)`) is needed to ensure that this signature + # matches the types that `rule` sees when `rule(args...)` is called below. If you get + # this wrong, an assertion is violated, causing a hard-to-debug error (see issue 660). + sig = Tuple{map(Base._stable_typeof ∘ primal, args)...} + + rule = get(dynamic_rule.cache, sig, nothing) + if rule === nothing + interp = get_interpreter(ReverseMode) + rule = build_rrule(interp, sig; debug_mode=dynamic_rule.debug_mode) + dynamic_rule.cache[sig] = rule + end + return __call_rule(rule, args) end """ - build_rrule(sig::Type{<:Tuple}; kwargs...) + LazyDerivedRule(interp, mi::Core.MethodInstance, debug_mode::Bool) -Helper method: Equivalent to -`build_rrule(Mooncake.get_interpreter(ReverseMode), sig; kwargs...)`. -""" -function build_rrule(sig::Type{<:Tuple}; kwargs...) - return build_rrule(get_interpreter(ReverseMode), sig; kwargs...) -end +For internal use only. -const MOONCAKE_INFERENCE_LOCK = ReentrantLock() +A type-stable wrapper around a `DerivedRule`, which only instantiates the `DerivedRule` +when it is first called. This is useful, as it means that if a rule does not get run, it +does not have to be derived. -struct DerivedRuleInfo - primal_ir::IRCode +If `debug_mode` is `true`, then the rule constructed will be a `DebugRRule`. This is useful +when debugging, but should usually be switched off for production code as it (in general) +incurs some runtime overhead. + +Note: the signature of the primal for which this is a rule is stored in the type. The only +reason to keep this around is for debugging -- it is very helpful to have this type visible +in the stack trace when something goes wrong, as it allows you to trivially determine which +bit of your code is the culprit. + +# Extended Help + +There are two main reasons why deferring the construction of a `DerivedRule` until we need +to use it is crucial. + +The first is to do with recursion. Consider the following function: +```julia +f(x) = x > 0 ? f(x - 1) : x +``` +If we generate the `IRCode` for this function, we will see something like the following: +```julia +julia> Base.code_ircode_by_type(Tuple{typeof(f), Float64})[1][1] +1 1 ─ %1 = Base.lt_float(0.0, _2)::Bool + │ %2 = Base.or_int(%1, false)::Bool + └── goto #6 if not %2 + 2 ─ %4 = Base.sub_float(_2, 1.0)::Float64 + │ %5 = Base.lt_float(0.0, %4)::Bool + │ %6 = Base.or_int(%5, false)::Bool + └── goto #4 if not %6 + 3 ─ %8 = Base.sub_float(%4, 1.0)::Float64 + │ %9 = invoke Main.f(%8::Float64)::Float64 + └── goto #5 + 4 ─ goto #5 + 5 ┄ %12 = φ (#3 => %9, #4 => %4)::Float64 + └── return %12 + 6 ─ return _2 +``` +Suppose that we decide to construct a `DerivedRule` immediately whenever we find an +`:invoke` statement in a rule that we're currently building a `DerivedRule` for. +In the above example, we produce an infinite recursion when we attempt to produce a +`DerivedRule` for %9, because it has the same signature as the call which generates this IR. +By instead adopting a policy of constructing a `LazyDerivedRule` whenever we encounter an +`:invoke` statement, we avoid this problem. + +The second reason that delaying the construction of a `DerivedRule`, is essential is that it +ensures that we don't derive rules for method instances which aren't run. Suppose that +function B contains code for which we can't derive a rule -- perhaps it contains an +unsupported language feature like a `PhiCNode` or an `UpsilonNode`. Suppose that function A +contains an `:invoke` which refers to function `B`, but that this call is on a branch which +deals with error handling, and doesn't get run run unless something goes wrong. By deferring +the derivation of the rule for B, we only ever attempt to derive it if we land on this +error handling branch. Conversely, if we attempted to derive the rule for B when we derive +the rule for A, we would be unable to complete the derivation of the rule for A. +""" +mutable struct LazyDerivedRule{primal_sig,Trule} + debug_mode::Bool + mi::Core.MethodInstance + rule::Trule + function LazyDerivedRule(mi::Core.MethodInstance, debug_mode::Bool) + interp = get_interpreter(ReverseMode) + return new{mi.specTypes,rule_type(interp, mi;debug_mode)}(debug_mode, mi) + end + function LazyDerivedRule{Tprimal_sig,Trule}( + mi::Core.MethodInstance, debug_mode::Bool + ) where {Tprimal_sig,Trule} + return new{Tprimal_sig,Trule}(debug_mode, mi) + end +end + +# Create new lazy rule with same method instance and debug mode +_copy(x::P) where {P<:LazyDerivedRule} = P(x.mi, x.debug_mode) + +# On Julia 1.10, the generic __call_rule fallback is @stable-checked and returns Any for +# LazyDerivedRule, triggering TypeInstabilityError when dispatch_doctor_mode = "error". +# Add type-asserting specialisations so callers in @stable contexts see a concrete type. +# LazyDerivedRule doesn't contain an OpaqueClosure directly, so no inferencebarrier needed. +@static if VERSION < v"1.11-" + @inline function __call_rule( + rule::LazyDerivedRule{sig,DerivedRule{Tp,FA,FR,RA,RR,isva,Val{pnargs}}}, args::A + ) where {sig,Tp,FA,FR,RA,RR,isva,pnargs,A<:Tuple} + return rule(args...)::Tuple{FR,Pullback{Tp,RA,RR,isva,fieldcount(A)}} + end + @inline function __call_rule( + rule::LazyDerivedRule{ + sig,DebugRRule{DerivedRule{Tp,FA,CoDual{P,FD},RA,RR,isva,Val{pnargs}}} + }, + args::A, + ) where {sig,Tp,FA,P,FD,RA,RR,isva,pnargs,A<:Tuple} + return rule( + args... + )::Tuple{CoDual{P,FD},DebugPullback{Pullback{Tp,RA,RR,isva,fieldcount(A)},P}} + end +end + +@inline function (rule::LazyDerivedRule)(args::Vararg{Any,N}) where {N} + return isdefined(rule, :rule) ? __call_rule(rule.rule, args) : _build_rule!(rule, args) +end + +@noinline function _build_rule!(rule::LazyDerivedRule{sig,Trule}, args) where {sig,Trule} + interp = get_interpreter(ReverseMode) + rule.rule = build_rrule(interp, rule.mi; debug_mode=rule.debug_mode) + return __call_rule(rule.rule, args) +end + +# +# Rule derivation entry points and compile-time helpers +# + +_get_sig(sig::Type) = sig +_get_sig(mi::Core.MethodInstance) = mi.specTypes +_get_sig(mc::MistyClosure) = Tuple{map(CC.widenconst, mc.ir[].argtypes)...} + +""" +Flatten the signature of a vararg method to group the +possibly multiple vararg arguments (what users pass to the function) +into a single tuple argument matching `ir.argtypes`. +""" +function flatten_va_sig(sig, isva, nargs) + @nospecialize sig + return if isva + Tuple{sig.parameters[1:(nargs - 1)]...,Tuple{sig.parameters[nargs:end]...}} + else + sig + end +end + +function forwards_ret_type(primal_ir::IRCode) + return fcodual_type(compute_ir_rettype(primal_ir)) +end + +function pullback_ret_type(primal_ir::IRCode) + return Tuple{map(rdata_type ∘ tangent_type ∘ CC.widenconst, primal_ir.argtypes)...} +end + +struct MooncakeRuleCompilationError <: Exception + interp::MooncakeInterpreter + sig + debug_mode::Bool + cause::Exception +end + +function Base.showerror(io::IO, err::MooncakeRuleCompilationError) + msg_lines = ( + "MooncakeRuleCompilationError: an error occurred while Mooncake was compiling a", + "rule to differentiate something. If the `caused by` error message below does", + "not make it clear to you how the problem can be fixed, please open an issue", + "at github.com/chalk-lab/Mooncake.jl describing your problem.", + ) + cause_width = min(_boxed_message_width(io, "│ "), 78) + cause_lines = let lines = if hasfield(typeof(err.cause), :msg) + msg = getfield(err.cause, :msg) + if msg isa AbstractString + split(msg, '\n') + else + split(sprint(showerror, err.cause), '\n') + end + else + split(sprint(showerror, err.cause), '\n') + end + while !isempty(lines) && isempty(last(lines)) + pop!(lines) + end + wrapped_lines = String[] + for line in lines + append!(wrapped_lines, _wrap_boxed_line(line, cause_width)) + end + wrapped_lines + end + detail_lines = ("Caused by:", cause_lines..., "", msg_lines...) + + # Print the source location of the method being differentiated, if available. + try + m = lookup_method(err.sig) + if m !== nothing + mstr = sprint(show, m) + header, location = let parts = split(mstr, " @ "; limit=2) + length(parts) == 2 ? (parts[1], parts[2]) : (mstr, nothing) + end + _print_boxed_error( + io, + ( + "Mooncake failed to differentiate the following method:", + header, + "", + detail_lines..., + ); + footer=isnothing(location) ? nothing : "@ $location", + ) + println(io) # blank line before the main error body + else + _print_boxed_error(io, detail_lines) + println(io) + end + catch e + # If method lookup fails for any reason, skip gracefully. + @debug "MooncakeRuleCompilationError: method lookup failed" exception = e + _print_boxed_error(io, detail_lines) + println(io) + end + println(io, "To replicate this error run the following:\n") + println( + io, + "Mooncake.build_rrule(Mooncake.$(err.interp), $(err.sig); debug_mode=$(err.debug_mode))", + ) + return println( + io, + "\nNote that you may need to `using` some additional packages if not all of the " * + "names printed in the above signature are available currently in your environment.", + ) +end + +""" + build_rrule(args...; kwargs...) + +Helper method: equivalent to extracting the signature from `args` and calling +`build_rrule(sig; kwargs...)`. +""" +function build_rrule(args...; kwargs...) + interp = get_interpreter(ReverseMode) + return build_rrule(interp, _typeof(TestUtils.__get_primals(args)); kwargs...) +end + +""" + build_rrule(sig::Type{<:Tuple}; kwargs...) + +Helper method: Equivalent to +`build_rrule(Mooncake.get_interpreter(ReverseMode), sig; kwargs...)`. +""" +function build_rrule(sig::Type{<:Tuple}; kwargs...) + return build_rrule(get_interpreter(ReverseMode), sig; kwargs...) +end + +const MOONCAKE_INFERENCE_LOCK = ReentrantLock() + +struct DerivedRuleInfo + primal_ir::IRCode fwd_ir::IRCode fwd_ret_type::Type rvs_ir::IRCode @@ -1209,7 +1498,7 @@ function build_rrule_checks( throw( ArgumentError( "World age associated to interp is behind current world age. Please " * - "a new interpreter for the current world age.", + "create a new interpreter for the current world age.", ), ) end @@ -1225,7 +1514,7 @@ function build_derived_rrule( ) where {C} @nospecialize sig_or_mi sig - # We don't have a hand-coded rule, so derived one. + # No hand-coded rule exists, so derive one from compiler IR. lock(MOONCAKE_INFERENCE_LOCK) try # If we've already derived the OpaqueClosures and info, do not re-derive, just @@ -1234,7 +1523,7 @@ function build_derived_rrule( if haskey(interp.oc_cache, oc_cache_key) return _copy(interp.oc_cache[oc_cache_key]) else - # Derive forwards- and reverse-pass IR, and shove in `MistyClosure`s. + # Derive the forward and reverse IR, then package them into `MistyClosure`s. dri = try generate_ir(interp, sig_or_mi; debug_mode) catch err @@ -1276,7 +1565,7 @@ end """ generate_ir( interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true - ) +) Used by `build_rrule`, and the various debugging tools: primal_ir, fwds_ir, adjoint_ir. """ function generate_ir( @@ -1290,7 +1579,7 @@ function generate_ir( # function runs. seed_id!() - # Grab code associated to the primal. + # Look up the inferred primal IR. ir, _ = lookup_ir(interp, sig_or_mi) @static if VERSION > v"1.12-" ir = set_valid_world!(ir, interp.world) @@ -1317,34 +1606,51 @@ function generate_ir( end end - # Normalise the IR, and generated BBCode version of it. + # Reverse mode now starts from normalized IRCode and uses the local CFG builder directly. isva, spnames = is_vararg_and_sparam_names(sig_or_mi) ir = normalise!(ir, spnames) - primal_ir = remove_unreachable_blocks!(BBCode(ir)) + primal_blocks = _remove_unreachable_cfg_blocks!(_ircode_to_cfg_blocks(ir)) # Compute global info. - info = ADInfo(interp, primal_ir, debug_mode, fwd_ret_type, rvs_ret_type) + arg_types = Dict{Argument,Any}( + map(((n, t),) -> (Argument(n) => CC.widenconst(t)), enumerate(ir.argtypes)) + ) + primal_stmts = IDInstPair[inst for block in primal_blocks for inst in block.insts] + ssa_insts = Dict{ID,NewInstruction}(primal_stmts) + is_used_dict = characterise_used_ids(primal_stmts) + Tlazy_rdata_ref = Tuple{map(lazy_zero_rdata_type ∘ CC.widenconst, ir.argtypes)...} + zero_lazy_rdata_ref = Ref{Tlazy_rdata_ref}() + info = ADInfo( + interp, + arg_types, + ssa_insts, + is_used_dict, + debug_mode, + zero_lazy_rdata_ref, + fwd_ret_type, + rvs_ret_type, + ) - # For each block in the fwds and pullback BBCode, translate all statements. Running this - # will, in general, push items to `info.shared_data_pairs`. - ad_stmts_blocks = map(primal_ir.blocks) do primal_blk - ids = primal_blk.inst_ids - primal_stmts = map(x -> x.stmt, primal_blk.insts) - return (primal_blk.id, make_ad_stmts!.(primal_stmts, ids, Ref(info))) + # For each block in the primal CFG, translate all statements. Running this will, in + # general, push items to `info.shared_data_pairs`. + ad_stmts_blocks = map(primal_blocks) do primal_blk + ids = first.(primal_blk.insts) + stmts = map(x -> x[2].stmt, primal_blk.insts) + return (primal_blk.id, make_ad_stmts!.(stmts, ids, Ref(info))) end - # Make shared data, and construct BBCode for forwards-pass and pullback. - fwds_comms_insts, pb_comms_insts = create_comms_insts!(ad_stmts_blocks, info) + # Make shared data, and construct IR for the forwards-pass and pullback. + block_comms = create_comms_insts!(ad_stmts_blocks, info) shared_data = shared_data_tuple(info.shared_data_pairs) fwd_ir = forwards_pass_ir( - primal_ir, ad_stmts_blocks, fwds_comms_insts, info, _typeof(shared_data) + ir, primal_blocks, ad_stmts_blocks, block_comms, info, _typeof(shared_data) ) rvs_ir = pullback_ir( - primal_ir, Treturn, ad_stmts_blocks, pb_comms_insts, info, _typeof(shared_data) + ir, primal_blocks, Treturn, ad_stmts_blocks, block_comms, info, _typeof(shared_data) ) - opt_fwd_ir = do_optimize ? optimise_ir!(IRCode(fwd_ir); do_inline) : IRCode(fwd_ir) - opt_rvs_ir = do_optimize ? optimise_ir!(IRCode(rvs_ir); do_inline) : IRCode(rvs_ir) + opt_fwd_ir = do_optimize ? optimise_ir!(fwd_ir; do_inline) : fwd_ir + opt_rvs_ir = do_optimize ? optimise_ir!(rvs_ir; do_inline) : rvs_ir return DerivedRuleInfo( ir, opt_fwd_ir, fwd_ret_type, opt_rvs_ir, rvs_ret_type, shared_data, info, isva ) @@ -1380,6 +1686,10 @@ end const ADStmts = Vector{Tuple{ID,Vector{ADStmtInfo}}} +# +# Forward-pass communication and CFG assembly +# + """ create_comms_insts!(ad_stmts_blocks::ADStmts, info::ADInfo) @@ -1399,12 +1709,11 @@ For each basic block represented in `ADStmts`: shared data by the instructions generated by the previous point, and assigned them to the `comms_id`s. -Returns two a `Tuple{Vector{IDInstPair}, Vector{IDInstPair}`. The nth element of each -`Vector` corresponds to the instructions to be inserted into the forwards- and reverse -passes resp. for the nth block in `ad_stmts_blocks`. +Returns a `Vector{BlockCommsInsts}`. The nth element contains the forward-pass suffix and +reverse-pass prefix associated to the nth block in `ad_stmts_blocks`. """ function create_comms_insts!(ad_stmts_blocks::ADStmts, info::ADInfo) - insts = map(ad_stmts_blocks) do (_, ad_stmts) + return map(ad_stmts_blocks) do (_, ad_stmts) # Get the communication channel for each stmt which has one. comms_channels = filter(!=(nothing), map(comms_channel, ad_stmts)) @@ -1427,130 +1736,542 @@ function create_comms_insts!(ad_stmts_blocks::ADStmts, info::ADInfo) (ID(), new_inst(Expr(:call, push!, comms_stack_id, tuple_id))), ] - # Create instructions for reverse-pass to pop comms stack and extract elements of - # tuple into comms ids. - rvs_insts = IDInstPair[ - (tuple_id, new_inst(Expr(:call, pop!, comms_stack_id))), - map(enumerate(comms_ids)) do (n, id) - (id, new_inst(Expr(:call, getfield, tuple_id, n))) - end..., - ] + # Create instructions for reverse-pass to pop comms stack and extract elements of + # tuple into comms ids. + rvs_insts = IDInstPair[ + (tuple_id, new_inst(Expr(:call, pop!, comms_stack_id))), + map(enumerate(comms_ids)) do (n, id) + (id, new_inst(Expr(:call, getfield, tuple_id, n))) + end..., + ] + + return BlockCommsInsts(fwds_insts, rvs_insts) + end +end + +""" + forwards_pass_ir( + ir::IRCode, + primal_blocks, + ad_stmts_blocks::ADStmts, + block_comms, + info::ADInfo, + Tshared_data, + ) + +Produce the IR associated to the `OpaqueClosure` which runs most of the forwards-pass. +""" +function forwards_pass_ir( + ir::IRCode, + primal_blocks, + ad_stmts_blocks::ADStmts, + block_comms, + info::ADInfo, + Tshared_data, +) + is_unique_pred, pred_is_unique_pred = _characterise_unique_predecessor_blocks( + primal_blocks + ) + + # Insert a block at the start which extracts all items from the captures field of the + # `OpaqueClosure`, which contains all of the data shared between the forwards- and + # reverse-passes. These are assigned to the `ID`s given by the `SharedDataPairs`. + # Push the entry id onto the block stack if needed. Create `LazyZeroRData` for each + # argument, and put it in the `Ref` for use on the reverse-pass. + sds = shared_data_stmts(info.shared_data_pairs) + if pred_is_unique_pred[primal_blocks[1].id] + push_block_stack_insts = IDInstPair[] + else + push_block_stack_stmt = Expr( + :call, __push_blk_stack!, info.block_stack_id, info.entry_id.id + ) + push_block_stack_insts = [(ID(), new_inst(push_block_stack_stmt))] + end + lazy_zero_rdata_stmt = Expr( + :call, + __assemble_lazy_zero_rdata, + info.lazy_zero_rdata_ref_id, + map(n -> Argument(n + 1), 1:num_args(info))..., + ) + lazy_zero_rdata_insts = [(ID(), new_inst(lazy_zero_rdata_stmt))] + entry_stmts = vcat(sds, lazy_zero_rdata_insts, push_block_stack_insts) + entry_block = CFGBlock(info.entry_id, entry_stmts) + + # Construct augmented version of each basic block from the primal. For each block: + # 1. pull the translated basic block statements from ad_stmts_blocks, + # 2. insert a series of statements to log the contents of the `comms_id`s -- see + # the `comms_id` field of `ADStmtInfo`, + # 3. insert a statement which logs the ID of the current block (if necessary to know + # how to perform the reverse-pass), + # 4. return the CFG block. + blocks = map(ad_stmts_blocks, block_comms) do (block_id, ad_stmts), comms + + # Extract the `fwds` fields from the stmts, and create the block for the fwds pass. + insts = reduce(vcat, map(x -> x.fwds, ad_stmts)) + + # Insert communication instructions. See `create_comms_insts!` for an explanation. + for stack_inst in comms.fwds_suffix + _insert_before_terminator!(insts, stack_inst) + end + + # Log the ID of the current basic block. This is needed to know which basic block to + # jump to during the reverse-pass if the current block is not the unique predecessor + # of each of its successors (in which case there is no need to log that control + # passed through this block as opposed to any other). + if !is_unique_pred[block_id] + ins_stmt = Expr(:call, __push_blk_stack!, info.block_stack_id, block_id.id) + _insert_before_terminator!(insts, (ID(), new_inst(ins_stmt))) + end + + return CFGBlock(block_id, insts) + end + + # Lower the forwards-pass CFG directly to `IRCode`. + arg_types = vcat(Tshared_data, map(fcodual_type ∘ CC.widenconst, ir.argtypes)) + return lower_cfg_blocks_to_ir( + ir, arg_types, vcat([entry_block], blocks); sort_cfg=false + ) +end + +""" + __push_blk_stack!(block_stack::BlockStack, id::Int32) + +Equivalent to `push!(block_stack, id)`. Going via this function, rather than just calling +push! directly, is helpful for debugging and performance analysis -- it makes it very +straightforward to figure out much time is spent pushing to the block stack when profiling. +""" +@inline __push_blk_stack!(block_stack::BlockStack, id::Int32) = push!(block_stack, id) + +__lazy_zero_rdata_primal(T, x) = lazy_zero_rdata(T, primal(x)) + +@inline @generated function __assemble_lazy_zero_rdata( + r::Ref{T}, args::Vararg{CoDual,N} +) where {T<:Tuple,N} + return :(r[] = tuple_map(__lazy_zero_rdata_primal, $(fieldtypes(T)), args)) +end + +# +# CFGBlock working IR +# +# Reverse mode assembles new control flow in this local representation first, then lowers the +# finished CFG back to compiler IR in one step. +# + +""" + CFGBlock(id::ID, insts::Vector{IDInstPair}) + +Reverse-mode-local basic block representation used while assembling reverse-mode CFGs before +lowering to compiler IR. +""" +struct CFGBlock + id::ID + insts::Vector{IDInstPair} +end + +function _remap_assigned_phi_values(f, values::Vector{Any})::Vector{Any} + # Keep dead-edge phi slots undefined while remapping the assigned incoming values. + new_values = Vector{Any}(undef, length(values)) + for n in eachindex(values) + isassigned(values, n) && (new_values[n] = f(values[n])) + end + return new_values +end + +# +# `IRCode` -> `CFGBlock` conversion +# + +function _ssa_to_ids(d::SSAToIdDict, inst::NewInstruction) + return NewInstruction(inst; stmt=_ssa_to_ids(d, inst.stmt)) +end +function _ssa_to_ids(d::SSAToIdDict, x::ReturnNode) + return isdefined(x, :val) ? ReturnNode(get(d, x.val, x.val)) : x +end +_ssa_to_ids(d::SSAToIdDict, x::Expr) = Expr(x.head, map(a -> get(d, a, a), x.args)...) +_ssa_to_ids(d::SSAToIdDict, x::PiNode) = PiNode(get(d, x.val, x.val), get(d, x.typ, x.typ)) +_ssa_to_ids(::SSAToIdDict, x) = x +function _ssa_to_ids(d::SSAToIdDict, x::PhiNode) + return PhiNode(x.edges, _remap_assigned_phi_values(v -> get(d, v, v), x.values)) +end +_ssa_to_ids(d::SSAToIdDict, x::GotoIfNot) = GotoIfNot(get(d, x.cond, x.cond), x.dest) + +function _ssas_to_ids(insts::InstVector)::Tuple{Vector{ID},InstVector} + ids = map(_ -> ID(), insts) + val_id_map = SSAToIdDict(zip(SSAValue.(eachindex(insts)), ids)) + return ids, map(Base.Fix1(_ssa_to_ids, val_id_map), insts) +end + +function _block_num_to_ids(d::BlockNumToIdDict, x::NewInstruction) + return NewInstruction(x; stmt=_block_num_to_ids(d, x.stmt)) +end +function _block_num_to_ids(d::BlockNumToIdDict, x::PhiNode) + return IDPhiNode(ID[d[e] for e in x.edges], x.values) +end +_block_num_to_ids(d::BlockNumToIdDict, x::GotoNode) = IDGotoNode(d[x.label]) +_block_num_to_ids(d::BlockNumToIdDict, x::GotoIfNot) = IDGotoIfNot(x.cond, d[x.dest]) +_block_num_to_ids(::BlockNumToIdDict, x) = x + +function _block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID},InstVector} + ids = map(_ -> ID(), cfg.blocks) + block_num_id_map = BlockNumToIdDict(zip(eachindex(cfg.blocks), ids)) + return ids, map(Base.Fix1(_block_num_to_ids, block_num_id_map), insts) +end + +function _ircode_to_cfg_blocks(ir::IRCode)::Vector{CFGBlock} + # Reuse the shared cross-version stmt accessor rather than branching on field names here. + stmts = map( + (stmt, type, info, line, flag) -> NewInstruction(stmt, type, info, line, flag), + stmt(ir.stmts), + ir.stmts.type, + ir.stmts.info, + ir.stmts.line, + ir.stmts.flag, + ) + ssa_ids, stmts = _ssas_to_ids(stmts) + block_ids, stmts = _block_nums_to_ids(stmts, ir.cfg) + return map(zip(block_ids, ir.cfg.blocks)) do (block_id, bb) + CFGBlock(block_id, collect(zip(ssa_ids[bb.stmts], stmts[bb.stmts]))) + end +end + +_cfg_terminator(stmt) = stmt isa Union{Switch,IDGotoIfNot,IDGotoNode,ReturnNode} +function _cfg_terminator(block::CFGBlock) + isempty(block.insts) && return nothing + stmt = last(block.insts)[2].stmt + return _cfg_terminator(stmt) ? stmt : nothing +end + +function _cfg_phi_nodes(block::CFGBlock) + # Phi nodes are only valid at the start of a block, so stop at the first non-phi. + n_phi_nodes = findfirst(x -> !(x[2].stmt isa IDPhiNode), block.insts) + n_phi_nodes = isnothing(n_phi_nodes) ? length(block.insts) : n_phi_nodes - 1 + return first.(block.insts[1:n_phi_nodes]), last.(block.insts[1:n_phi_nodes]) +end + +# +# CFG analysis and canonicalization helpers +# + +function _compute_cfg_successors(blocks::Vector{CFGBlock})::Dict{ID,Vector{ID}} + succs = Dict{ID,Vector{ID}}() + for (n, block) in enumerate(blocks) + is_final_block = n == length(blocks) + t = _cfg_terminator(block) + if t === nothing + succs[block.id] = is_final_block ? ID[] : ID[blocks[n + 1].id] + elseif t isa IDGotoNode + succs[block.id] = ID[t.label] + elseif t isa IDGotoIfNot + succs[block.id] = is_final_block ? ID[t.dest] : ID[t.dest, blocks[n + 1].id] + elseif t isa ReturnNode + succs[block.id] = ID[] + elseif t isa Switch + succs[block.id] = vcat(t.dests, t.fallthrough_dest) + else + error("Unhandled terminator $t") + end + end + return succs +end + +function _compute_cfg_predecessors(blocks::Vector{CFGBlock})::Dict{ID,Vector{ID}} + successor_map = _compute_cfg_successors(blocks) + predecessor_map = Dict{ID,Vector{ID}}(block.id => ID[] for block in blocks) + for (k, succs) in successor_map + for succ in succs + push!(predecessor_map[succ], k) + end + end + return predecessor_map +end + +function _cfg_distance_to_entry(blocks::Vector{CFGBlock})::Vector{Int} + id_to_int = Dict(zip(map(block -> block.id, blocks), eachindex(blocks))) + successors = _compute_cfg_successors(blocks) + distances = fill(typemax(Int), length(blocks)) + distances[1] = 0 + queue = [blocks[1].id] + head = 1 + while head <= length(queue) + block_id = queue[head] + head += 1 + dist = distances[id_to_int[block_id]] + for successor in successors[block_id] + successor_idx = id_to_int[successor] + if distances[successor_idx] == typemax(Int) + distances[successor_idx] = dist + 1 + push!(queue, successor) + end + end + end + return distances +end + +function _sort_cfg_blocks!(blocks::Vector{CFGBlock})::Vector{CFGBlock} + I = sortperm(_cfg_distance_to_entry(blocks)) + blocks .= blocks[I] + return blocks +end + +function _remove_unreachable_cfg_blocks!(blocks::Vector{CFGBlock})::Vector{CFGBlock} + is_reachable = _cfg_distance_to_entry(blocks) .< typemax(Int) + remaining_blocks = blocks[is_reachable] + removed_block_ids = map(idx -> blocks[idx].id, findall(!, is_reachable)) + for block in remaining_blocks, (_, inst) in block.insts + stmt = inst.stmt + stmt isa IDPhiNode || continue + for n in reverse(1:length(stmt.edges)) + if stmt.edges[n] in removed_block_ids + deleteat!(stmt.edges, n) + deleteat!(stmt.values, n) + end + end + end + return remaining_blocks +end + +function _characterise_unique_predecessor_blocks( + blocks::Vector{CFGBlock} +)::Tuple{Dict{ID,Bool},Dict{ID,Bool}} + block_ids = ID[block.id for block in blocks] + preds = _compute_cfg_predecessors(blocks) + succs = _compute_cfg_successors(blocks) + + is_unique_pred = Dict{ID,Bool}() + for id in block_ids + ss = succs[id] + is_unique_pred[id] = !isempty(ss) && all(s -> length(preds[s]) == 1, ss) + end + + reachable_return_blocks = filter(blocks) do block + is_reachable_return_node(_cfg_terminator(block)) + end + if length(reachable_return_blocks) == 1 + is_unique_pred[only(reachable_return_blocks).id] = true + end - return fwds_insts, rvs_insts + pred_is_unique_pred = Dict{ID,Bool}() + for id in block_ids + pred_is_unique_pred[id] = length(preds[id]) == 1 && is_unique_pred[only(preds[id])] end - return map(first, insts), map(last, insts) + + entry_id = block_ids[1] + pred_is_unique_pred[entry_id] = isempty(preds[entry_id]) + return is_unique_pred, pred_is_unique_pred end -""" - forwards_pass_ir(ir::BBCode, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data) +function _insert_before_terminator!(insts::Vector{IDInstPair}, inst::IDInstPair) + if !isempty(insts) && _cfg_terminator(last(insts)[2].stmt) + insert!(insts, length(insts), inst) + else + push!(insts, inst) + end + return insts +end + +function _canonicalise_cfg_blocks(blocks::Vector{CFGBlock}; sort_cfg::Bool=true) + blocks = copy(blocks) + # Canonicalization is "sort first, then prune" so phi-edge cleanup sees final block order. + sort_cfg && _sort_cfg_blocks!(blocks) + return _remove_unreachable_cfg_blocks!(blocks) +end + +function _cfg_lower_switch_statements(blocks::Vector{CFGBlock})::Vector{CFGBlock} + new_blocks = CFGBlock[] + for block in blocks + t = _cfg_terminator(block) + if t isa Switch + push!(new_blocks, CFGBlock(block.id, block.insts[1:(end - 1)])) + foreach(t.conds, t.dests) do cond, dest + push!( + new_blocks, + CFGBlock(ID(), [(ID(), new_inst(IDGotoIfNot(cond, dest), Any))]), + ) + end + push!( + new_blocks, + CFGBlock(ID(), [(ID(), new_inst(IDGotoNode(t.fallthrough_dest), Any))]), + ) + else + push!(new_blocks, block) + end + end + return new_blocks +end -Produce the IR associated to the `OpaqueClosure` which runs most of the forwards-pass. -""" -function forwards_pass_ir( - ir::BBCode, ad_stmts_blocks::ADStmts, fwds_comms_insts, info::ADInfo, Tshared_data -) - is_unique_pred, pred_is_unique_pred = characterise_unique_predecessor_blocks(ir.blocks) +function _cfg_remove_double_edges(blocks::Vector{CFGBlock})::Vector{CFGBlock} + return map(enumerate(blocks)) do (n, block) + t = _cfg_terminator(block) + if n < length(blocks) && t isa IDGotoIfNot && t.dest == blocks[n + 1].id + term_id, term_inst = last(block.insts) + new_insts = copy(block.insts) + new_insts[end] = (term_id, NewInstruction(term_inst; stmt=IDGotoNode(t.dest))) + return CFGBlock(block.id, new_insts) + else + return block + end + end +end - # Insert a block at the start which extracts all items from the captures field of the - # `OpaqueClosure`, which contains all of the data shared between the forwards- and - # reverse-passes. These are assigned to the `ID`s given by the `SharedDataPairs`. - # Push the entry id onto the block stack if needed. Create `LazyZeroRData` for each - # argument, and put it in the `Ref` for use on the reverse-pass. - sds = shared_data_stmts(info.shared_data_pairs) - if pred_is_unique_pred[ir.blocks[1].id] - push_block_stack_insts = IDInstPair[] - else - push_block_stack_stmt = Expr( - :call, __push_blk_stack!, info.block_stack_id, info.entry_id.id - ) - push_block_stack_insts = [(ID(), new_inst(push_block_stack_stmt))] +function _cfg_control_flow_graph(blocks::Vector{CFGBlock})::CC.CFG + preds_ids = _compute_cfg_predecessors(blocks) + succs_ids = _compute_cfg_successors(blocks) + block_ids = map(block -> block.id, blocks) + id_to_num = Dict{ID,Int}(zip(block_ids, eachindex(block_ids))) + preds = map(id -> sort(map(p -> id_to_num[p], preds_ids[id])), block_ids) + succs = map(id -> sort(map(s -> id_to_num[s], succs_ids[id])), block_ids) + @static if VERSION >= v"1.11" + push!(preds[1], 0) end - lazy_zero_rdata_stmt = Expr( - :call, - __assemble_lazy_zero_rdata, - info.lazy_zero_rdata_ref_id, - map(n -> Argument(n + 1), 1:num_args(info))..., + index = vcat(0, cumsum(map(block -> length(block.insts), blocks))) .+ 1 + basic_blocks = map(eachindex(blocks)) do n + stmt_range = CC.StmtRange(index[n], index[n + 1] - 1) + return CC.BasicBlock(stmt_range, preds[n], succs[n]) + end + return CC.CFG(basic_blocks, index[2:(end - 1)]) +end + +function _cfg_to_ssas(d::Dict, inst::NewInstruction) + return NewInstruction(inst; stmt=_cfg_to_ssas(d, inst.stmt)) +end +function _cfg_to_ssas(d::Dict, x::ReturnNode) + isdefined(x, :val) ? ReturnNode(get(d, x.val, x.val)) : x +end +_cfg_to_ssas(d::Dict, x::Expr) = Expr(x.head, map(a -> get(d, a, a), x.args)...) +_cfg_to_ssas(d::Dict, x::PiNode) = PiNode(get(d, x.val, x.val), get(d, x.typ, x.typ)) +_cfg_to_ssas(d::Dict, x) = x +function _cfg_to_ssas(d::Dict, x::IDPhiNode) + return PhiNode( + map(edge -> Int32(getindex(d, edge).id), x.edges), + _remap_assigned_phi_values(v -> get(d, v, v), x.values), ) - lazy_zero_rdata_insts = [(ID(), new_inst(lazy_zero_rdata_stmt))] - entry_stmts = vcat(sds, lazy_zero_rdata_insts, push_block_stack_insts) - entry_block = BBlock(info.entry_id, entry_stmts) +end +_cfg_to_ssas(d::Dict, x::IDGotoNode) = GotoNode(d[x.label].id) +_cfg_to_ssas(d::Dict, x::IDGotoIfNot) = GotoIfNot(get(d, x.cond, x.cond), d[x.dest].id) - # Construct augmented version of each basic block from the primal. For each block: - # 1. pull the translated basic block statements from ad_stmts_blocks, - # 2. insert a series of statements to log the contents of the `comms_id`s -- see - # the `comms_id` field of `ADStmtInfo`, - # 3. insert a statement which logs the ID of the current block (if necessary to know - # how to perform the reverse-pass), - # 4. return the BBlock. - blocks = map(ad_stmts_blocks, fwds_comms_insts) do (block_id, ad_stmts), comms_insts +function _cfg_ids_to_line_numbers(blocks::Vector{CFGBlock})::InstVector + block_ids = map(block -> block.id, blocks) + block_lengths = map(block -> length(block.insts), blocks) + block_start_ssas = SSAValue.(vcat(1, cumsum(block_lengths)[1:(end - 1)] .+ 1)) + lines = [inst for block in blocks for inst in block.insts] + line_ids = first.(lines) + line_ssas = SSAValue.(eachindex(line_ids)) + id_to_ssa_map = Dict(zip(vcat(block_ids, line_ids), vcat(block_start_ssas, line_ssas))) + return [_cfg_to_ssas(id_to_ssa_map, inst) for (_, inst) in lines] +end - # Extract the `fwds` fields from the stmts, and create the block for the fwds pass. - blk = BBlock(block_id, reduce(vcat, map(x -> x.fwds, ad_stmts))) +function _cfg_lines_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector + stmts = __line_numbers_to_block_numbers!(Any[x.stmt for x in insts], cfg) + return map((inst, stmt) -> NewInstruction(inst; stmt), insts, stmts) +end - # Insert communcation instructions. See `create_comms_insts!` for an explanation. - for stack_inst in comms_insts - insert_before_terminator!(blk, stack_inst[1], stack_inst[2]) - end +# +# CFG line/block numbering and compiler-IR reconstruction +# - # Log the ID of the current basic block. This is needed to know which basic block to - # jump to during the reverse-pass if the current block is not the unique predecessor - # of each of its successors (in which case there is no need to log that control - # passed through this block as opposed to any other). - if !is_unique_pred[block_id] - ins_stmt = Expr(:call, __push_blk_stack!, info.block_stack_id, block_id.id) - insert_before_terminator!(blk, ID(), new_inst(ins_stmt)) +function _cfg_instruction_stream(ir::IRCode, insts::InstVector) + @static if VERSION > v"1.12-" + lines = CC.copy(ir.debuginfo.codelocs) + n = length(insts) + if length(lines) > 3n + resize!(lines, 3n) + elseif length(lines) < 3n + for _ in (length(lines) + 1):3n + push!(lines, 0) + end end - - return blk + return CC.InstructionStream( + Any[x.stmt for x in insts], + Any[x.type for x in insts], + CC.CallInfo[x.info for x in insts], + lines, + UInt32[x.flag for x in insts], + ) + else + return CC.InstructionStream( + Any[x.stmt for x in insts], + Any[x.type for x in insts], + CC.CallInfo[x.info for x in insts], + Int32[x.line for x in insts], + UInt32[x.flag for x in insts], + ) end +end - # Create and return the `BBCode` for the forwards-pass. - arg_types = vcat(Tshared_data, map(fcodual_type ∘ CC.widenconst, ir.argtypes)) - new_ir = BBCode(ir, vcat(entry_block, blocks)) +function _rebuild_ircode(ir::IRCode, arg_types, cfg::CC.CFG, insts::InstVector)::IRCode + inst_stream = _cfg_instruction_stream(ir, insts) @static if VERSION > v"1.12-" - new_ir = BBCode( - new_ir.blocks, - arg_types, - new_ir.sptypes, - new_ir.debuginfo, - new_ir.meta, - new_ir.valid_worlds, + return IRCode( + inst_stream, + cfg, + CC.copy(ir.debuginfo), + Any[arg_types...], + CC.copy(ir.meta), + CC.copy(ir.sptypes), + ir.valid_worlds, ) else - new_ir = BBCode( - new_ir.blocks, arg_types, new_ir.sptypes, new_ir.linetable, new_ir.meta + return IRCode( + inst_stream, + cfg, + CC.copy(ir.linetable), + Any[arg_types...], + CC.copy(ir.meta), + CC.copy(ir.sptypes), ) end - return remove_unreachable_blocks!(new_ir) end -""" - __push_blk_stack!(block_stack::BlockStack, id::Int32) +# +# `CFGBlock` -> `IRCode` lowering +# -Equivalent to `push!(block_stack, id)`. Going via this function, rather than just calling -push! directly, is helpful for debugging and performance analysis -- it makes it very -straightforward to figure out much time is spent pushing to the block stack when profiling. """ -@inline __push_blk_stack!(block_stack::BlockStack, id::Int32) = push!(block_stack, id) + lower_cfg_blocks_to_ir(ir::IRCode, arg_types, blocks::Vector{CFGBlock}; sort_cfg=true) -__lazy_zero_rdata_primal(T, x) = lazy_zero_rdata(T, primal(x)) - -@inline @generated function __assemble_lazy_zero_rdata( - r::Ref{T}, args::Vararg{CoDual,N} -) where {T<:Tuple,N} - return :(r[] = tuple_map(__lazy_zero_rdata_primal, $(fieldtypes(T)), args)) +Lower reverse-mode-local CFG blocks directly to `IRCode`. +""" +function lower_cfg_blocks_to_ir( + ir::IRCode, arg_types, blocks::Vector{CFGBlock}; sort_cfg::Bool=true +)::IRCode + blocks = _canonicalise_cfg_blocks(blocks; sort_cfg) + blocks = _cfg_remove_double_edges(_cfg_lower_switch_statements(blocks)) + insts = _cfg_ids_to_line_numbers(blocks) + cfg = _cfg_control_flow_graph(blocks) + insts = _cfg_lines_to_blocks(insts, cfg) + return _rebuild_ircode(ir, arg_types, cfg, insts) end +# +# Pullback CFG assembly +# + """ - pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data) + pullback_ir( + ir::IRCode, + primal_blocks, + Tret, + ad_stmts_blocks::ADStmts, + block_comms, + info::ADInfo, + Tshared_data, + ) Produce the IR associated to the `OpaqueClosure` which runs most of the pullback. """ function pullback_ir( - ir::BBCode, Tret, ad_stmts_blocks::ADStmts, pb_comms_insts, info::ADInfo, Tshared_data + ir::IRCode, + primal_blocks, + Tret, + ad_stmts_blocks::ADStmts, + block_comms, + info::ADInfo, + Tshared_data, ) # Compute the blocks which return in the primal. - primal_exit_blocks_inds = findall(is_reachable_return_node ∘ terminator, ir.blocks) + primal_exit_blocks_inds = findall( + is_reachable_return_node ∘ _cfg_terminator, primal_blocks + ) # # Short-circuit for non-terminating primals -- applies to a tiny fraction of primals: @@ -1560,14 +2281,8 @@ function pullback_ir( # terminates without throwing, meaning that if AD hits this function, it definitely # won't succeed on the forwards-pass. As such, the reverse-pass can just be a no-op. if isempty(primal_exit_blocks_inds) - blocks = [BBlock(ID(), [(ID(), new_inst(ReturnNode(nothing)))])] - @static if VERSION >= v"1.12-" - return BBCode( - blocks, Any[Any], ir.sptypes, ir.debuginfo, ir.meta, ir.valid_worlds - ) - else - return BBCode(blocks, Any[Any], ir.sptypes, ir.linetable, ir.meta) - end + blocks = [CFGBlock(ID(), [(ID(), new_inst(ReturnNode(nothing)))])] + return lower_cfg_blocks_to_ir(ir, Any[Any], blocks) end # @@ -1585,9 +2300,9 @@ function pullback_ir( # no need to pop the block stack. data_stmts = shared_data_stmts(info.shared_data_pairs) rev_data_ref_stmts = reverse_data_ref_stmts(info) - exit_blocks_ids = map(n -> ir.blocks[n].id, primal_exit_blocks_inds) + exit_blocks_ids = map(n -> primal_blocks[n].id, primal_exit_blocks_inds) switch_stmts = make_switch_stmts(exit_blocks_ids, length(exit_blocks_ids) == 1, info) - entry_block = BBlock(ID(), vcat(data_stmts, rev_data_ref_stmts, switch_stmts)) + entry_block = CFGBlock(ID(), vcat(data_stmts, rev_data_ref_stmts, switch_stmts)) # For each basic block in the primal: # 1. if the block is reachable on the reverse-pass, the bulk of its statements are the @@ -1603,15 +2318,15 @@ function pullback_ir( # characterise_unique_predecessor_blocks is used in forwards_pass_ir). # 4. if the block began with one or more PhiNodes, then handle their rdata. # 5. jump to the predecessor block. - ps = compute_all_predecessors(ir) - _, pred_is_unique_pred = characterise_unique_predecessor_blocks(ir.blocks) + ps = _compute_cfg_predecessors(primal_blocks) + _, pred_is_unique_pred = _characterise_unique_predecessor_blocks(primal_blocks) main_blocks = map( - ad_stmts_blocks, enumerate(ir.blocks), pb_comms_insts - ) do (blk_id, ad_stmts), (n, blk), comms_insts + ad_stmts_blocks, enumerate(primal_blocks), block_comms + ) do (blk_id, ad_stmts), (n, blk), comms # Short-circuit if we know that this block cannot be reached on the reverse-pass. - if is_unreachable_return_node(terminator(blk)) - return BBlock(blk_id, [(ID(), new_inst(nothing))]) + if is_unreachable_return_node(_cfg_terminator(blk)) + return CFGBlock(blk_id, [(ID(), new_inst(nothing))]) end # Extract reverse-stmts from ad_stmts. @@ -1623,8 +2338,8 @@ function pullback_ir( additional_stmts, new_blocks = conclude_rvs_block(blk, pred_ids, tmp, info) # Combine all blocks and return. See `create_comms_insts!` for more info regarding - # `comms_insts`. - rvs_block = BBlock(blk_id, vcat(comms_insts, rvs_ad_stmts, additional_stmts)) + # `comms`. + rvs_block = CFGBlock(blk_id, vcat(comms.rvs_prefix, rvs_ad_stmts, additional_stmts)) return vcat(rvs_block, new_blocks) end main_blocks = reduce(vcat, main_blocks) @@ -1679,7 +2394,7 @@ function pullback_ir( # Construct return node and assemble final basic block. ret = new_inst(ReturnNode(assert_id)) - exit_block = BBlock( + exit_block = CFGBlock( info.entry_id, vcat( (lazy_zero_rdata_tuple_id, lazy_zero_rdata_tuple), @@ -1688,36 +2403,29 @@ function pullback_ir( ), ) - # Create and return `BBCode` for the pullback. Sort the blocks and remove any blocks - # which are unreachable, in the sense that they have no predecessors (except the entry - # block). This ought not to be necessary, but _appears_ to be necessary in order to - # avoid annoying the Julia compiler. - blks = vcat(entry_block, main_blocks, exit_block) - @static if VERSION >= v"1.12-" - pb_ir = BBCode(blks, arg_types, ir.sptypes, ir.debuginfo, ir.meta, ir.valid_worlds) - else - pb_ir = BBCode(blks, arg_types, ir.sptypes, ir.linetable, ir.meta) - end - return remove_unreachable_blocks!(sort_blocks!(pb_ir)) + # Lower the pullback CFG directly to `IRCode`. + return lower_cfg_blocks_to_ir( + ir, arg_types, vcat([entry_block], main_blocks, [exit_block]) + ) end """ conclude_rvs_block( - blk::BBlock, pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo + blk::CFGBlock, pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo ) Generates code which is inserted at the end of each counterpart block in the reverse-pass. Handles phi nodes, and choosing the correct next block to switch to. """ function conclude_rvs_block( - blk::BBlock, pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo + blk::CFGBlock, pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo ) # Get the PhiNodes and their IDs. - phi_ids, phis = phi_nodes(blk) + phi_ids, phis = _cfg_phi_nodes(blk) # If there are no PhiNodes in this block, switch directly to the predecessor. if length(phi_ids) == 0 - return make_switch_stmts(pred_ids, pred_is_unique_pred, info), BBlock[] + return make_switch_stmts(pred_ids, pred_is_unique_pred, info), CFGBlock[] end # Create statements which extract + zero the rdata refs associated to them. @@ -1729,7 +2437,7 @@ function conclude_rvs_block( end deref_stmts = reduce(vcat, tmp; init=IDInstPair[]) - # For each predecessor, create a `BBlock` which processes its corresponding edge in + # For each predecessor, create a `CFGBlock` which processes its corresponding edge in # each of the `PhiNode`s. new_blocks = map(pred_ids) do pred_id values = Any[__get_value(pred_id, p.stmt) for p in phis] @@ -1770,7 +2478,7 @@ end """ rvs_phi_block(pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo) -Produces a `BBlock` which runs the reverse-pass for the edge associated to `pred_id` in a +Produces a `CFGBlock` which runs the reverse-pass for the edge associated to `pred_id` in a collection of `IDPhiNode`s, and then goes to the block associated to `pred_id`. For example, suppose that we encounter the following collection of `PhiNode`s at the start @@ -1816,7 +2524,7 @@ function rvs_phi_block( end inc_stmts = reduce(vcat, filter(x -> !(x === nothing), tmp); init=IDInstPair[]) goto_stmt = (ID(), new_inst(IDGotoNode(pred_id))) - return BBlock(ID(), vcat(inc_stmts, goto_stmt)) + return CFGBlock(ID(), vcat(inc_stmts, goto_stmt)) end """ @@ -1892,156 +2600,6 @@ Helper function emitted by `make_switch_stmts`. """ __switch_case(id::Int32, predecessor_id::Int32) = !(id === predecessor_id) -""" - DynamicDerivedRule(interp::MooncakeInterpreter, debug_mode::Bool) - -For internal use only. - -A callable data structure which, when invoked, calls an rrule specific to the dynamic types -of its arguments. Stores rules in an internal cache to avoid re-deriving. - -This is used to implement dynamic dispatch. -""" -struct DynamicDerivedRule{V} - cache::V - debug_mode::Bool -end - -DynamicDerivedRule(debug_mode::Bool) = DynamicDerivedRule(Dict{Any,Any}(), debug_mode) - -# Create new dynamic rule with empty cache and same debug mode -_copy(x::P) where {P<:DynamicDerivedRule} = P(Dict{Any,Any}(), x.debug_mode) - -function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any,N}) where {N} - - # `Base._stable_typeof` is used here, rather than `typeof` or `Mooncake._typeof`. Its - # precise behaviour (equivalent to `typeof` for everything except `Type`s, for which it - # returns `Type{P}` rather than `typeof(P)`) is needed to ensure that this signature - # matches the types that `rule` sees when `rule(args...)` is called below. If you get - # this wrong, an assertion is violated, causing a hard-to-debug error (see issue 660). - sig = Tuple{map(Base._stable_typeof ∘ primal, args)...} - - rule = get(dynamic_rule.cache, sig, nothing) - if rule === nothing - interp = get_interpreter(ReverseMode) - rule = build_rrule(interp, sig; debug_mode=dynamic_rule.debug_mode) - dynamic_rule.cache[sig] = rule - end - return __call_rule(rule, args) -end - -""" - LazyDerivedRule(interp, mi::Core.MethodInstance, debug_mode::Bool) - -For internal use only. - -A type-stable wrapper around a `DerivedRule`, which only instantiates the `DerivedRule` -when it is first called. This is useful, as it means that if a rule does not get run, it -does not have to be derived. - -If `debug_mode` is `true`, then the rule constructed will be a `DebugRRule`. This is useful -when debugging, but should usually be switched off for production code as it (in general) -incurs some runtime overhead. - -Note: the signature of the primal for which this is a rule is stored in the type. The only -reason to keep this around is for debugging -- it is very helpful to have this type visible -in the stack trace when something goes wrong, as it allows you to trivially determine which -bit of your code is the culprit. - -# Extended Help - -There are two main reasons why deferring the construction of a `DerivedRule` until we need -to use it is crucial. - -The first is to do with recursion. Consider the following function: -```julia -f(x) = x > 0 ? f(x - 1) : x -``` -If we generate the `IRCode` for this function, we will see something like the following: -```julia -julia> Base.code_ircode_by_type(Tuple{typeof(f), Float64})[1][1] -1 1 ─ %1 = Base.lt_float(0.0, _2)::Bool - │ %2 = Base.or_int(%1, false)::Bool - └── goto #6 if not %2 - 2 ─ %4 = Base.sub_float(_2, 1.0)::Float64 - │ %5 = Base.lt_float(0.0, %4)::Bool - │ %6 = Base.or_int(%5, false)::Bool - └── goto #4 if not %6 - 3 ─ %8 = Base.sub_float(%4, 1.0)::Float64 - │ %9 = invoke Main.f(%8::Float64)::Float64 - └── goto #5 - 4 ─ goto #5 - 5 ┄ %12 = φ (#3 => %9, #4 => %4)::Float64 - └── return %12 - 6 ─ return _2 -``` -Suppose that we decide to construct a `DerivedRule` immediately whenever we find an -`:invoke` statement in a rule that we're currently building a `DerivedRule` for. -In the above example, we produce an infinite recursion when we attempt to produce a -`DerivedRule` for %9, because it has the same signature as the call which generates this IR. -By instead adopting a policy of constructing a `LazyDerivedRule` whenever we encounter an -`:invoke` statement, we avoid this problem. - -The second reason that delaying the construction of a `DerivedRule`, is essential is that it -ensures that we don't derive rules for method instances which aren't run. Suppose that -function B contains code for which we can't derive a rule -- perhaps it contains an -unsupported language feature like a `PhiCNode` or an `UpsilonNode`. Suppose that function A -contains an `:invoke` which refers to function `B`, but that this call is on a branch which -deals with error handling, and doesn't get run run unless something goes wrong. By deferring -the derivation of the rule for B, we only ever attempt to derive it if we land on this -error handling branch. Conversely, if we attempted to derive the rule for B when we derive -the rule for A, we would be unable to complete the derivation of the rule for A. -""" -mutable struct LazyDerivedRule{primal_sig,Trule} - debug_mode::Bool - mi::Core.MethodInstance - rule::Trule - function LazyDerivedRule(mi::Core.MethodInstance, debug_mode::Bool) - interp = get_interpreter(ReverseMode) - return new{mi.specTypes,rule_type(interp, mi;debug_mode)}(debug_mode, mi) - end - function LazyDerivedRule{Tprimal_sig,Trule}( - mi::Core.MethodInstance, debug_mode::Bool - ) where {Tprimal_sig,Trule} - return new{Tprimal_sig,Trule}(debug_mode, mi) - end -end - -# Create new lazy rule with same method instance and debug mode -_copy(x::P) where {P<:LazyDerivedRule} = P(x.mi, x.debug_mode) - -# On Julia 1.10, the generic __call_rule fallback is @stable-checked and returns Any for -# LazyDerivedRule, triggering TypeInstabilityError when dispatch_doctor_mode = "error". -# Add type-asserting specialisations so callers in @stable contexts see a concrete type. -# LazyDerivedRule doesn't contain an OpaqueClosure directly, so no inferencebarrier needed. -@static if VERSION < v"1.11-" - @inline function __call_rule( - rule::LazyDerivedRule{sig,DerivedRule{Tp,FA,FR,RA,RR,isva,Val{pnargs}}}, args::A - ) where {sig,Tp,FA,FR,RA,RR,isva,pnargs,A<:Tuple} - return rule(args...)::Tuple{FR,Pullback{Tp,RA,RR,isva,fieldcount(A)}} - end - @inline function __call_rule( - rule::LazyDerivedRule{ - sig,DebugRRule{DerivedRule{Tp,FA,CoDual{P,FD},RA,RR,isva,Val{pnargs}}} - }, - args::A, - ) where {sig,Tp,FA,P,FD,RA,RR,isva,pnargs,A<:Tuple} - return rule( - args... - )::Tuple{CoDual{P,FD},DebugPullback{Pullback{Tp,RA,RR,isva,fieldcount(A)},P}} - end -end - -@inline function (rule::LazyDerivedRule)(args::Vararg{Any,N}) where {N} - return isdefined(rule, :rule) ? __call_rule(rule.rule, args) : _build_rule!(rule, args) -end - -@noinline function _build_rule!(rule::LazyDerivedRule{sig,Trule}, args) where {sig,Trule} - interp = get_interpreter(ReverseMode) - rule.rule = build_rrule(interp, rule.mi; debug_mode=rule.debug_mode) - return __call_rule(rule.rule, args) -end - """ rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C} diff --git a/src/precompile.jl b/src/precompile.jl index dca70a0737..e3aeec801c 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -12,9 +12,10 @@ using PrecompileTools: @setup_workload, @compile_workload # `prepare_derivative_cache` → `value_and_derivative!!` pipelines (which internally call # `build_rrule`/`build_frule`, `generate_ir`, and all the IR-transformation infrastructure) # for both a simple scalar and a simple vector function. Because the IR-manipulation -# methods (`normalise!`, `BBCode`, `make_ad_stmts!`, …) work on `IRCode`/`BBCode` objects -# whose *Julia type* is the same regardless of which function is being differentiated, one -# call through the pipeline is enough to pre-warm the bulk of the compilation work. +# methods (`normalise!`, `make_ad_stmts!`, CFG lowering, …) work on a stable `IRCode`/CFG +# pipeline whose *Julia type* is the same regardless of which function is being +# differentiated, one call through the pipeline is enough to pre-warm the bulk of the +# compilation work. @setup_workload begin # A non-primitive scalar function: exercises the derived-rule code path end-to-end. diff --git a/src/skill_utils.jl b/src/skill_utils.jl index beff4c2e02..b29c536228 100644 --- a/src/skill_utils.jl +++ b/src/skill_utils.jl @@ -9,8 +9,7 @@ module SkillUtils using ..Mooncake: CC, IRCode, - BBCode, - BasicBlockCode, + CFGBlock, ForwardMode, ReverseMode, MooncakeInterpreter, @@ -19,7 +18,9 @@ using ..Mooncake: lookup_ir, is_vararg_and_sparam_names, normalise!, - remove_unreachable_blocks!, + _ircode_to_cfg_blocks, + _remove_unreachable_cfg_blocks!, + _compute_cfg_successors, generate_dual_ir, generate_ir, optimise_ir!, @@ -66,27 +67,29 @@ end # --- Stage Graphs --- -forward_stage_order() = [:raw, :normalized, :bbcode, :dual_ir, :optimized] +forward_stage_order() = [:raw, :normalized, :cfg_blocks, :dual_ir, :optimized] function forward_stage_graph() return [ :raw => :normalized, - :normalized => :bbcode, - :bbcode => :dual_ir, + :normalized => :cfg_blocks, + :cfg_blocks => :dual_ir, :dual_ir => :optimized, ] end function reverse_stage_order() - return [:raw, :normalized, :bbcode, :fwd_ir, :rvs_ir, :optimized_fwd, :optimized_rvs] + return [ + :raw, :normalized, :cfg_blocks, :fwd_ir, :rvs_ir, :optimized_fwd, :optimized_rvs + ] end function reverse_stage_graph() return [ :raw => :normalized, - :normalized => :bbcode, - :bbcode => :fwd_ir, - :bbcode => :rvs_ir, + :normalized => :cfg_blocks, + :cfg_blocks => :fwd_ir, + :cfg_blocks => :rvs_ir, :fwd_ir => :optimized_fwd, :rvs_ir => :optimized_rvs, ] @@ -100,11 +103,11 @@ function render_ir(ir::IRCode)::String return String(take!(io)) end -function render_ir(bb::BBCode)::String +function render_ir(blocks::Vector{CFGBlock})::String io = IOBuffer() - for (i, block) in enumerate(bb.blocks) + for (i, block) in enumerate(blocks) println(io, "Block $(i) (id=$(block.id)):") - for (id, inst) in zip(block.inst_ids, block.insts) + for (id, inst) in block.insts println(io, " $id: $(inst.stmt) :: $(inst.type)") end end @@ -134,11 +137,11 @@ function extract_meta(ir::IRCode)::StageMeta ) end -function extract_meta(bb::BBCode)::StageMeta - succs = BasicBlockCode.compute_all_successors(bb) +function extract_meta(blocks::Vector{CFGBlock})::StageMeta + succs = _compute_cfg_successors(blocks) return StageMeta(; - block_count=length(bb.blocks), - inst_count=sum(length(b.inst_ids) for b in bb.blocks), + block_count=length(blocks), + inst_count=sum(length(block.insts) for block in blocks), edge_count=sum(length(v) for v in values(succs)), ) end @@ -183,8 +186,8 @@ function primal_stages(interp, sig) normalized_ir = CC.copy(raw_ir) normalise!(normalized_ir, spnames) - bbcode = remove_unreachable_blocks!(BBCode(normalized_ir)) - return raw_ir, normalized_ir, bbcode + cfg_blocks = _remove_unreachable_cfg_blocks!(_ircode_to_cfg_blocks(normalized_ir)) + return raw_ir, normalized_ir, cfg_blocks end function primitive_dispatch_note(mode::Symbol, sig::Type)::String @@ -267,12 +270,14 @@ function inspect_ir( # Propagate generation failures so callers do not mistake partial inspection output # for a successful run. # Stage 1: Raw IR - raw_ir, normalized_ir, bbcode = primal_stages(interp, sig) + raw_ir, normalized_ir, cfg_blocks = primal_stages(interp, sig) stages[:raw] = IRStage(:raw, raw_ir, render_ir(raw_ir), extract_meta(raw_ir)) stages[:normalized] = IRStage( :normalized, normalized_ir, render_ir(normalized_ir), extract_meta(normalized_ir) ) - stages[:bbcode] = IRStage(:bbcode, bbcode, render_ir(bbcode), extract_meta(bbcode)) + stages[:cfg_blocks] = IRStage( + :cfg_blocks, cfg_blocks, render_ir(cfg_blocks), extract_meta(cfg_blocks) + ) # Mode-specific stages if mode == :forward diff --git a/test/front_matter.jl b/test/front_matter.jl index 80255ffa29..c88ac960fa 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -57,25 +57,20 @@ using Mooncake: TestResources, CoDual, DefaultCtx, - rrule!!, lgetfield, lsetfield!, Stack, _typeof, - BBCode, ID, IDPhiNode, IDGotoNode, IDGotoIfNot, - BBlock, make_ad_stmts!, ADStmtInfo, ad_stmt_info, ADInfo, SharedDataPairs, increment_field!!, - NoFData, - NoRData, zero_fcodual, zero_like_rdata_from_type, zero_rdata, @@ -83,7 +78,6 @@ using Mooncake: LazyZeroRData, lazy_zero_rdata, new_inst, - characterise_unique_predecessor_blocks, NoPullback, characterise_used_ids, InvalidFDataException, diff --git a/test/interpreter/bbcode.jl b/test/interpreter/bbcode.jl deleted file mode 100644 index d5296acb10..0000000000 --- a/test/interpreter/bbcode.jl +++ /dev/null @@ -1,273 +0,0 @@ -module BBCodeTestCases -test_phi_node(x::Ref{Union{Float32,Float64}}) = sin(x[]) -end - -@testset "bbcode" begin - @testset "ID" begin - id1 = ID() - id2 = ID() - @test id1 == id1 - @test id1 != id2 - end - @testset "BBlock" begin - bb = BBlock( - ID(), - ID[ID(), ID()], - CC.NewInstruction[ - CC.NewInstruction(IDPhiNode([ID(), ID()], Any[true, false]), Any), - CC.NewInstruction(:(println("hello")), Any), - ], - ) - @test bb isa BBlock - @test length(bb) == 2 - - ids, phi_nodes = Mooncake.phi_nodes(bb) - @test only(ids) == bb.inst_ids[1] - @test only(phi_nodes) == bb.insts[1] - - insert!(bb, 1, ID(), CC.NewInstruction(nothing, Nothing)) - @test length(bb) == 3 - @test bb.insts[1].stmt === nothing - - bb_copy = copy(bb) - @test bb_copy.inst_ids !== bb.inst_ids - - @test Mooncake.terminator(bb) === nothing - - # Final statement is a regular instruction, so the newly inserted instruction should go at - # the end of the block. - @test Mooncake.insert_before_terminator!(bb, ID(), new_inst(ReturnNode(5))) === - nothing - @test bb.insts[end].stmt === ReturnNode(5) - - # Final statement is now a Terminator, so insertion should happen before it. - @test Mooncake.insert_before_terminator!(bb, ID(), new_inst(nothing)) === nothing - @test bb.insts[end].stmt === ReturnNode(5) - @test bb.insts[end - 1].stmt === nothing - end - @testset "BBCode $f" for (f, P) in [ - (TestResources.test_while_loop, Tuple{Float64}), - (sin, Tuple{Float64}), - (BBCodeTestCases.test_phi_node, Tuple{Ref{Union{Float32,Float64}}}), - ] - ir = Base.code_ircode(f, P)[1][1] - bb_code = BBCode(ir) - @test bb_code isa BBCode - @test length(bb_code.blocks) == length(ir.cfg.blocks) - new_ir = Mooncake.IRCode(bb_code) - @test length(stmt(new_ir.stmts)) == length(stmt(ir.stmts)) - @test all(map(==, stmt(ir.stmts), stmt(new_ir.stmts))) - @test all(map(==, ir.stmts.type, new_ir.stmts.type)) - @test all(map(==, ir.stmts.info, new_ir.stmts.info)) - @test all(map(==, ir.stmts.line, new_ir.stmts.line)) - @test all(map(==, ir.stmts.flag, new_ir.stmts.flag)) - @test length(Mooncake.collect_stmts(bb_code)) == length(stmt(ir.stmts)) - @test Mooncake.BasicBlockCode.id_to_line_map(bb_code) isa Dict{ID,Int} - end - @testset "control_flow_graph" begin - ir = Base.code_ircode_by_type(Tuple{typeof(sin),Float64})[1][1] - bb = BBCode(ir) - new_ir = Core.Compiler.IRCode(bb) - cfg = Mooncake.BasicBlockCode.control_flow_graph(bb) - @test all(map((l, r) -> l.stmts == r.stmts, ir.cfg.blocks, cfg.blocks)) - @test all(map((l, r) -> sort(l.preds) == sort(r.preds), ir.cfg.blocks, cfg.blocks)) - @test all(map((l, r) -> sort(l.succs) == sort(r.succs), ir.cfg.blocks, cfg.blocks)) - @test ir.cfg.index == cfg.index - end - @testset "_characterise_unique_predecessor_blocks" begin - @testset "single block" begin - blk_id = ID() - blks = BBlock[BBlock(blk_id, [ID()], [new_inst(ReturnNode(5))])] - upreds, pred_is_upred = characterise_unique_predecessor_blocks(blks) - @test upreds[blk_id] == true - @test pred_is_upred[blk_id] == true - end - @testset "pair of blocks" begin - blk_id_1 = ID() - blk_id_2 = ID() - blks = BBlock[ - BBlock(blk_id_1, [ID()], [new_inst(IDGotoNode(blk_id_2))]), - BBlock(blk_id_2, [ID()], [new_inst(ReturnNode(5))]), - ] - upreds, pred_is_upred = characterise_unique_predecessor_blocks(blks) - @test upreds[blk_id_1] == true - @test upreds[blk_id_2] == true - @test pred_is_upred[blk_id_1] == true - @test pred_is_upred[blk_id_2] == true - end - @testset "Non-Unique Exit Node" begin - blk_id_1 = ID() - blk_id_2 = ID() - blk_id_3 = ID() - blks = BBlock[ - BBlock(blk_id_1, [ID()], [new_inst(IDGotoIfNot(true, blk_id_3))]), - BBlock(blk_id_2, [ID()], [new_inst(ReturnNode(5))]), - BBlock(blk_id_3, [ID()], [new_inst(ReturnNode(5))]), - ] - upreds, pred_is_upred = characterise_unique_predecessor_blocks(blks) - @test upreds[blk_id_1] == true - @test upreds[blk_id_2] == false - @test upreds[blk_id_3] == false - @test pred_is_upred[blk_id_1] == true - @test pred_is_upred[blk_id_2] == true - @test pred_is_upred[blk_id_3] == true - end - @testset "diamond structure of four blocks" begin - blk_id_1 = ID() - blk_id_2 = ID() - blk_id_3 = ID() - blk_id_4 = ID() - blks = BBlock[ - BBlock(blk_id_1, [ID()], [new_inst(IDGotoIfNot(true, blk_id_3))]), - BBlock(blk_id_2, [ID()], [new_inst(IDGotoNode(blk_id_4))]), - BBlock(blk_id_3, [ID()], [new_inst(IDGotoNode(blk_id_4))]), - BBlock(blk_id_4, [ID()], [new_inst(ReturnNode(0))]), - ] - upreds, pred_is_upred = characterise_unique_predecessor_blocks(blks) - @test upreds[blk_id_1] == true - @test upreds[blk_id_2] == false - @test upreds[blk_id_3] == false - @test upreds[blk_id_4] == true - @test pred_is_upred[blk_id_1] == true - @test pred_is_upred[blk_id_2] == true - @test pred_is_upred[blk_id_3] == true - @test pred_is_upred[blk_id_4] == false - end - @testset "simple loop back to first block" begin - blk_id_1 = ID() - blk_id_2 = ID() - blks = BBlock[ - BBlock(blk_id_1, [ID()], [new_inst(IDGotoIfNot(true, blk_id_1))]), - BBlock(blk_id_2, [ID()], [new_inst(ReturnNode(5))]), - ] - upreds, pred_is_upred = characterise_unique_predecessor_blocks(blks) - @test upreds[blk_id_1] == true - @test upreds[blk_id_2] == true - @test pred_is_upred[blk_id_1] == false - @test pred_is_upred[blk_id_2] == true - end - end - @testset "characterise_used_ids" begin - @testset "_find_id_uses!" begin - @testset "Expr" begin - id = ID() - d = Dict{ID,Bool}(id => false) - Mooncake.BasicBlockCode._find_id_uses!(d, Expr(:call, sin, 5)) - @test d[id] == false - Mooncake.BasicBlockCode._find_id_uses!(d, Expr(:call, sin, id)) - @test d[id] == true - end - @testset "IDGotoIfNot" begin - id = ID() - d = Dict{ID,Bool}(id => false) - Mooncake.BasicBlockCode._find_id_uses!(d, IDGotoIfNot(ID(), ID())) - @test d[id] == false - Mooncake.BasicBlockCode._find_id_uses!(d, IDGotoIfNot(true, ID())) - @test d[id] == false - Mooncake.BasicBlockCode._find_id_uses!(d, IDGotoIfNot(id, ID())) - @test d[id] == true - end - @testset "IDGotoNode" begin - id = ID() - d = Dict{ID,Bool}(id => false) - Mooncake.BasicBlockCode._find_id_uses!(d, IDGotoNode(ID())) - @test d[id] == false - end - @testset "IDPhiNode" begin - id = ID() - d = Dict{ID,Bool}(id => false) - Mooncake.BasicBlockCode._find_id_uses!( - d, IDPhiNode([ID()], Vector{Any}(undef, 1)) - ) - @test d[id] == false - Mooncake.BasicBlockCode._find_id_uses!(d, IDPhiNode([ID()], Any[id])) - @test d[id] == true - end - @testset "PiNode" begin - id = ID() - d = Dict{ID,Bool}(id => false) - Mooncake.BasicBlockCode._find_id_uses!(d, PiNode(false, Bool)) - @test d[id] == false - Mooncake.BasicBlockCode._find_id_uses!(d, PiNode(id, Bool)) - @test d[id] == true - end - @testset "ReturnNode" begin - id = ID() - d = Dict{ID,Bool}(id => false) - Mooncake.BasicBlockCode._find_id_uses!(d, ReturnNode()) - @test d[id] == false - Mooncake.BasicBlockCode._find_id_uses!(d, ReturnNode(5)) - @test d[id] == false - Mooncake.BasicBlockCode._find_id_uses!(d, ReturnNode(id)) - @test d[id] == true - end - end - @testset "some used some unused" begin - id_1 = ID() - id_2 = ID() - id_3 = ID() - stmts = Tuple{ID,Core.Compiler.NewInstruction}[ - (id_1, new_inst(Expr(:call, sin, Argument(1)))), - (id_2, new_inst(Expr(:call, cos, id_1))), - (id_3, new_inst(ReturnNode(id_2))), - ] - result = characterise_used_ids(stmts) - @test result[id_1] == true - @test result[id_2] == true - @test result[id_3] == false - end - end - @testset "_is_reachable" begin - ir = Mooncake.ircode( - Any[ - ReturnNode(nothing), - Expr(:call, sin, 5), - Core.GotoNode(4), - ReturnNode(SSAValue(2)), - ], - Any[Any for _ in 1:4], - ) - @test Mooncake.BasicBlockCode._is_reachable(BBCode(ir).blocks) == - [true, false, false] - end - @testset "remove_unreachable_blocks!" begin - - # This test case has two important features: - # 1. the second basic block (the second statement) cannot be reached, and - # 2. the PhiNode in the third basic block refers to the second basic block. Since - # the second block will be removed, the edge / value in the PhiNode corresponding - # to the second block must be removed as part of the call to - # remove_unreachable_blocks. - ir = Mooncake.ircode( - Any[ - GotoNode(3), - nothing, - PhiNode(Int32[2, 1], Any[false, true]), - ReturnNode(SSAValue(3)), - ], - Any[Any for _ in 1:4], - ) - CC.verify_ir(ir) - bb_ir = BBCode(ir) - new_bb_ir = Mooncake.remove_unreachable_blocks!(bb_ir) - - # Check that only the first and third block remain in the new IR. - @test length(new_bb_ir.blocks) == 2 - @test bb_ir.blocks[1].id == new_bb_ir.blocks[1].id - @test bb_ir.blocks[3].id == new_bb_ir.blocks[2].id - - # Check that the reference to the second block in the PhiNode has been removed. - # Do this by checking that the only - updated_id_phi_node = new_bb_ir.blocks[2].insts[1].stmt - @test length(updated_id_phi_node.edges) == 1 - @test length(updated_id_phi_node.values) == 1 - @test only(updated_id_phi_node.values) == true - - # Get the IRCode, and ensure that the statements in it agree with what is expected. - new_ir = CC.IRCode(new_bb_ir) - expected_stmts = Any[ - GotoNode(2), PhiNode(Int32[1], Any[true]), ReturnNode(SSAValue(2)) - ] - @test Mooncake.stmt(new_ir.stmts) == expected_stmts - end -end diff --git a/test/interpreter/cfg_builder.jl b/test/interpreter/cfg_builder.jl new file mode 100644 index 0000000000..21f470283e --- /dev/null +++ b/test/interpreter/cfg_builder.jl @@ -0,0 +1,193 @@ +@testset "cfg_builder" begin + @testset "_characterise_unique_predecessor_blocks" begin + @testset "single block" begin + blk_id = ID() + blocks = Mooncake.CFGBlock[Mooncake.CFGBlock( + blk_id, [(ID(), new_inst(ReturnNode(5)))] + )] + upreds, pred_is_upred = Mooncake._characterise_unique_predecessor_blocks(blocks) + @test upreds[blk_id] == true + @test pred_is_upred[blk_id] == true + end + + @testset "pair of blocks" begin + blk_id_1 = ID() + blk_id_2 = ID() + blocks = Mooncake.CFGBlock[ + Mooncake.CFGBlock(blk_id_1, [(ID(), new_inst(IDGotoNode(blk_id_2)))]), + Mooncake.CFGBlock(blk_id_2, [(ID(), new_inst(ReturnNode(5)))]), + ] + upreds, pred_is_upred = Mooncake._characterise_unique_predecessor_blocks(blocks) + @test upreds[blk_id_1] == true + @test upreds[blk_id_2] == true + @test pred_is_upred[blk_id_1] == true + @test pred_is_upred[blk_id_2] == true + end + + @testset "non-unique exit node" begin + blk_id_1 = ID() + blk_id_2 = ID() + blk_id_3 = ID() + blocks = Mooncake.CFGBlock[ + Mooncake.CFGBlock( + blk_id_1, [(ID(), new_inst(IDGotoIfNot(true, blk_id_3)))] + ), + Mooncake.CFGBlock(blk_id_2, [(ID(), new_inst(ReturnNode(5)))]), + Mooncake.CFGBlock(blk_id_3, [(ID(), new_inst(ReturnNode(5)))]), + ] + upreds, pred_is_upred = Mooncake._characterise_unique_predecessor_blocks(blocks) + @test upreds[blk_id_1] == true + @test upreds[blk_id_2] == false + @test upreds[blk_id_3] == false + @test pred_is_upred[blk_id_1] == true + @test pred_is_upred[blk_id_2] == true + @test pred_is_upred[blk_id_3] == true + end + + @testset "diamond structure of four blocks" begin + blk_id_1 = ID() + blk_id_2 = ID() + blk_id_3 = ID() + blk_id_4 = ID() + blocks = Mooncake.CFGBlock[ + Mooncake.CFGBlock( + blk_id_1, [(ID(), new_inst(IDGotoIfNot(true, blk_id_3)))] + ), + Mooncake.CFGBlock(blk_id_2, [(ID(), new_inst(IDGotoNode(blk_id_4)))]), + Mooncake.CFGBlock(blk_id_3, [(ID(), new_inst(IDGotoNode(blk_id_4)))]), + Mooncake.CFGBlock(blk_id_4, [(ID(), new_inst(ReturnNode(0)))]), + ] + upreds, pred_is_upred = Mooncake._characterise_unique_predecessor_blocks(blocks) + @test upreds[blk_id_1] == true + @test upreds[blk_id_2] == false + @test upreds[blk_id_3] == false + @test upreds[blk_id_4] == true + @test pred_is_upred[blk_id_1] == true + @test pred_is_upred[blk_id_2] == true + @test pred_is_upred[blk_id_3] == true + @test pred_is_upred[blk_id_4] == false + end + + @testset "simple loop back to first block" begin + blk_id_1 = ID() + blk_id_2 = ID() + blocks = Mooncake.CFGBlock[ + Mooncake.CFGBlock( + blk_id_1, [(ID(), new_inst(IDGotoIfNot(true, blk_id_1)))] + ), + Mooncake.CFGBlock(blk_id_2, [(ID(), new_inst(ReturnNode(5)))]), + ] + upreds, pred_is_upred = Mooncake._characterise_unique_predecessor_blocks(blocks) + @test upreds[blk_id_1] == true + @test upreds[blk_id_2] == true + @test pred_is_upred[blk_id_1] == false + @test pred_is_upred[blk_id_2] == true + end + end + + @testset "_cfg_distance_to_entry and _canonicalise_cfg_blocks" begin + blk_id_1 = ID() + blk_id_2 = ID() + blk_id_3 = ID() + blk_id_4 = ID() + blocks = Mooncake.CFGBlock[ + Mooncake.CFGBlock(blk_id_1, [(ID(), new_inst(IDGotoNode(blk_id_4)))]), + Mooncake.CFGBlock(blk_id_3, [(ID(), new_inst(ReturnNode(3)))]), + Mooncake.CFGBlock(blk_id_2, [(ID(), new_inst(ReturnNode(2)))]), + Mooncake.CFGBlock(blk_id_4, [(ID(), new_inst(IDGotoNode(blk_id_2)))]), + ] + + @test Mooncake._cfg_distance_to_entry(blocks) == [0, typemax(Int), 2, 1] + + sorted_blocks = Mooncake._sort_cfg_blocks!(copy(blocks)) + @test map(block -> block.id, sorted_blocks) == + [blk_id_1, blk_id_4, blk_id_2, blk_id_3] + + canonical_blocks = Mooncake._canonicalise_cfg_blocks(blocks) + @test map(block -> block.id, canonical_blocks) == [blk_id_1, blk_id_4, blk_id_2] + end + + @testset "_cfg_control_flow_graph and lower_cfg_blocks_to_ir" begin + ir = Mooncake.ircode(Any[ReturnNode(nothing)], Any[Any]) + mid_id = ID() + end_id = ID() + blocks = Mooncake.CFGBlock[ + Mooncake.CFGBlock(ID(), [(ID(), new_inst(IDGotoNode(mid_id)))]), + Mooncake.CFGBlock(mid_id, [(ID(), new_inst(IDGotoNode(end_id)))]), + Mooncake.CFGBlock(end_id, [(ID(), new_inst(ReturnNode(1)))]), + ] + + lowered_ir = Mooncake.lower_cfg_blocks_to_ir(ir, Any[Any], blocks) + cfg = Mooncake._cfg_control_flow_graph(Mooncake._canonicalise_cfg_blocks(blocks)) + + @test all( + map((lhs, rhs) -> lhs.stmts == rhs.stmts, lowered_ir.cfg.blocks, cfg.blocks) + ) + @test all( + map((lhs, rhs) -> lhs.preds == rhs.preds, lowered_ir.cfg.blocks, cfg.blocks) + ) + @test all( + map((lhs, rhs) -> lhs.succs == rhs.succs, lowered_ir.cfg.blocks, cfg.blocks) + ) + @test lowered_ir.cfg.index == cfg.index + end + + @testset "sort_cfg=false preserves block order" begin + entry_id, mid_id, exit_id = ID(), ID(), ID() + blocks = Mooncake.CFGBlock[ + Mooncake.CFGBlock(entry_id, [(ID(), new_inst(IDGotoNode(exit_id)))]), + Mooncake.CFGBlock(mid_id, [(ID(), new_inst(ReturnNode(1)))]), + Mooncake.CFGBlock(exit_id, [(ID(), new_inst(IDGotoNode(mid_id)))]), + ] + unsorted = Mooncake._canonicalise_cfg_blocks(blocks; sort_cfg=false) + @test map(blk -> blk.id, unsorted) == [entry_id, mid_id, exit_id] + end + + @testset "_insert_before_terminator!" begin + mid_id = ID() + insts = [(ID(), new_inst(IDGotoNode(mid_id)))] + inserted = (ID(), new_inst(ReturnNode(3))) + Mooncake._insert_before_terminator!(insts, inserted) + @test insts[1] == inserted + @test insts[2][2].stmt == IDGotoNode(mid_id) + end + + @testset "_cfg_terminator and _cfg_phi_nodes" begin + entry_id, mid_id = ID(), ID() + phi_block = Mooncake.CFGBlock( + ID(), + [ + (ID(), new_inst(IDPhiNode([entry_id, mid_id], Any[Argument(1), 2]))), + (ID(), new_inst(ReturnNode(nothing))), + ], + ) + phi_ids, phis = Mooncake._cfg_phi_nodes(phi_block) + @test length(phi_ids) == 1 + @test only(phis).stmt == IDPhiNode([entry_id, mid_id], Any[Argument(1), 2]) + @test Mooncake._cfg_terminator(phi_block) == ReturnNode(nothing) + end + + @testset "phi-edge cleanup on dead predecessors" begin + ir = Mooncake.ircode(Any[ReturnNode(nothing)], Any[Any]) + entry_id, dead_id, join_id = ID(), ID(), ID() + blocks = Mooncake.CFGBlock[ + Mooncake.CFGBlock(entry_id, [(ID(), new_inst(IDGotoNode(join_id)))]), + Mooncake.CFGBlock(dead_id, [(ID(), new_inst(IDGotoNode(join_id)))]), + Mooncake.CFGBlock( + join_id, + [ + (ID(), new_inst(IDPhiNode([entry_id, dead_id], Any[Argument(1), 2]))), + (ID(), new_inst(ReturnNode(1))), + ], + ), + ] + preds = Mooncake._compute_cfg_predecessors(blocks) + @test Set(preds[join_id]) == Set([entry_id, dead_id]) + + lowered = Mooncake._canonicalise_cfg_blocks(blocks) + @test lowered[2].insts[1][2].stmt == IDPhiNode([entry_id], Any[Argument(1)]) + + lowered_ir = Mooncake.lower_cfg_blocks_to_ir(ir, Any[Any], blocks) + @test stmt(lowered_ir.stmts)[2] == PhiNode(Int32[1], Any[Argument(1)]) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 67272c4045..8b25700cc0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -46,7 +46,7 @@ include("front_matter.jl") include(joinpath("interpreter", "contexts.jl")) include(joinpath("interpreter", "abstract_interpretation.jl")) include(joinpath("interpreter", "ir_utils.jl")) - include(joinpath("interpreter", "bbcode.jl")) + include(joinpath("interpreter", "cfg_builder.jl")) include(joinpath("interpreter", "ir_normalisation.jl")) include(joinpath("interpreter", "zero_like_rdata.jl")) include(joinpath("interpreter", "forward_mode.jl")) diff --git a/test/skill_utils.jl b/test/skill_utils.jl index 013705039d..7f0c214086 100644 --- a/test/skill_utils.jl +++ b/test/skill_utils.jl @@ -49,7 +49,7 @@ Mooncake.@zero_derivative Mooncake.MinimalCtx Tuple{typeof(zero_derivative_llvmc @test isempty(ins.notes) expected_stages = [ - :raw, :normalized, :bbcode, :fwd_ir, :rvs_ir, :optimized_fwd, :optimized_rvs + :raw, :normalized, :cfg_blocks, :fwd_ir, :rvs_ir, :optimized_fwd, :optimized_rvs ] @test ins.stage_order == expected_stages for s in expected_stages @@ -74,7 +74,7 @@ Mooncake.@zero_derivative Mooncake.MinimalCtx Tuple{typeof(zero_derivative_llvmc ins = inspect_fwd(test_fn, 1.0) @test ins.mode == :forward - expected_stages = [:raw, :normalized, :bbcode, :dual_ir, :optimized] + expected_stages = [:raw, :normalized, :cfg_blocks, :dual_ir, :optimized] @test ins.stage_order == expected_stages for s in expected_stages @test haskey(ins.stages, s) @@ -244,8 +244,8 @@ Mooncake.@zero_derivative Mooncake.MinimalCtx Tuple{typeof(zero_derivative_llvmc @testset "render_ir" begin ins = inspect_ir(test_fn, 1.0) @test !isempty(render_ir(ins.stages[:raw].ir)) - @test !isempty(render_ir(ins.stages[:bbcode].ir)) - @test occursin("Block", render_ir(ins.stages[:bbcode].ir)) + @test !isempty(render_ir(ins.stages[:cfg_blocks].ir)) + @test occursin("Block", render_ir(ins.stages[:cfg_blocks].ir)) end @testset "convenience functions" begin @@ -269,21 +269,23 @@ Mooncake.@zero_derivative Mooncake.MinimalCtx Tuple{typeof(zero_derivative_llvmc fg = forward_stage_graph() @test fg == [ :raw => :normalized, - :normalized => :bbcode, - :bbcode => :dual_ir, + :normalized => :cfg_blocks, + :cfg_blocks => :dual_ir, :dual_ir => :optimized, ] rg = reverse_stage_graph() @test (:raw => :normalized) in rg - @test (:bbcode => :fwd_ir) in rg - @test (:bbcode => :rvs_ir) in rg + @test (:cfg_blocks => :fwd_ir) in rg + @test (:cfg_blocks => :rvs_ir) in rg @test (:fwd_ir => :optimized_fwd) in rg @test (:rvs_ir => :optimized_rvs) in rg - @test forward_stage_order() == [:raw, :normalized, :bbcode, :dual_ir, :optimized] - @test reverse_stage_order() == - [:raw, :normalized, :bbcode, :fwd_ir, :rvs_ir, :optimized_fwd, :optimized_rvs] + @test forward_stage_order() == + [:raw, :normalized, :cfg_blocks, :dual_ir, :optimized] + @test reverse_stage_order() == [ + :raw, :normalized, :cfg_blocks, :fwd_ir, :rvs_ir, :optimized_fwd, :optimized_rvs + ] end @testset "StageMeta" begin @@ -306,9 +308,9 @@ Mooncake.@zero_derivative Mooncake.MinimalCtx Tuple{typeof(zero_derivative_llvmc @test raw_meta.inst_count > 0 @test raw_meta.edge_count >= 0 - bb_meta = extract_meta(ins.stages[:bbcode].ir) - @test bb_meta.block_count > 0 - @test bb_meta.inst_count > 0 + cfg_meta = extract_meta(ins.stages[:cfg_blocks].ir) + @test cfg_meta.block_count > 0 + @test cfg_meta.inst_count > 0 fallback = extract_meta("not an IR") @test fallback.block_count == 0