feat: setup_context for FlashAttnFunc (torch.func.grad)#2405
feat: setup_context for FlashAttnFunc (torch.func.grad)#2405NJX-njx wants to merge 2 commits intoDao-AILab:mainfrom
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 8e84a94d6b
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| def setup_context(ctx, inputs, output): | ||
| if not isinstance(output, tuple) or len(output) != 5: |
There was a problem hiding this comment.
Switch FlashAttnFunc to separate forward/setup_context API
setup_context was added, but FlashAttnFunc.forward is still defined in the combined form (forward(ctx, ...)). In PyTorch’s torch.func integration, custom autograd Functions must use the separate form (forward(*args) + setup_context) to participate in transforms; with this mixed definition, torch.func.grad/vjp on flash_attn_func will still fail instead of using the new context path. Please convert forward to the separate signature and keep all ctx writes in setup_context.
Useful? React with 👍 / 👎.
|
Thanks for the review. I've removed \ctx\ from \FlashAttnFunc.forward\ so it uses the separate \orward(*args)\ + \setup_context\ form required for \ orch.func\ transforms (commit c1ff5c5 on this branch). |
PyTorch functorch requires forward(*args) + setup_context without the legacy forward(ctx, ...) combined form. Addresses Codex review on PR Dao-AILab#2405. Made-with: Cursor
- Move save_for_backward to setup_context for functorch/torch.func.grad compatibility. - When grad is enabled, return auxiliary outputs and mark non-differentiable tensors; clone trimmed output so it is not a view of out_padded. - Unpack flash_attn_func return when the 5-tuple grad path is used. - Add tests/test_flash_attn_functorch.py (skipped if CUDA ext missing). Made-with: Cursor
PyTorch functorch requires forward(*args) + setup_context without the legacy forward(ctx, ...) combined form. Addresses Codex review on PR Dao-AILab#2405. Made-with: Cursor
c1ff5c5 to
e12235f
Compare
|
Rebased onto main and removed unrelated files (Chinese documentation and hopper/flash.h comments) that were accidentally included. The Codex review about the separate forward/setup_context API was already addressed in the latest commit. |
Summary
Implements setup_context on FlashAttnFunc so orch.func.grad / functorch transforms work (#2071), per maintainer guidance.
Testing
Fixes #2071