Skip to content

fix(flash_fwd_sm90): zero partial V smem to prevent 0*NaN=NaN in PV GEMM#2407

Open
NJX-njx wants to merge 11 commits intoDao-AILab:mainfrom
NJX-njx:claude/agitated-franklin
Open

fix(flash_fwd_sm90): zero partial V smem to prevent 0*NaN=NaN in PV GEMM#2407
NJX-njx wants to merge 11 commits intoDao-AILab:mainfrom
NJX-njx:claude/agitated-franklin

Conversation

@NJX-njx
Copy link
Copy Markdown
Contributor

@NJX-njx NJX-njx commented Mar 29, 2026

Summary

  • Root cause: TMA loads a full tile_n-row V block even when seqlen_k is not a tile_n multiple. Rows [seqlen_k % tile_n, tile_n) of the last V block may contain NaN from an uninitialized KV cache (e.g. seqused_k < cache_size). After masking, those P values are zero, but WGMMA still computes 0 × NaN = NaN (IEEE 754), propagating NaN to the output.
  • Fix: After pipeline_v.consumer_wait() and before mma_pv_fn(), each MMA-warpgroup thread zeros its assigned rows in V smem for the last (partial) V block. The wgmma.fence.before_group.sync.aligned inside mma_pv_fn ensures those stores are visible to WGMMA — no additional barrier needed.
  • Scope: SM90 (Hopper) only. SM80/SM120 already handle this correctly via cp.async predicate zero-fill. SM100 (Blackwell) uses a separate kernel and is not affected.

Three code paths fixed

Path Condition
mma_one_n_block (non-overlap) is_first_n_block=True (compile-time) + seqlen_in_block < tile_n (runtime)
mma_one_n_block_intrawg_overlap V block n_block+1 is partial: seqlen_k - (n_block+1)*tile_n < tile_n (runtime, only fires for the last V block)
last_half_block_overlap seqlen_in_block_v < tile_n (covers single-block case)

Known limitations

  • Block-sparse path (consume_block_sparse_loads) does not pass seqlen_in_block_v to last_half_block_overlap, so the NaN fix is not applied there. This can be addressed in a follow-up PR.
  • Fix is not tested on SM90 hardware (only SM120 available). The logic is analogous to the existing SM80 predicate zero-fill which is confirmed correct.

Test plan

  • Reproduce with the script from issue FA3/4 does not guard against NaN in unused V cache entries #2374 (requires SM90 / Hopper GPU with NaN-initialized KV cache + seqused_k)
  • Confirm flash_attn_func(q, k, v, seqused_k=...) output is NaN-free after fix
  • Run pytest tests/cute/test_flash_attn.py and tests/cute/test_flash_attn_varlen.py to check no regressions

Fixes #2374

🤖 Generated with Claude Code

NJX-njx and others added 11 commits March 9, 2026 08:05
…missing

- Pass /Zc:preprocessor to cl via nvcc when CUDA toolkit is 13.0+ on Windows
  (CCCL requires conforming preprocessor; fixes Dao-AILab#2395).
- Emit a UserWarning when the ninja package is not installed (Dao-AILab#2165).

Made-with: Cursor
- FlashAttentionForwardSm80: disable TMA output store; runtime arch can be sm_120
  while the kernel uses cpasync epilogue (avoids tma_atom_O=None).
- _flash_attn_bwd SM120: set dQ_single_wg and num_stages_PdS for compile_key.
- atomic_add_fp32: use red.global.add.f32 inline asm instead of nvvm.atomicrmw
  for binding compatibility.

Made-with: Cursor
- Move save_for_backward to setup_context for functorch/torch.func.grad compatibility.
- When grad is enabled, return auxiliary outputs and mark non-differentiable tensors;
  clone trimmed output so it is not a view of out_padded.
- Unpack flash_attn_func return when the 5-tuple grad path is used.
- Add tests/test_flash_attn_functorch.py (skipped if CUDA ext missing).

Made-with: Cursor
…EMM (Dao-AILab#2374)

TMA loads a full tile_n block regardless of seqlen_k. When seqlen_k is not
a tile_n multiple, rows [seqlen_k%tile_n, tile_n) of the last V block may
contain NaN from an uninitialized KV cache (e.g. seqused_k < cache_size).
Masking sets P=0 for those positions, but WGMMA still computes 0*NaN=NaN
(IEEE 754), propagating NaN to the output.

Fix: after pipeline_v.consumer_wait() and before mma_pv_fn(), each thread
in the MMA warpgroup zeroes its assigned rows in V smem. The
wgmma.fence.before_group.sync.aligned inside mma_pv_fn ensures those
stores are visible to WGMMA without requiring an additional barrier.

Three code paths fixed:
- mma_one_n_block (non-overlap): zeroes when is_first_n_block=True
- mma_one_n_block_intrawg_overlap: zeroes V for block (n_block+1) when
  seqlen_in_block_v < tile_n (runtime guard, only fires for last V block)
- last_half_block_overlap: zeroes when seqlen_in_block_v < tile_n
  (single-block case or last block in overlap path)

SM80/SM120 are unaffected: cp.async with pred=False already zero-fills.
SM100 (Blackwell) uses a separate kernel and is not changed here.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…nts (Dao-AILab#2183)

Move benchmark_block_sparsity.py and benchmark_mask_mod.py from tests/cute/
to benchmarks/cute/ where they belong. Update their sys.path to locate
mask_mod_definitions.py which remains in tests/cute/ (shared with tests).

Replace verbose print() statements in test_block_sparsity.py with
logging.debug() so they are suppressed by default and only visible with
pytest --log-cli-level=DEBUG or similar.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

FA3/4 does not guard against NaN in unused V cache entries

1 participant