Skip to content

Forward-mode AD#389

Merged
yebai merged 226 commits intochalk-lab:mainfrom
gdalle:gd/forward
Aug 12, 2025
Merged

Forward-mode AD#389
yebai merged 226 commits intochalk-lab:mainfrom
gdalle:gd/forward

Conversation

@gdalle
Copy link
Copy Markdown
Collaborator

@gdalle gdalle commented Nov 24, 2024

This is a very rough backbone of forward mode AD, based on #386 and the existing reverse mode implementation.

Will's edits (apologies for editing your thing @gdalle -- I just want to make sure that the todo list is at the top of the PR):

Todo:

  • make FunctionWrappers work correctly not going to do this in this PR
  • add support for MistyClosures
  • add tests for Hessian vector products
  • define is_primitive separately for forwards and reverse pass.
  • do a complete pass to review design -- are there any high-level things we ought to modify?
  • improve DRY-ness of code, particularly in testing infrastructure in particular.
  • check GPU compatibility, make sure no major design issues prevent future GPU compatibility, and be explicit about what needs to be done in the future.
  • what name should we use for @from_rule: @from_chainrules or @from_chain_rule, see comments below.
  • add support for UpsilonNodes and PhiCNodes.
  • get all tests passing
  • bump to version 0.5 actually not needed

Once the above are complete, I'll request reviews.

@codecov
Copy link
Copy Markdown

codecov Bot commented Nov 24, 2024

Codecov Report

Attention: Patch coverage is 94.04070% with 82 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/interpreter/s2s_forward_mode_ad.jl 88.77% 22 Missing ⚠️
src/test_utils.jl 86.66% 16 Missing ⚠️
src/rrules/foreigncall.jl 75.75% 8 Missing ⚠️
src/rrules/memory.jl 87.69% 8 Missing ⚠️
src/utils.jl 76.92% 6 Missing ⚠️
src/rrules/tasks.jl 64.28% 5 Missing ⚠️
src/dual.jl 85.71% 3 Missing ⚠️
src/rrules/builtins.jl 97.82% 3 Missing ⚠️
src/developer_tools.jl 0.00% 2 Missing ⚠️
src/interpreter/s2s_reverse_mode_ad.jl 71.42% 2 Missing ⚠️
... and 5 more
Files with missing lines Coverage Δ
src/Mooncake.jl 100.00% <ø> (ø)
src/interpreter/ir_utils.jl 89.68% <100.00%> (+2.81%) ⬆️
src/rrules/array_legacy.jl 100.00% <100.00%> (ø)
src/rrules/avoiding_non_differentiable_code.jl 100.00% <100.00%> (ø)
src/rrules/blas.jl 99.64% <100.00%> (+0.84%) ⬆️
src/rrules/fastmath.jl 100.00% <100.00%> (ø)
src/rrules/lapack.jl 100.00% <100.00%> (+0.56%) ⬆️
src/rrules/linear_algebra.jl 100.00% <100.00%> (ø)
src/rrules/low_level_maths.jl 100.00% <100.00%> (ø)
src/rrules/new.jl 91.30% <100.00%> (+2.84%) ⬆️
... and 20 more

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Collaborator

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. I've left a few comments, but if you're planning to do a bunch of additional stuff, then maybe they're redundant. Either way, don't feel the need to respond to them.

Comment thread src/interpreter/s2s_forward_mode_ad.jl Outdated
Comment thread test/forward.jl Outdated
Comment thread src/frules/basic.jl Outdated
Comment thread src/frules/basic.jl Outdated
Comment thread src/interpreter/s2s_forward_mode_ad.jl Outdated
Comment thread src/interpreter/s2s_forward_mode_ad.jl Outdated
Comment thread src/interpreter/s2s_forward_mode_ad.jl Outdated
Comment thread src/interpreter/s2s_forward_mode_ad.jl Outdated
@gdalle
Copy link
Copy Markdown
Collaborator Author

gdalle commented Nov 26, 2024

@willtebbutt following our discussion yesterday I scratched my head some more, and I decided that it would be infinitely simpler to enforce the invariant that one line of primal IR maps to one line of dual IR. While this may require additional fallbacks in the Julia code itself, I hope it will make our lives much easier on the IR side. What do you think?

@willtebbutt
Copy link
Copy Markdown
Collaborator

I think this could work.

