Skip to content

Add ForwardModeSplit support#3024

Draft
vchuravy wants to merge 1 commit intomainfrom
vc/fwdsplit
Draft

Add ForwardModeSplit support#3024
vchuravy wants to merge 1 commit intomainfrom
vc/fwdsplit

Conversation

@vchuravy
Copy link
Copy Markdown
Member

Implements DEM_ForwardModeSplit (value 4 in CApi.h) which splits forward
mode AD into two separate LLVM functions — an augmented forward pass
(runs primal, stores tape) and a forward derivative pass (takes tape,
returns shadow) — mirroring ReverseModeSplit for the reverse case.

Changes:

  • EnzymeCore: add ForwardModeSplit mode type, ForwardSplitNoPrimal /
    ForwardSplitWithPrimal constants, and ForwardSplitWidth/Modified helpers
  • api.jl: DEM_ForwardModeSplit = 4, extend EnzymeCreateForwardDiff with
    aug parameter for split mode
  • compiler.jl: ForwardModeSplitThunk struct; thunkbase branch calling
    EnzymeCreateAugmentedPrimal then EnzymeCreateForwardDiff(split);
    ABI wrapper and enzyme_call flags; fix AugmentedForwardThunk rettype to
    Const{actualRetType} so the primal-only augmented forward wrapper does
    not attempt to extract a shadow
  • interpreter.jl, reflection.jl: treat DEM_ForwardModeSplit like
    DEM_ForwardMode (forward rules, Forward dispatch)
  • Enzyme.jl: autodiff_thunk(::ForwardModeSplit, ...) returning
    (AugmentedForwardThunk, ForwardModeSplitThunk)
  • errors.jl, customrules.jl, parallelrules.jl, activity.jl: extend all
    DEM_ForwardMode-only checks to also cover DEM_ForwardModeSplit
  • test/forwardmodesplit.jl: new test file (42 tests)
  • test/tests.jl: internal thunk tests for DEM_ForwardModeSplit

Co-Authored-By: Claude Sonnet 4.6 noreply@anthropic.com

Comment thread src/api.jl
additionalArg,
typeInfo,
uncacheable_args,
aug = C_NULL,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a jll change for this?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would there be? The C-API has already accepted this parameter and we are just exposing it to the user here?

Implements DEM_ForwardModeSplit (value 4 in CApi.h) which splits forward
mode AD into two separate LLVM functions — an augmented forward pass
(runs primal, stores tape) and a forward derivative pass (takes tape,
returns shadow) — mirroring ReverseModeSplit for the reverse case.

Changes:
- EnzymeCore: add ForwardModeSplit mode type, ForwardSplitNoPrimal /
  ForwardSplitWithPrimal constants, and ForwardSplitWidth/Modified helpers
- api.jl: DEM_ForwardModeSplit = 4, extend EnzymeCreateForwardDiff with
  aug parameter for split mode
- compiler.jl: ForwardModeSplitThunk struct; thunkbase branch calling
  EnzymeCreateAugmentedPrimal then EnzymeCreateForwardDiff(split);
  ABI wrapper and enzyme_call flags; fix AugmentedForwardThunk rettype to
  Const{actualRetType} so the primal-only augmented forward wrapper does
  not attempt to extract a shadow
- interpreter.jl, reflection.jl: treat DEM_ForwardModeSplit like
  DEM_ForwardMode (forward rules, Forward dispatch)
- Enzyme.jl: autodiff_thunk(::ForwardModeSplit, ...) returning
  (AugmentedForwardThunk, ForwardModeSplitThunk)
- errors.jl, customrules.jl, parallelrules.jl, activity.jl: extend all
  DEM_ForwardMode-only checks to also cover DEM_ForwardModeSplit
- test/forwardmodesplit.jl: new test file (42 tests)
- test/tests.jl: internal thunk tests for DEM_ForwardModeSplit

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants