[FA4][CuTe DSL] Add head_dim=256 support with exp2 FMA emulation optimization (forward + backward, SM100)#2454
Draft
Johnsonms wants to merge 6 commits intoDao-AILab:mainfrom
Draft
[FA4][CuTe DSL] Add head_dim=256 support with exp2 FMA emulation optimization (forward + backward, SM100)#2454Johnsonms wants to merge 6 commits intoDao-AILab:mainfrom
Johnsonms wants to merge 6 commits intoDao-AILab:mainfrom
Conversation
Migrate the hd256 2CTA kernels (fwd, dQ, dKdV) to use upstream's CLC pipeline management abstractions while preserving the existing tile coordinate format and grid layout. CLC pipeline management: - Replace manual PipelineClcFetchAsync consumer_wait/release/advance and producer_acquire/get_barrier/issue_clc_query/advance calls with upstream's ClcState wrapper (prefetch_next_work, consumer_wait, consumer_release, producer_tail). - Use upstream's ClcDynamicPersistentTileScheduler as the hardware scheduler inside ClcState. Unified scheduler interface: - Both Sm100FmhaStaticTileScheduler and Sm100FmhaClcDynamicTileScheduler now expose the same protocol: advance_to_next_work() returns WorkTileInfo, prefetch_next_work() and producer_tail() are no-ops for static. - Consumer warps use a single advance_to_next_work() call (no if/else). - Scheduler warp uses prefetch_next_work() + advance_to_next_work() + producer_tail(), matching the upstream clc_scheduler_warp pattern. API/config alignment: - Rename use_clc_dynamic_scheduler -> use_clc_scheduler across all hd256 kernel constructors and interface.py. - Unify env var: interface.py hd256 path now reads FA_CLC (upstream default, off) instead of the old USE_CLC (default on). - Add Sm100FmhaClcDynamicTileSchedulerParams.clc_hw_params() to encapsulate ClcDynamicPersistentTileSchedulerParams construction. - Remove thin wrapper functions (create_sm100_fmha_static_tile_scheduler, create_sm100_fmha_clc_dynamic_tile_scheduler, etc.); callers use constructors and .create() directly. Unchanged: - Static scheduling path (grid shape, coordinate format, tile scheduler). - CLC tile coordinate format: (m, 0, (bid, hid)) with packed L=B*H. - CLC grid shape: round_up((M, 1, B*H), cluster_shape). - Kernel warp structure, register allocation, tensor indexing. Tests: 434 passed / 90 skipped for both static and FA_CLC=1 paths.
Merged into existing flash attn test files. Added skips for unsupported head_dim=256 kernels. All tests passing.
… speedup) Replace a fraction of hardware exp2 (SFU) instructions with a polynomial FMA emulation (ex2_emulation_2) in the softmax P-tile computation. The key insight: SM100's SFU throughput is a bottleneck for hdim=256 due to the large tile size. By substituting 3 out of every 4 exp2 calls (ex2_emu_freq=4, ex2_emu_res=3) with packed FMA polynomial approximation, we shift pressure onto the underutilized FMA pipeline. Additionally, the P write slot acquisition is moved earlier to overlap any pipeline stall with the exp2 compute. Benchmark (B200, bf16, hdim=256, 8 Q-heads, batch ~32k tokens, avg 9 runs, locked clocks @ 1965 MHz): FWD Non-Causal (TFLOPS): seqlen : 1k 2k 4k 8k 16k 32k 64k 96k 128k base : 585 1258 1438 1525 1575 1602 1419 1448 1372 exp2 : 728 1265 1488 1608 1680 1726 1569 1557 1560 delta : +25% +0% +3% +5% +7% +8% +11% +8% +14% FWD Causal (TFLOPS): seqlen : 1k 2k 4k 8k 16k 32k 64k 96k 128k base : 347 702 1175 1356 1476 1552 1612 1453 1399 exp2 : 343 709 1190 1384 1505 1586 1628 1434 1444 delta : -1% +1% +1% +2% +2% +2% +1% -1% +3% BWD: negligible impact (< 0.5% across all seqlens), no regression.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
The SM100 hdim=256 forward kernel uses exp2 (hardware SFU) for softmax in the inner loop. On Blackwell, the SFU pipeline can be a throughput bottleneck relative to the TMEM/MMA pipelines.
This PR depends on https://github.com/Dao-AILab/flash-attention/pull/2412/commits, will start to review aftehr 2412 merged
Solution
Replace a fraction of exp2 SFU calls with a degree-3 FMA polynomial emulation (ex2_emulation_2) to shift work from the XU/SFU pipeline onto the FMA pipeline, improving overall instruction-level parallelism.
Implementation details:
combine_int_frac_ex2 (bit manipulation for integer+fractional recombination)
emulated based on: k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res
50% SFU / 50% FMA emulation
NCU analysis (seqlen=4096, non-causal):
TMEM is the bottleneck post-optimization. Sweeping beyond 50% emulation increases instruction overhead without
additional speedup (benchmarked).
The exp2 emulation is applied only in the softmax forward path (compute_softmax_p_warp). Backward uses a separate
kernel — no modification needed.
e2e Benchmark
Base (3c015f1):
Exp2 (e122e67):
--- Conclusions (3-run avg, locked clocks @ 1965 MHz):