You could just replace the frule!! calls with a call to a function call_frule!! which would be something like

@inline function call_frule!!(rule::R, fargs::Vararg{Any, N}) where {N}
    return rule(map(x -> x isa Dual ? x : zero_dual(x), fargs)...)
end

The optimisation pass will lower this to the what we were thinking about writing out in the IR anyway.

I think the other important kinds of nodes would be largely straightforward to handle.

@gdalle
Copy link
Copy Markdown
Collaborator Author

gdalle commented Nov 26, 2024

I think we might need to be slightly more subtle. If an argument to the :call or :invoke expression is a CC.Argument or a CC.SSAValue, we don't wrap it in a Dual because we assume it will already be one, right?

@willtebbutt
Copy link
Copy Markdown
Collaborator

willtebbutt commented Nov 26, 2024

Yes. I think my propose code handles this though, or am I missing something?

@gdalle
Copy link
Copy Markdown
Collaborator Author

gdalle commented Nov 26, 2024

In the spirit of higher-order AD, we may encounter Dual inputs that we want to wrap with a second Dual, and Dual inputs that we want to leave as-is. So I think this wrapping needs to be decided from the type of each argument in the IR?

@willtebbutt
Copy link
Copy Markdown
Collaborator

Very good point.

So I think this wrapping needs to be decided from the type of each argument in the IR?

Agreed. Specifically, I think we need to distinguish between literals / QuoteNodes / GlobalRefs, and Argument / SSAValues?

@gdalle
Copy link
Copy Markdown
Collaborator Author

gdalle commented Nov 26, 2024

