fix(flash_fwd_sm90): zero partial V smem to prevent 0*NaN=NaN in PV GEMM#2407
Open
NJX-njx wants to merge 11 commits intoDao-AILab:mainfrom
Open
fix(flash_fwd_sm90): zero partial V smem to prevent 0*NaN=NaN in PV GEMM#2407NJX-njx wants to merge 11 commits intoDao-AILab:mainfrom
NJX-njx wants to merge 11 commits intoDao-AILab:mainfrom
Conversation
…missing - Pass /Zc:preprocessor to cl via nvcc when CUDA toolkit is 13.0+ on Windows (CCCL requires conforming preprocessor; fixes Dao-AILab#2395). - Emit a UserWarning when the ninja package is not installed (Dao-AILab#2165). Made-with: Cursor
- FlashAttentionForwardSm80: disable TMA output store; runtime arch can be sm_120 while the kernel uses cpasync epilogue (avoids tma_atom_O=None). - _flash_attn_bwd SM120: set dQ_single_wg and num_stages_PdS for compile_key. - atomic_add_fp32: use red.global.add.f32 inline asm instead of nvvm.atomicrmw for binding compatibility. Made-with: Cursor
- Move save_for_backward to setup_context for functorch/torch.func.grad compatibility. - When grad is enabled, return auxiliary outputs and mark non-differentiable tensors; clone trimmed output so it is not a view of out_padded. - Unpack flash_attn_func return when the 5-tuple grad path is used. - Add tests/test_flash_attn_functorch.py (skipped if CUDA ext missing). Made-with: Cursor
…EMM (Dao-AILab#2374) TMA loads a full tile_n block regardless of seqlen_k. When seqlen_k is not a tile_n multiple, rows [seqlen_k%tile_n, tile_n) of the last V block may contain NaN from an uninitialized KV cache (e.g. seqused_k < cache_size). Masking sets P=0 for those positions, but WGMMA still computes 0*NaN=NaN (IEEE 754), propagating NaN to the output. Fix: after pipeline_v.consumer_wait() and before mma_pv_fn(), each thread in the MMA warpgroup zeroes its assigned rows in V smem. The wgmma.fence.before_group.sync.aligned inside mma_pv_fn ensures those stores are visible to WGMMA without requiring an additional barrier. Three code paths fixed: - mma_one_n_block (non-overlap): zeroes when is_first_n_block=True - mma_one_n_block_intrawg_overlap: zeroes V for block (n_block+1) when seqlen_in_block_v < tile_n (runtime guard, only fires for last V block) - last_half_block_overlap: zeroes when seqlen_in_block_v < tile_n (single-block case or last block in overlap path) SM80/SM120 are unaffected: cp.async with pred=False already zero-fills. SM100 (Blackwell) uses a separate kernel and is not changed here. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…nts (Dao-AILab#2183) Move benchmark_block_sparsity.py and benchmark_mask_mod.py from tests/cute/ to benchmarks/cute/ where they belong. Update their sys.path to locate mask_mod_definitions.py which remains in tests/cute/ (shared with tests). Replace verbose print() statements in test_block_sparsity.py with logging.debug() so they are suppressed by default and only visible with pytest --log-cli-level=DEBUG or similar. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…test prints (Dao-AILab#2183)" This reverts commit 52e4bf9.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
tile_n-row V block even whenseqlen_kis not atile_nmultiple. Rows[seqlen_k % tile_n, tile_n)of the last V block may contain NaN from an uninitialized KV cache (e.g.seqused_k < cache_size). After masking, those P values are zero, but WGMMA still computes0 × NaN = NaN(IEEE 754), propagating NaN to the output.pipeline_v.consumer_wait()and beforemma_pv_fn(), each MMA-warpgroup thread zeros its assigned rows in V smem for the last (partial) V block. Thewgmma.fence.before_group.sync.alignedinsidemma_pv_fnensures those stores are visible to WGMMA — no additional barrier needed.cp.asyncpredicate zero-fill. SM100 (Blackwell) uses a separate kernel and is not affected.Three code paths fixed
mma_one_n_block(non-overlap)is_first_n_block=True(compile-time) +seqlen_in_block < tile_n(runtime)mma_one_n_block_intrawg_overlapn_block+1is partial:seqlen_k - (n_block+1)*tile_n < tile_n(runtime, only fires for the last V block)last_half_block_overlapseqlen_in_block_v < tile_n(covers single-block case)Known limitations
consume_block_sparse_loads) does not passseqlen_in_block_vtolast_half_block_overlap, so the NaN fix is not applied there. This can be addressed in a follow-up PR.Test plan
seqused_k)flash_attn_func(q, k, v, seqused_k=...)output is NaN-free after fixpytest tests/cute/test_flash_attn.pyandtests/cute/test_flash_attn_varlen.pyto check no regressionsFixes #2374
🤖 Generated with Claude Code