@@ -2305,6 +2305,273 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches
23052305 return sch .mod ["main" ].with_attr ("tirx.is_scheduled" , True )
23062306
23072307
2308+ def _attention_sequence_prefill_with_mask (h_kv , h_q , d , dtype , target : Target , sm_scale = 1.0 ): # pylint: disable=line-too-long
2309+ """Tiled sequence prefill kernel with a per-batch right-padding mask.
2310+
2311+ This is the counterpart of :func:`_attention_sequence_prefill` for batched
2312+ encoder-style inputs where each sample in the batch is padded to a common
2313+ ``seq_len`` but only the first ``valid_lens[b]`` tokens carry real content.
2314+ The kernel takes an extra ``valid_lens`` buffer of shape ``(batch_size,)``
2315+ and applies the padding mask inside the QKV load path and the online
2316+ softmax update, so no explicit mask tensor broadcast or additive bias is
2317+ needed on the host side.
2318+
2319+ Semantics: for batch ``b``, positions ``[0, valid_lens[b])`` are real and
2320+ positions ``[valid_lens[b], seq_len)`` are padding. Padding queries and
2321+ keys/values are zeroed at load time; padded ``(row, col)`` pairs are
2322+ excluded from the max/sum of the online softmax via a ``-inf`` slot.
2323+ """
2324+ (
2325+ _ ,
2326+ LOAD_VEC ,
2327+ group_size ,
2328+ bdx ,
2329+ num_warps ,
2330+ tile_x ,
2331+ tile_y ,
2332+ tile_z ,
2333+ ) = _get_prefill_kernel_config (h_kv , h_q , d , dtype , target )
2334+
2335+ def _valid_length_mask (valid_len , row , col , qo_len ):
2336+ """Return True when both the query row and the key col are unpadded."""
2337+ return tirx .And (
2338+ tirx .And (row < qo_len , row < valid_len ),
2339+ col < valid_len ,
2340+ )
2341+
2342+ # fmt: off
2343+ @T .prim_func
2344+ def batch_sequence_prefill_kv_masked ( # pylint: disable=too-many-branches
2345+ var_q : T .handle , # [batch_size, qo_len, h_q, d]
2346+ var_k : T .handle , # [batch_size, kv_len, h_kv, d]
2347+ var_v : T .handle , # [batch_size, kv_len, h_kv, d]
2348+ var_valid_lens : T .handle , # [batch_size], int32
2349+ var_output : T .handle , # [batch_size, qo_len, h_q, d]
2350+ var_lse : T .handle # [batch_size, qo_len, h_q]
2351+ ):
2352+ batch_size = T .int32 (is_size_var = True )
2353+ qo_len = T .int32 (is_size_var = True )
2354+ kv_len = T .int32 (is_size_var = True )
2355+ q = T .match_buffer (var_q , (batch_size , qo_len , h_q , d ), dtype )
2356+ k = T .match_buffer (var_k , (batch_size , kv_len , h_kv , d ), dtype )
2357+ v = T .match_buffer (var_v , (batch_size , kv_len , h_kv , d ), dtype )
2358+ valid_lens = T .match_buffer (var_valid_lens , (batch_size ,), "int32" )
2359+ output = T .match_buffer (var_output , (batch_size , qo_len , h_q , d ), dtype )
2360+ lse = T .match_buffer (var_lse , (batch_size , qo_len , h_q ), dtype )
2361+
2362+ batch_tiles : T .int32 = T .ceildiv (qo_len * group_size , tile_x )
2363+
2364+ for lbx in T .thread_binding (T .cast (batch_size , "int32" ) * batch_tiles , thread = "blockIdx.x" ):
2365+ for lby in T .thread_binding (h_kv , thread = "blockIdx.y" ):
2366+ for lty in T .thread_binding (num_warps , thread = "threadIdx.y" ):
2367+ for ltx in T .thread_binding (bdx , thread = "threadIdx.x" ):
2368+ with T .sblock ("attn" ):
2369+ vbx , by , ty , tx = T .axis .remap ("SSSS" , [lbx , lby , lty , ltx ])
2370+ T .reads ()
2371+ T .writes ()
2372+
2373+ Q_smem = T .sblock_alloc_buffer ((tile_x , d ), dtype , scope = "shared" )
2374+ K_smem = T .sblock_alloc_buffer ((tile_z , d ), dtype , scope = "shared" )
2375+ V_smem = T .sblock_alloc_buffer ((tile_z , d ), dtype , scope = "shared" )
2376+ S_smem = T .sblock_alloc_buffer ((tile_x , tile_z ), "float32" , scope = "shared" )
2377+
2378+ S_local = T .sblock_alloc_buffer ((tile_x , tile_z ), "float32" , scope = "local" )
2379+ O_local = T .sblock_alloc_buffer ((tile_x , d ), "float32" , scope = "local" )
2380+
2381+ m_smem = T .sblock_alloc_buffer ((tile_x ,), "float32" , scope = "shared" )
2382+ m_prev_smem = T .sblock_alloc_buffer ((tile_x ,), "float32" , scope = "shared" )
2383+ d_smem = T .sblock_alloc_buffer ((tile_x ,), "float32" , scope = "shared" )
2384+
2385+ m_new = T .sblock_alloc_buffer (
2386+ (math .ceil (tile_x / (bdx * num_warps )),), "float32" , scope = "local"
2387+ )
2388+ m_prev = T .sblock_alloc_buffer (
2389+ (math .ceil (tile_x / (bdx * num_warps )),), "float32" , scope = "local"
2390+ )
2391+ d_new = T .sblock_alloc_buffer (
2392+ (math .ceil (tile_x / (bdx * num_warps )),), "float32" , scope = "local"
2393+ )
2394+
2395+ b_idx : T .int32 = vbx // batch_tiles
2396+ valid_len : T .int32 = valid_lens [b_idx ]
2397+ tile_id : T .int32 = vbx % batch_tiles
2398+ LH_start : T .int32 = tile_id * tile_x
2399+ T .tvm_storage_sync ("shared" )
2400+
2401+ # init states
2402+ for i in T .serial (T .ceildiv (tile_x , bdx * num_warps )):
2403+ row : T .int32 = i * bdx * num_warps + ty * bdx + tx
2404+ if row < tile_x :
2405+ m_smem [row ] = - 5e4
2406+ d_smem [row ] = 1.0
2407+
2408+ for li , lj in T .grid (tile_x , tile_y ):
2409+ with T .sblock ("O_init" ):
2410+ i , j = T .axis .remap ("SS" , [li , lj ])
2411+ O_local [i , j ] = 0.0
2412+ T .tvm_storage_sync ("shared" )
2413+
2414+ # Load Q; padded rows are zeroed so they contribute nothing downstream.
2415+ for li , lj in T .grid (tile_x , tile_y ):
2416+ with T .sblock ("Q_load" ):
2417+ i , j = T .axis .remap ("SS" , [li , lj ])
2418+ T .reads ()
2419+ T .writes ()
2420+ cur_L = (LH_start + i ) // group_size
2421+ cur_H_qo = by * group_size + (LH_start + i ) % group_size
2422+ if tirx .And (cur_L < qo_len , cur_L < valid_len ):
2423+ Q_smem [i , j ] = q [b_idx , cur_L , cur_H_qo , j ]
2424+ else :
2425+ Q_smem [i , j ] = 0.0
2426+ T .tvm_storage_sync ("shared" )
2427+
2428+ for iterator in T .serial (T .ceildiv (kv_len , tile_z )):
2429+ L_kv_start : T .int32 = iterator * tile_z
2430+ L_kv_base : T .int32 = 0
2431+ for lz , ly in T .grid (tile_z , tile_y ):
2432+ with T .sblock ("K_load" ):
2433+ i , j = T .axis .remap ("SS" , [lz , ly ])
2434+ T .reads ()
2435+ T .writes ()
2436+ cur_L = L_kv_start + i
2437+ if tirx .And (cur_L < kv_len , cur_L < valid_len ):
2438+ K_smem [i , j ] = k [b_idx , L_kv_base + cur_L , by , j ]
2439+ else :
2440+ K_smem [i , j ] = 0.0
2441+ T .tvm_storage_sync ("shared" )
2442+ for lz , ly in T .grid (tile_z , tile_y ):
2443+ with T .sblock ("V_load" ):
2444+ i , j = T .axis .remap ("SS" , [lz , ly ])
2445+ T .reads ()
2446+ T .writes ()
2447+ cur_L = L_kv_start + i
2448+ if tirx .And (cur_L < kv_len , cur_L < valid_len ):
2449+ V_smem [i , j ] = v [b_idx , L_kv_base + cur_L , by , j ]
2450+ else :
2451+ V_smem [i , j ] = 0.0
2452+ T .tvm_storage_sync ("shared" )
2453+
2454+ # Compute S
2455+ with T .sblock ():
2456+ for li , lj , lk in T .grid (tile_x , tile_z , tile_y ):
2457+ with T .sblock ("S_gemm" ):
2458+ i , j , k = T .axis .remap ("SSR" , [li , lj , lk ])
2459+ with T .init ():
2460+ S_local [i , j ] = 0.0
2461+ S_local [i , j ] += (
2462+ T .cast (Q_smem [i , k ], "float32" )
2463+ * T .cast (K_smem [j , k ], "float32" )
2464+ * sm_scale
2465+ * math .log2 (math .exp (1 ))
2466+ )
2467+ T .tvm_storage_sync ("shared" )
2468+ for li , lj in T .grid (tile_x , tile_z ):
2469+ with T .sblock ("S_store" ):
2470+ i , j = T .axis .remap ("SS" , [li , lj ])
2471+ S_smem [i , j ] = S_local [i , j ]
2472+ T .tvm_storage_sync ("shared" )
2473+
2474+ # Update S, m, d — use padding mask instead of causal.
2475+ for i in T .serial (T .ceildiv (tile_x , bdx * num_warps )):
2476+ row : T .int32 = i * bdx * num_warps + ty * bdx + tx
2477+ if row < tile_x :
2478+ with T .sblock ("update1" ):
2479+ m_prev [i ] = m_smem [row ]
2480+ m_new [i ] = m_smem [row ]
2481+ row_ : T .int32 = (LH_start + row ) // group_size
2482+ for j in T .serial (tile_z ):
2483+ if _valid_length_mask (
2484+ valid_len ,
2485+ row = row_ ,
2486+ col = L_kv_start + j ,
2487+ qo_len = qo_len ,
2488+ ):
2489+ m_new [i ] = T .max (
2490+ m_new [i ], S_smem [row , j ]
2491+ )
2492+ d_new [i ] = d_smem [row ] * T .exp2 (
2493+ m_prev [i ] - m_new [i ]
2494+ )
2495+
2496+ for i in T .serial (T .ceildiv (tile_x , bdx * num_warps )):
2497+ row : T .int32 = i * bdx * num_warps + ty * bdx + tx
2498+ with T .sblock ("update" ):
2499+ for j in T .serial (tile_z ):
2500+ # sync is outside the branch, so the predicate is inside
2501+ if row < tile_x :
2502+ row_ : T .int32 = (
2503+ LH_start + row
2504+ ) // group_size
2505+ if _valid_length_mask (
2506+ valid_len ,
2507+ row = row_ ,
2508+ col = L_kv_start + j ,
2509+ qo_len = qo_len ,
2510+ ):
2511+ S_smem [row , j ] = T .exp2 (
2512+ S_smem [row , j ] - m_new [i ]
2513+ )
2514+ else :
2515+ S_smem [row , j ] = T .exp2 (- 5e4 - m_new [i ])
2516+
2517+ for i in T .serial (T .ceildiv (tile_x , bdx * num_warps )):
2518+ row : T .int32 = i * bdx * num_warps + ty * bdx + tx
2519+ if row < tile_x :
2520+ with T .sblock ("update" ):
2521+ for j in T .serial (tile_z ):
2522+ d_new [i ] += S_smem [row , j ]
2523+ m_smem [row ] = m_new [i ]
2524+ d_smem [row ] = d_new [i ]
2525+ m_prev_smem [row ] = m_prev [i ]
2526+ T .tvm_storage_sync ("shared" )
2527+
2528+ # Update O
2529+ with T .sblock ():
2530+ for li , lj , lk in T .grid (tile_x , tile_y , tile_z ):
2531+ with T .sblock ("O_gemm" ):
2532+ i , j , k = T .axis .remap ("SSR" , [li , lj , lk ])
2533+ with T .init ():
2534+ O_local [i , j ] *= T .exp2 (
2535+ m_prev_smem [i ] - m_smem [i ]
2536+ )
2537+ O_local [i , j ] += S_smem [i , k ] * T .cast (
2538+ V_smem [k , j ], "float32"
2539+ )
2540+
2541+ # Store O
2542+ for li , lj in T .grid (tile_x , tile_y ):
2543+ with T .sblock ("O_store" ):
2544+ i , j = T .axis .remap ("SS" , [li , lj ])
2545+ cur_L : T .int32 = 0 + (LH_start + i ) // group_size
2546+ cur_H_qo : T .int32 = (
2547+ by * group_size + (LH_start + i ) % group_size
2548+ )
2549+ if cur_L < qo_len :
2550+ output [b_idx , cur_L , cur_H_qo , j ] = (
2551+ O_local [i , j ] / d_smem [i ]
2552+ )
2553+
2554+ # Store LSE
2555+ for li in T .grid (tile_x ):
2556+ with T .sblock ("lse_store" ):
2557+ i = T .axis .remap ("S" , [li ])
2558+ cur_L : T .int32 = 0 + (LH_start + i ) // group_size
2559+ cur_H_qo : T .int32 = (
2560+ by * group_size + (LH_start + i ) % group_size
2561+ )
2562+ if cur_L < qo_len :
2563+ lse [b_idx , cur_L , cur_H_qo ] = m_smem [i ] + T .log2 (
2564+ d_smem [i ]
2565+ )
2566+
2567+ # fmt: on
2568+ sch = tvm .s_tir .Schedule (batch_sequence_prefill_kv_masked )
2569+ sch = _schedule_prefill_kernel (
2570+ sch , LOAD_VEC , bdx , num_warps , tile_x , tile_y , tile_z , False , False
2571+ )
2572+ return sch .mod ["main" ].with_attr ("tirx.is_scheduled" , True )
2573+
2574+
23082575def _attention_prefill_ragged_cpu (h_kv , h_q , d_qk , d_v , dtype , rope_scaling : dict [str , Any ]):
23092576 group_size = h_q // h_kv
23102577
0 commit comments