Skip to content

[FA4][CuTe DSL] Add head_dim=256 support with exp2 FMA emulation optimization (forward + backward, SM100)#2454

Draft
Johnsonms wants to merge 6 commits intoDao-AILab:mainfrom
Johnsonms:exp2-emu-hd256
Draft

[FA4][CuTe DSL] Add head_dim=256 support with exp2 FMA emulation optimization (forward + backward, SM100)#2454
Johnsonms wants to merge 6 commits intoDao-AILab:mainfrom
Johnsonms:exp2-emu-hd256

Conversation

@Johnsonms
Copy link
Copy Markdown
Collaborator

@Johnsonms Johnsonms commented Apr 13, 2026

Problem

The SM100 hdim=256 forward kernel uses exp2 (hardware SFU) for softmax in the inner loop. On Blackwell, the SFU pipeline can be a throughput bottleneck relative to the TMEM/MMA pipelines.

This PR depends on https://github.com/Dao-AILab/flash-attention/pull/2412/commits, will start to review aftehr 2412 merged

Solution

Replace a fraction of exp2 SFU calls with a degree-3 FMA polynomial emulation (ex2_emulation_2) to shift work from the XU/SFU pipeline onto the FMA pipeline, improving overall instruction-level parallelism.

Implementation details:

  • utils.py provides ex2_emulation_2(x, y) — a packed-pair FP32 degree-3 polynomial using evaluate_polynomial_2 +
    combine_int_frac_ex2 (bit manipulation for integer+fractional recombination)
  • In sm100_hd256_2cta_fmha_forward.py, the softmax inner loop processes S score fragments. Each fragment pair is
    emulated based on: k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res
  • The last fragment of each tile always uses hardware SFU for numerical stability
  • Tuned parameters (freq=4, res=3): because k steps by 2, k % 4 ∈ {0, 2}, so the condition fires 50% of the time →
    50% SFU / 50% FMA emulation

NCU analysis (seqlen=4096, non-causal):

image

TMEM is the bottleneck post-optimization. Sweeping beyond 50% emulation increases instruction overhead without
additional speedup (benchmarked).

  1. Performance (B200, hdim=256, bf16, locked clocks)
  2. Forward — non-causal: +2% to +22% across seqlens 1k–128k (sustained +4–6% at seqlen ≥ 4k)
  3. Forward — causal: +0.7% to +8.4% across seqlens 1k–128k
  4. Backward: no regression (within ±2% noise)
  5. No backward change

The exp2 emulation is applied only in the softmax forward path (compute_softmax_p_warp). Backward uses a separate
kernel — no modification needed.

e2e Benchmark

Base (3c015f1):

  PYTHONPATH=/tmp/fa-3c015f1 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 \
    python benchmarks/benchmark_attn.py \                                                                                                               
    --fwd --bwd --backend fa4 --headdim 256 \                                                                                                           
    --seqlen 1024,2048,4096,8192,16384,32768,65536,98304,131072 \                                                                                       
    --causal both --warmup 5 --rep 20         

Exp2 (e122e67):

  PYTHONPATH=/tmp/fa-e122e67 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 \                                                                                 
    python benchmarks/benchmark_attn.py \                                                                                                               
    --fwd --bwd --backend fa4 --headdim 256 \               
    --seqlen 1024,2048,4096,8192,16384,32768,65536,98304,131072 \                                                                                       
    --causal both --warmup 5 --rep 20  
image

--- Conclusions (3-run avg, locked clocks @ 1965 MHz):

  • FWD non-causal: +24% at 1k, +3–8% at ≥8k seqlen; strongest at 96k–128k (+7–8%).
  • FWD causal: flat at 1k–32k (within noise); clear gains at 64k–128k (+2–5% 🟢).
  • BWD: neutral across all configs (within ±1% noise). BWD causal 64k −3.2% 🔴 is a single-batch outlier inconsistent with the exp2 change scope — noise.

