[Perf] SM103 tcgen05.ld.red for fused TMEM load + row-max in softmax#2449
[Perf] SM103 tcgen05.ld.red for fused TMEM load + row-max in softmax#2449LopezCastroRoberto wants to merge 2 commits intoDao-AILab:mainfrom
Conversation
|
Thanks! I think cute-dsl has copy atom that does this, instead of having to call ptx directly? |
|
@tridao yeah, that was actually my first option, but I couldn't make cute-dsl copy work. I verified that CUTLASS issue #3090 is very likely the reason. CuTe's The raw PTX workaround on this PR sets |
|
cutlass has an example of LdRed, presumably that's working? |
|
Yes, the CUTLASS MLA example works, I verified it on B300. I tried integrating I think CuTeDSL lowers This would mean LLVM's CSE pass is free to merge multiple My guess here is that, in FA4, since stage is a compile-time constant, every call to The CUTLASS example doesn't hit this since MLA indexes the TMEM tensor with a runtime-varying pipeline state: # mla_decode_fp16.py, softmax():
tStS = tStS_staged[None, None, None, mma_s_consumer_state.index] # runtime Int32
tAcc = tStS[(None, None), 0, 0]
So the theory I have is: LLVM can't prove two iterations produce the same address, so keeps both reads, and this approach works? FA4 uses stage=0 (compile-time constant for all iterations within one softmax_loop call), producing a fixed TMEM address that LLVM can merge across iterations. NVIDIA/cutlass#3090 seems to have new activity since yesterday. |
Summary
Uses the SM103-only
tcgen05.ld.redinstruction to fuse the TMEM load with a hardware max reduction in FA4's softmax step, eliminatingfmaxALU ops per tile. The max is computed in the TMEM controller at zero ALU cost.Benchmark (B300, seqlen=8192, upstream
bench_sm90.pyconfig withdo_bench_cudagraph):Zero regressions across 2,000+ different configs tried (exhaustive sweep: 3 hdims × 7 head configs × 2 causal × 7 BS × 7 SL).