I still need to dig into the different node types we might encounter (and I still don't understand QuoteNodes) but yeah, Argument and SSAValue don't need to be wrapped.

@gdalle gdalle mentioned this pull request Nov 27, 2024
@willtebbutt
Copy link
Copy Markdown
Collaborator

I was reviewing the design docs and realised that, sadly, the "one line of primal IR maps to one line of dual IR" won't work for Core.GotoIfNot nodes. See https://compintell.github.io/Mooncake.jl/previews/PR386/developer_documentation/forwards_mode_design/#Statement-Transformation .

@gdalle
Copy link
Copy Markdown
Collaborator Author

gdalle commented Nov 27, 2024

I think that's okay, the main trouble is adding new lines which insert new variables because it requires manual renumbering. A GoTo should be much simpler.

@willtebbutt
Copy link
Copy Markdown
Collaborator

Were the difficulties around renumbering etc not resolved by not compact!ing until the end? I feel like I might be missing something.

@gdalle
Copy link
Copy Markdown
Collaborator Author

gdalle commented Nov 27, 2024

No they weren't. I experimented with compact! in various places and I was struggling a lot, so I asked Frames for advice. She agreed that insertion should usually be avoided.
If we have to insert something for GoTo, I think it will still be easier because we're not defining a new SSAValue so we don't have to adapt future statements that refer to it.

@willtebbutt
Copy link
Copy Markdown
Collaborator

willtebbutt commented Nov 27, 2024

Ah, right, but we do need to insert a new SSAValue. Suppose that the GotoIfNot of interest is

GotoIfNot(%5, #3)

i.e. jump to block 3 if not %5. In the forwards-mode IR this would become

%new_ssa = Expr(:call, primal, %5)
GotoIfNot(%new_ssa, #3)

Does this not cause the same kind of problems?

@gdalle
Copy link
Copy Markdown
Collaborator Author

gdalle commented Nov 27, 2024

Oh yes you're probably right. Although it might be slightly less of a hassle because the new SSA is only used in one spot, right after. I'll take a look

@gdalle
Copy link
Copy Markdown
Collaborator Author

gdalle commented Nov 27, 2024

Do you know what I should do about expressions of type :code_coverage_effect? I assume they're inserted automatically and they're alone on their lines?

@willtebbutt
Copy link
Copy Markdown
Collaborator

willtebbutt commented Nov 27, 2024

Yup -- I just strip them out of the IR entirely in reverse-mode. See https://github.com/compintell/Mooncake.jl/blob/0f37c079bd1ae064e7b84696eed4a1f7eb763f1f/src/interpreter/s2s_reverse_mode_ad.jl#L728

The way to remove an instruction from an IRCode is just to replace the instruction with nothing.

@gdalle
Copy link
Copy Markdown
Collaborator Author

gdalle commented Nov 27, 2024

I think this works for GotoIfNot:

  1. make all the insertions necessary
  2. compact! once to make sure they applied
  3. shift the conditions of all GotoIfNot nodes to refer to the node right before them (where we get the primal value of the condition)

MWE (requires this branch of Mooncake):

const CC = Core.Compiler
using Mooncake
using MistyClosures

f(x) = x > 1 ? 2x : 3 + x
ir = Base.code_ircode(f, (Float64,))[1][1]
initial_ir = copy(ir)
get_primal_inst = CC.NewInstruction(Expr(:call, +, 1, 2), Any)  # placeholder for get_primal
CC.insert_node!(ir, CC.SSAValue(3), get_primal_inst, false)
ir = CC.compact!(ir)
for k in 1:length(ir.stmts)
    inst = ir[CC.SSAValue(k)][:stmt]
    if inst isa Core.GotoIfNot
        Mooncake.replace_call!(ir,CC.SSAValue(k), Core.GotoIfNot(CC.SSAValue(k-1), inst.dest))
    end
end
ir
julia> initial_ir
5 1%1 = Base.lt_float(1.0, _2)::Bool                                                                                 │╻╷╷ >%2 = Base.or_int(%1, false)::Bool                                                                                 ││╻   <
  └──      goto #3 if not %2                                                                                            │   
  2%4 = Base.mul_float(2.0, _2)::Float64                                                                             ││╻   *
  └──      return %43%6 = Base.add_float(3.0, _2)::Float64                                                                             ││╻   +
  └──      return %6                                                                                                    │   
                                                                                                                            

julia> ir
5 1%1 = Base.lt_float(1.0, _2)::Bool                                                                                 │╻╷╷ >
  │        Base.or_int(%1, false)::Bool                                                                                 ││╻   <%3 = (+)(1, 2)::Any                                                                                               │   
  └──      goto #3 if not %3                                                                                            │   
  2%5 = Base.mul_float(2.0, _2)::Float64                                                                             ││╻   *
  └──      return %53%7 = Base.add_float(3.0, _2)::Float64                                                                             ││╻   +
  └──      return %7      

@willtebbutt
Copy link
Copy Markdown
Collaborator

Okay. I think I've now addressed all of @gdalle 's feedback, and this PR is in a state that's basically ready to go.

I'm now going to be offline for a couple of weeks, so I'm happy for @gdalle or @yebai to handle any remaining issues and merge if there's a strong need to merge before I'm back.

@gdalle
Copy link
Copy Markdown
Collaborator Author

gdalle commented Aug 8, 2025

After I review again I'd be in favor of merging but explicitly marking this as experimental, so that DI can add the functionality and then people can stress-test it on real problems. This will probably be a much better way to find bugs than refining a 4k-LOC PR

@yebai
Copy link
Copy Markdown
Member

yebai commented Aug 8, 2025

Great work, @willtebbutt! I’m also supportive of merging this PR now as an experimental feature.

@yebai yebai requested a review from Copilot August 8, 2025 21:58
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This pull request implements forward-mode automatic differentiation (AD) support for Mooncake.jl. The implementation provides a complete framework for forward-mode AD that operates alongside the existing reverse-mode capabilities.

Key changes include:

  • Addition of Dual type for forward-mode computations and comprehensive frule!! infrastructure
  • Introduction of mode-specific is_primitive function and unified @zero_derivative/@from_chainrules macros
  • Implementation of forward-mode rules (frule!!) for all existing primitives and mathematical operations

Reviewed Changes

Copilot reviewed 82 out of 83 changed files in this pull request and generated 2 comments.

File Description
src/tools_for_rules.jl Adds @zero_derivative and @from_chainrules macros, frule_wrapper functionality, and mooncake_tangent conversion
src/test_utils.jl Updates testing infrastructure to support both forward and reverse mode testing with new test_rule interface
src/rrules/*.jl Implements frule!! methods for all mathematical operations, linear algebra, memory operations, and built-ins
test/*.jl Updates all test files to use new testing interface and mode-specific functionality
Comments suppressed due to low confidence (3)

src/tools_for_rules.jl:162

  • [nitpick] The explicit return type annotation ::Tuple{Bool,Vector{Symbol}} is unnecessary and adds visual clutter. Julia's type inference can determine this automatically from the function body.
passing all of its arguments (including the function itself) to this function. For example:

src/tools_for_rules.jl:194

  • [nitpick] The explicit return type annotation ::Tuple{Bool,Vector{Symbol}} is redundant since it's already specified in the previous method and Julia can infer this type.
```jldoctest

src/rrules/lapack.jl:97

  • This commented-out code should be removed rather than left as a comment, as it adds clutter and may cause confusion about the intended behavior.
    # Restore initial state.

Comment thread test/rrules/blas.jl
Comment thread src/rrules/foreigncall.jl
@yebai
Copy link
Copy Markdown
Member

yebai commented Aug 11, 2025

@sunxd3, can you help take a look at Turing integration Test failures?

yebai and others added 3 commits August 12, 2025 19:18
Signed-off-by: Hong Ge <3279477+yebai@users.noreply.github.com>
@yebai yebai merged commit 83b4ff7 into chalk-lab:main Aug 12, 2025
88 checks passed
Comment thread HISTORY.md
Comment on lines +3 to +4
## Public Interface
- Mooncake offers forward mode AD.
Copy link
Copy Markdown
Collaborator

@penelopeysm penelopeysm Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's amazing work, though it's not clear how one can use it? I assume that ADTypes / DifferentiationInterface support will take a bit of time to arrive, but in the meantime do I just replace value_and_gradient!! with value_and_derivative!!?

And congratulations to all involved! 🎉

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DI support landing today

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gdalle gdalle deleted the gd/forward branch August 13, 2025 15:20
@yebai yebai mentioned this pull request Aug 16, 2025
Technici4n added a commit that referenced this pull request Nov 6, 2025
Seems to have been mistakenly duplicated as part of #389.

Signed-off-by: Bruno Ploumhans <13494793+Technici4n@users.noreply.github.com>
yebai pushed a commit that referenced this pull request Nov 6, 2025
Seems to have been mistakenly duplicated as part of #389.

Signed-off-by: Bruno Ploumhans <13494793+Technici4n@users.noreply.github.com>
penelopeysm pushed a commit that referenced this pull request Nov 10, 2025
Seems to have been mistakenly duplicated as part of #389.

Signed-off-by: Bruno Ploumhans <13494793+Technici4n@users.noreply.github.com>
yebai added a commit that referenced this pull request Mar 25, 2026
* Start forward mode prototype

* First working autodiff

* Docstring

* Apply suggestions from code review

Co-authored-by: Will Tebbutt <willtebbutt00@gmail.com>
Signed-off-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>

* Moving files around

* Primitives already known

* Keep pushing forward (pun intended)

* Still buggy, don't touch

* Keep instruction mapping one to one

* Use replace_call

* Ignore code cov

* No Aqua piracies test

* Start control flow

* Fix intrinsic

* Import

* Typos

* Figure out incremental additions

* Initial test case additions

* Formatting

* Add verify_dual_type

* test_frule_interface runs

* Fix ReturnNode

* Correctness testing runs

* Add randn_dual

* Improve sin and cos frules

* Performance tests run

* Tidy up implementation

* Standard testing infrastructure

* Fix typos

* Fix return node to return dual

* Handle PiNode

* Deleted line

* Case 7 solved

* Fix precompile issue

* Fix isa rule

* Fix is_primitive

* More test cases

* progress

* fixes

* Bump patch vesion

* Fix terminators

* More cases

* More cases

* Tuple rule

* Formatting

* Code to view forwards-mode IR from a signature

* Use widenconst to get actual argtype from ircode argtypes

* MyInstruction -> new_instruction

* Formatting

* Various improvements

* Rules for foreigncalls

* Fix pointer tests with forwards mode

* Enable more tests

* All derivation tests pass

* Initial pass over legacy array functionality

* Fix tangent usage in tests

* Rules for nice BLAS functions

* Tweak test inputs slightly

* Enable CI for BLAS and foreigncalls

* Enable linear_algebra rules

* More stuff works

* Make IdDict work

* Code to identify SSA uses

* Fix failing test via special case

* Remove outdated TODO note

* Fix typo

* BLAS support nearly finished

* All BLAS rules passing

* Initial work on getrf

* getrf frule sketch

* Improve getrf performance

* trtrs implementation + type stability checks

* Type stability checks for BLAS rules

* Note Seth's blog

* getrs frule implementation

* getri frule implementation

* potrs

* Enable lapack CI

* Fix pivoting

* Enable diff tests integration tests

* Only run extra CI on 1

* More lapack fixes

* widenconst

* Replace field access with method call

* Catch __vec_to_tuple edge case

* Display more stuff when correctness test fails

* Enable more integration tests

* Make output on test error sensible

* Tidy up blas implementations

* Fix pointerset error

* Fix ^ rule

* Implement from_chain_rule macro

* Get SpecialFunctions extension working

* Enable SpecialFunctions in CI

* logexpfunctions

* Run gpu jobs on 1.11 only

* Restrict FD step for forward mode

* Enable GP tests

* More integration testing

* bijectors

* Enable battery of tests

* Distributions integration testing

* Enable DI CI

* Enable reverse-mode integration tests for Lux etc

* Enable 1.10

* Fix LAPACK on 1.10

* Implement copytrito for 1.10

* formatting

* Tidying up

* Remove type piracy

* Initial forwards-mode timings

* Constrain JuliaInterpreter

* Basic MistyClosure support

* Do not use MistyClosure internals inside reverse-mode

* Forwards-over-reverse mwe

* Remove overly strict performance check

* Docstring and improved field naming

* Separate forward-mode and reverse-mode primitives

* Fix docs and rrule creation

* Fix low_level_maths

* Fix SpecialFunctions tests cases

* Fix more testing

* Fix formatting

* Make symbols available in tests

* Fix GP test suite

* Fix SpecialFunctions test suite

* Fix performance

* Fix array tests

* Fix formatting

* Fix forward-mode benchmarking

* Fix benchmarking

* forward mode interface

* Add frule for eps

* Merge in main

* Remove redundant ignore

* Rename from_chain_rule macro to from_chainrules

* Finish renaming chainrules macro

* Improve docstring for value_and_derivative

* DRY out global interpreter cache

* Fix typo in docstring

* Doctests for is_primitive macro

* Tidying up

* Fix typo in bijectors

* Fix test_rule call

* Fix formatting

* Fix dispatch doctor in forward mode

* Fix import

* Fix doctests

* Fix BLAS tests

* Fix DispatchDoctor tests

* Fix broken tests

* Include the mode in testset string

* Support try-catch statements

* Fix on LTS

* Support enter expression

* Formatting

* Enable FunctionWrappers

* Enable GPU CI on LTS

* Enable function wrappers in CI

* Bump patch version and create HISTORY

* Fix typo in prepare_derivative_cache docstring

* Note new components of public interface

* Add docstring to Dual

* Add value_and_derivative and preparation function to exports / public interface

* Improve interface docstring

* Remove comment about which we have an open issue

* Typo

* Improve documentation

* Rename const_dual to make clear that it mutates

* Clarify use of insert_node

* Improve docstring

* Tie todo note to github issue

* Apply suggestions from code review

Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
Signed-off-by: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com>

* Make use of inc_args for PiNode in reverse mode

* Nospecialise on rules and remove redundant comment

* More avoidance of specialisation

* Remove errant nospecialize and directly test rule caching

* Add not deepcopying behaviour in test

* Improve zero_derivative implementation

* Tidy up zero_adjoint and add deprecated file

* Rename some files

* Fix caching test bug

* Fix avoiding non diff code

* Tidy up from_chainrules

* Update HISTORY

* Update history

* Test interface kwargs

* Refine forward-over-reverse test

* Include MistyClosures in CI

* Remove incorrect test

* Ensure that rrule for MistyClosure errors loudly

* Formatting

* Fix flux integration tests

* Formatting

* Fix mode access

* Remove unnecessary ReverseMode import

* Fix macro bug

* Formatting

* Update misty closure set_to_zero implementation

* Update Project.toml

* typofix

---------

Signed-off-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
Signed-off-by: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com>
Signed-off-by: Hong Ge <3279477+yebai@users.noreply.github.com>
Co-authored-by: Will Tebbutt <willtebbutt00@gmail.com>
Co-authored-by: willtebbutt <wtebbutt@turing.ac.uk>
Co-authored-by: Will Tebbutt <wct23@cam.ac.uk>
Co-authored-by: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com>
Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
Co-authored-by: Hong Ge <hg344@cam.ac.uk>
yebai pushed a commit that referenced this pull request Mar 25, 2026
Seems to have been mistakenly duplicated as part of #389.

Signed-off-by: Bruno Ploumhans <13494793+Technici4n@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants