Skip to content

[Perf] SM103 tcgen05.ld.red for fused TMEM load + row-max in softmax#2449

Open
LopezCastroRoberto wants to merge 2 commits intoDao-AILab:mainfrom
LopezCastroRoberto:perf/ld.red-upstream
Open

[Perf] SM103 tcgen05.ld.red for fused TMEM load + row-max in softmax#2449
LopezCastroRoberto wants to merge 2 commits intoDao-AILab:mainfrom
LopezCastroRoberto:perf/ld.red-upstream

Conversation

@LopezCastroRoberto
Copy link
Copy Markdown

@LopezCastroRoberto LopezCastroRoberto commented Apr 9, 2026

Summary

Uses the SM103-only tcgen05.ld.red instruction to fuse the TMEM load with a hardware max reduction in FA4's softmax step, eliminating fmax ALU ops per tile. The max is computed in the TMEM controller at zero ALU cost.

Benchmark (B300, seqlen=8192, upstream bench_sm90.py config with do_bench_cudagraph):

hdim causal Speedup Baseline TFLOPS ld.red TFLOPS
64 non-causal +2.9% 1110 1142
64 causal +6.5% 1024 1091
96 non-causal +5.4% 1344 1416
96 causal +4.5% 1222 1277
128 non-causal +3.1% 1524 1571
128 causal +1.0% 1353 1366

Zero regressions across 2,000+ different configs tried (exhaustive sweep: 3 hdims × 7 head configs × 2 causal × 7 BS × 7 SL).

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
@LopezCastroRoberto LopezCastroRoberto marked this pull request as draft April 9, 2026 20:00
@tridao
Copy link
Copy Markdown
Member

tridao commented Apr 9, 2026

Thanks! I think cute-dsl has copy atom that does this, instead of having to call ptx directly?

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
@LopezCastroRoberto
Copy link
Copy Markdown
Author

@tridao yeah, that was actually my first option, but I couldn't make cute-dsl copy work.

I verified that Ld and LdRed produce identical data, identical layouts, and correct hardware max values. But cute.copy(LdRed) produces wrong results on subsequent tiles.

CUTLASS issue #3090 is very likely the reason. CuTe's LdRed is lowered to llvm.inline_asm without has_side_effects=True, causing LLVM's CSE pass to merge multiple reads from the same TMEM address even though the MMA warp writes new data between reads.

The raw PTX workaround on this PR sets has_side_effects=True explicitly, preventing CSE. Once #3090 is fixed, the raw PTX can be replaced with the copy atom for cleaner code and CuTe's native instruction scheduling.

@tridao
Copy link
Copy Markdown
Member

tridao commented Apr 14, 2026

@LopezCastroRoberto
Copy link
Copy Markdown
Author

Yes, the CUTLASS MLA example works, I verified it on B300. I tried integrating LdRed32x32bOp into FA4 following the exact same pattern. The correct PTX gets emitted, but the kernel produces wrong results. After some debugging, I still think the issue is the same as described in NVIDIA/cutlass#3090

I think CuTeDSL lowers tcgen05.ld and tcgen05.ld.red differently?

This would mean LLVM's CSE pass is free to merge multiple ld.red calls that share the same TMEM source address, even when the TMEM contents change between reads due to intervening MMA writes.

My guess here is that, in FA4, since stage is a compile-time constant, every call to softmax_step produces the same TMEM address. LLVM sees identical ld.red calls across loop iterations and merges them. I confirmed this theory with PTX instruction counts (same observation as NVIDIA/cutlass#3090):

Raw PTX (has_side_effects=true)  :  generates 8 `tcgen05.ld.red` instructions (4 tiles × 2 unrolled iterations)
CuTe LdRed32x32bOp: generates 4 `tcgen05.ld.red` instructions (half eliminated by CSE)

The CUTLASS example doesn't hit this since MLA indexes the TMEM tensor with a runtime-varying pipeline state:

# mla_decode_fp16.py, softmax():
tStS = tStS_staged[None, None, None, mma_s_consumer_state.index]  # runtime  Int32
tAcc = tStS[(None, None), 0, 0]

mma_s_consumer_state.index alternates at runtime (MLA has mma_s_stage=2). This produces add-based addressing in PTX where the TMEM address depends on a runtime value:

; MLA: runtime stage index → address varies per iteration
shl.b32  %r2174, %r4157, 6        ; runtime offset from pipeline state
add.s32  %r2164, %r426, %r2174    ; addr = base + runtime_offset
tcgen05.ld.red.sync.aligned.32x32b.x64.max.f32 {...}, %r2163, [%r2164];

So the theory I have is: LLVM can't prove two iterations produce the same address, so keeps both reads, and this approach works?

FA4 uses stage=0 (compile-time constant for all iterations within one softmax_loop call), producing a fixed TMEM address that LLVM can merge across iterations.

NVIDIA/cutlass#3090 seems to have new activity since yesterday.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants