Skip to content

feat: setup_context for FlashAttnFunc (torch.func.grad)#2405

Open
NJX-njx wants to merge 2 commits intoDao-AILab:mainfrom
NJX-njx:pr/functorch-setup-context
Open

feat: setup_context for FlashAttnFunc (torch.func.grad)#2405
NJX-njx wants to merge 2 commits intoDao-AILab:mainfrom
NJX-njx:pr/functorch-setup-context

Conversation

@NJX-njx
Copy link
Copy Markdown
Contributor

@NJX-njx NJX-njx commented Mar 28, 2026

Summary

Implements setup_context on FlashAttnFunc so orch.func.grad / functorch transforms work (#2071), per maintainer guidance.

  • When gradients are enabled, forward returns extra tensors for context; non-differentiable outputs are marked accordingly.
  • The user-visible output is cloned when saving for backward so it is not a view of out_padded (which is marked non-differentiable).
  • lash_attn_func unpacks the 5-tuple training path so the public API is unchanged.
  • Adds ests/test_flash_attn_functorch.py (skipped if lash_attn_2_cuda is not built).

Testing

  • pytest tests/test_flash_attn_functorch.py (skipped here without CUDA extension build).

Fixes #2071

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 8e84a94d6b

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +875 to +876
def setup_context(ctx, inputs, output):
if not isinstance(output, tuple) or len(output) != 5:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Switch FlashAttnFunc to separate forward/setup_context API

setup_context was added, but FlashAttnFunc.forward is still defined in the combined form (forward(ctx, ...)). In PyTorch’s torch.func integration, custom autograd Functions must use the separate form (forward(*args) + setup_context) to participate in transforms; with this mixed definition, torch.func.grad/vjp on flash_attn_func will still fail instead of using the new context path. Please convert forward to the separate signature and keep all ctx writes in setup_context.

Useful? React with 👍 / 👎.

@NJX-njx
Copy link
Copy Markdown
Contributor Author

NJX-njx commented Mar 28, 2026

Thanks for the review. I've removed \ctx\ from \FlashAttnFunc.forward\ so it uses the separate \ orward(*args)\ + \setup_context\ form required for \ orch.func\ transforms (commit c1ff5c5 on this branch).

NJX-njx added a commit to NJX-njx/flash-attention that referenced this pull request Mar 28, 2026
PyTorch functorch requires forward(*args) + setup_context without the legacy
forward(ctx, ...) combined form. Addresses Codex review on PR Dao-AILab#2405.

Made-with: Cursor
NJX-njx added 2 commits March 29, 2026 15:34
- Move save_for_backward to setup_context for functorch/torch.func.grad compatibility.
- When grad is enabled, return auxiliary outputs and mark non-differentiable tensors;
  clone trimmed output so it is not a view of out_padded.
- Unpack flash_attn_func return when the 5-tuple grad path is used.
- Add tests/test_flash_attn_functorch.py (skipped if CUDA ext missing).

Made-with: Cursor
PyTorch functorch requires forward(*args) + setup_context without the legacy
forward(ctx, ...) combined form. Addresses Codex review on PR Dao-AILab#2405.

Made-with: Cursor
@NJX-njx NJX-njx force-pushed the pr/functorch-setup-context branch from c1ff5c5 to e12235f Compare March 29, 2026 07:35
@NJX-njx
Copy link
Copy Markdown
Contributor Author

NJX-njx commented Mar 29, 2026

Rebased onto main and removed unrelated files (Chinese documentation and hopper/flash.h comments) that were accidentally included. The Codex review about the separate forward/setup_context API was already addressed in the latest commit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

setup_context method required for FA autograd Function for working with torch.func.grad

1 participant