wangsiyu and others added 6 commits April 13, 2026 19:53
Migrate the hd256 2CTA kernels (fwd, dQ, dKdV) to use upstream's CLC
pipeline management abstractions while preserving the existing tile
coordinate format and grid layout.

CLC pipeline management:
- Replace manual PipelineClcFetchAsync consumer_wait/release/advance
  and producer_acquire/get_barrier/issue_clc_query/advance calls with
  upstream's ClcState wrapper (prefetch_next_work, consumer_wait,
  consumer_release, producer_tail).
- Use upstream's ClcDynamicPersistentTileScheduler as the hardware
  scheduler inside ClcState.

Unified scheduler interface:
- Both Sm100FmhaStaticTileScheduler and Sm100FmhaClcDynamicTileScheduler
  now expose the same protocol: advance_to_next_work() returns WorkTileInfo,
  prefetch_next_work() and producer_tail() are no-ops for static.
- Consumer warps use a single advance_to_next_work() call (no if/else).
- Scheduler warp uses prefetch_next_work() + advance_to_next_work() +
  producer_tail(), matching the upstream clc_scheduler_warp pattern.

API/config alignment:
- Rename use_clc_dynamic_scheduler -> use_clc_scheduler across all
  hd256 kernel constructors and interface.py.
- Unify env var: interface.py hd256 path now reads FA_CLC (upstream
  default, off) instead of the old USE_CLC (default on).
- Add Sm100FmhaClcDynamicTileSchedulerParams.clc_hw_params() to
  encapsulate ClcDynamicPersistentTileSchedulerParams construction.
- Remove thin wrapper functions (create_sm100_fmha_static_tile_scheduler,
  create_sm100_fmha_clc_dynamic_tile_scheduler, etc.); callers use
  constructors and .create() directly.

Unchanged:
- Static scheduling path (grid shape, coordinate format, tile scheduler).
- CLC tile coordinate format: (m, 0, (bid, hid)) with packed L=B*H.
- CLC grid shape: round_up((M, 1, B*H), cluster_shape).
- Kernel warp structure, register allocation, tensor indexing.

Tests: 434 passed / 90 skipped for both static and FA_CLC=1 paths.
Merged  into existing flash attn test files. Added skips for unsupported head_dim=256 kernels. All tests passing.
… speedup)

Replace a fraction of hardware exp2 (SFU) instructions with a
polynomial FMA emulation (ex2_emulation_2) in the softmax P-tile
computation. The key insight: SM100's SFU throughput is a bottleneck
for hdim=256 due to the large tile size. By substituting 3 out of
every 4 exp2 calls (ex2_emu_freq=4, ex2_emu_res=3) with packed FMA
polynomial approximation, we shift pressure onto the underutilized
FMA pipeline.

Additionally, the P write slot acquisition is moved earlier to
overlap any pipeline stall with the exp2 compute.

Benchmark (B200, bf16, hdim=256, 8 Q-heads, batch ~32k tokens,
avg 9 runs, locked clocks @ 1965 MHz):

FWD Non-Causal (TFLOPS):
  seqlen  :   1k    2k    4k    8k   16k   32k   64k   96k  128k
  base    :  585  1258  1438  1525  1575  1602  1419  1448  1372
  exp2    :  728  1265  1488  1608  1680  1726  1569  1557  1560
  delta   : +25%   +0%   +3%   +5%   +7%   +8%  +11%   +8%  +14%

FWD Causal (TFLOPS):
  seqlen  :   1k    2k    4k    8k   16k   32k   64k   96k  128k
  base    :  347   702  1175  1356  1476  1552  1612  1453  1399
  exp2    :  343   709  1190  1384  1505  1586  1628  1434  1444
  delta   :  -1%   +1%   +1%   +2%   +2%   +2%   +1%   -1%   +3%

BWD: negligible impact (< 0.5% across all seqlens), no regression.
@Johnsonms Johnsonms marked this pull request as ready for review April 13, 2026 21:28
@Johnsonms Johnsonms marked this pull request as draft April 13, 2026 21:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants