Skip to content

Commit a6e2ea8

Browse files
authored
[Relax][Frontend][KVCache] Add masked sequence prefill helper for encoder valid lengths (#19392)
Adds `_attention_sequence_prefill_with_mask` in `python/tvm/relax/frontend/nn/llm/kv_cache.py` — a masked variant of the existing sequence prefill kernel that supports right-padded encoder batches with per-sample `valid_lens`. The existing `_attention_sequence_prefill` assumes all positions in `[0, seq_len)` are valid, which breaks for padded encoder inputs where each batch element has a different valid prefix length. This helper adds the masking semantics needed for correctness: - accepts a per-batch `valid_lens` input - ignores padded query rows and padded key/value positions - excludes padded `(row, col)` pairs from the online softmax update It reuses the existing prefill kernel config and schedule — no new tuning knobs, no target-specific changes, no performance claims. Correctness only. ## Motivation: encoder batch prefill for downstream consumers This is the TVM-side primitive needed to support **encoder batch prefill** in downstream projects like `mlc-llm`, where padded encoder batches with `valid_lens` need to be lowered without materializing an explicit broadcast attention mask on the host. The helper is generic and useful for any encoder-style sequence prefill consumer with per-sample valid lengths.
1 parent 87fdeaf commit a6e2ea8

2 files changed

Lines changed: 464 additions & 0 deletions

File tree

python/tvm/relax/frontend/nn/llm/kv_cache.py

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2305,6 +2305,273 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches
23052305
return sch.mod["main"].with_attr("tirx.is_scheduled", True)
23062306

23072307

2308+
def _attention_sequence_prefill_with_mask(h_kv, h_q, d, dtype, target: Target, sm_scale=1.0): # pylint: disable=line-too-long
2309+
"""Tiled sequence prefill kernel with a per-batch right-padding mask.
2310+
2311+
This is the counterpart of :func:`_attention_sequence_prefill` for batched
2312+
encoder-style inputs where each sample in the batch is padded to a common
2313+
``seq_len`` but only the first ``valid_lens[b]`` tokens carry real content.
2314+
The kernel takes an extra ``valid_lens`` buffer of shape ``(batch_size,)``
2315+
and applies the padding mask inside the QKV load path and the online
2316+
softmax update, so no explicit mask tensor broadcast or additive bias is
2317+
needed on the host side.
2318+
2319+
Semantics: for batch ``b``, positions ``[0, valid_lens[b])`` are real and
2320+
positions ``[valid_lens[b], seq_len)`` are padding. Padding queries and
2321+
keys/values are zeroed at load time; padded ``(row, col)`` pairs are
2322+
excluded from the max/sum of the online softmax via a ``-inf`` slot.
2323+
"""
2324+
(
2325+
_,
2326+
LOAD_VEC,
2327+
group_size,
2328+
bdx,
2329+
num_warps,
2330+
tile_x,
2331+
tile_y,
2332+
tile_z,
2333+
) = _get_prefill_kernel_config(h_kv, h_q, d, dtype, target)
2334+
2335+
def _valid_length_mask(valid_len, row, col, qo_len):
2336+
"""Return True when both the query row and the key col are unpadded."""
2337+
return tirx.And(
2338+
tirx.And(row < qo_len, row < valid_len),
2339+
col < valid_len,
2340+
)
2341+
2342+
# fmt: off
2343+
@T.prim_func
2344+
def batch_sequence_prefill_kv_masked( # pylint: disable=too-many-branches
2345+
var_q: T.handle, # [batch_size, qo_len, h_q, d]
2346+
var_k: T.handle, # [batch_size, kv_len, h_kv, d]
2347+
var_v: T.handle, # [batch_size, kv_len, h_kv, d]
2348+
var_valid_lens: T.handle, # [batch_size], int32
2349+
var_output: T.handle, # [batch_size, qo_len, h_q, d]
2350+
var_lse: T.handle # [batch_size, qo_len, h_q]
2351+
):
2352+
batch_size = T.int32(is_size_var=True)
2353+
qo_len = T.int32(is_size_var=True)
2354+
kv_len = T.int32(is_size_var=True)
2355+
q = T.match_buffer(var_q, (batch_size, qo_len, h_q, d), dtype)
2356+
k = T.match_buffer(var_k, (batch_size, kv_len, h_kv, d), dtype)
2357+
v = T.match_buffer(var_v, (batch_size, kv_len, h_kv, d), dtype)
2358+
valid_lens = T.match_buffer(var_valid_lens, (batch_size,), "int32")
2359+
output = T.match_buffer(var_output, (batch_size, qo_len, h_q, d), dtype)
2360+
lse = T.match_buffer(var_lse, (batch_size, qo_len, h_q), dtype)
2361+
2362+
batch_tiles: T.int32 = T.ceildiv(qo_len * group_size, tile_x)
2363+
2364+
for lbx in T.thread_binding(T.cast(batch_size, "int32") * batch_tiles, thread="blockIdx.x"):
2365+
for lby in T.thread_binding(h_kv, thread="blockIdx.y"):
2366+
for lty in T.thread_binding(num_warps, thread="threadIdx.y"):
2367+
for ltx in T.thread_binding(bdx, thread="threadIdx.x"):
2368+
with T.sblock("attn"):
2369+
vbx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
2370+
T.reads()
2371+
T.writes()
2372+
2373+
Q_smem = T.sblock_alloc_buffer((tile_x, d), dtype, scope="shared")
2374+
K_smem = T.sblock_alloc_buffer((tile_z, d), dtype, scope="shared")
2375+
V_smem = T.sblock_alloc_buffer((tile_z, d), dtype, scope="shared")
2376+
S_smem = T.sblock_alloc_buffer((tile_x, tile_z), "float32", scope="shared")
2377+
2378+
S_local = T.sblock_alloc_buffer((tile_x, tile_z), "float32", scope="local")
2379+
O_local = T.sblock_alloc_buffer((tile_x, d), "float32", scope="local")
2380+
2381+
m_smem = T.sblock_alloc_buffer((tile_x,), "float32", scope="shared")
2382+
m_prev_smem = T.sblock_alloc_buffer((tile_x,), "float32", scope="shared")
2383+
d_smem = T.sblock_alloc_buffer((tile_x,), "float32", scope="shared")
2384+
2385+
m_new = T.sblock_alloc_buffer(
2386+
(math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local"
2387+
)
2388+
m_prev = T.sblock_alloc_buffer(
2389+
(math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local"
2390+
)
2391+
d_new = T.sblock_alloc_buffer(
2392+
(math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local"
2393+
)
2394+
2395+
b_idx: T.int32 = vbx // batch_tiles
2396+
valid_len: T.int32 = valid_lens[b_idx]
2397+
tile_id: T.int32 = vbx % batch_tiles
2398+
LH_start: T.int32 = tile_id * tile_x
2399+
T.tvm_storage_sync("shared")
2400+
2401+
# init states
2402+
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
2403+
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
2404+
if row < tile_x:
2405+
m_smem[row] = -5e4
2406+
d_smem[row] = 1.0
2407+
2408+
for li, lj in T.grid(tile_x, tile_y):
2409+
with T.sblock("O_init"):
2410+
i, j = T.axis.remap("SS", [li, lj])
2411+
O_local[i, j] = 0.0
2412+
T.tvm_storage_sync("shared")
2413+
2414+
# Load Q; padded rows are zeroed so they contribute nothing downstream.
2415+
for li, lj in T.grid(tile_x, tile_y):
2416+
with T.sblock("Q_load"):
2417+
i, j = T.axis.remap("SS", [li, lj])
2418+
T.reads()
2419+
T.writes()
2420+
cur_L = (LH_start + i) // group_size
2421+
cur_H_qo = by * group_size + (LH_start + i) % group_size
2422+
if tirx.And(cur_L < qo_len, cur_L < valid_len):
2423+
Q_smem[i, j] = q[b_idx, cur_L, cur_H_qo, j]
2424+
else:
2425+
Q_smem[i, j] = 0.0
2426+
T.tvm_storage_sync("shared")
2427+
2428+
for iterator in T.serial(T.ceildiv(kv_len, tile_z)):
2429+
L_kv_start: T.int32 = iterator * tile_z
2430+
L_kv_base: T.int32 = 0
2431+
for lz, ly in T.grid(tile_z, tile_y):
2432+
with T.sblock("K_load"):
2433+
i, j = T.axis.remap("SS", [lz, ly])
2434+
T.reads()
2435+
T.writes()
2436+
cur_L = L_kv_start + i
2437+
if tirx.And(cur_L < kv_len, cur_L < valid_len):
2438+
K_smem[i, j] = k[b_idx, L_kv_base + cur_L, by, j]
2439+
else:
2440+
K_smem[i, j] = 0.0
2441+
T.tvm_storage_sync("shared")
2442+
for lz, ly in T.grid(tile_z, tile_y):
2443+
with T.sblock("V_load"):
2444+
i, j = T.axis.remap("SS", [lz, ly])
2445+
T.reads()
2446+
T.writes()
2447+
cur_L = L_kv_start + i
2448+
if tirx.And(cur_L < kv_len, cur_L < valid_len):
2449+
V_smem[i, j] = v[b_idx, L_kv_base + cur_L, by, j]
2450+
else:
2451+
V_smem[i, j] = 0.0
2452+
T.tvm_storage_sync("shared")
2453+
2454+
# Compute S
2455+
with T.sblock():
2456+
for li, lj, lk in T.grid(tile_x, tile_z, tile_y):
2457+
with T.sblock("S_gemm"):
2458+
i, j, k = T.axis.remap("SSR", [li, lj, lk])
2459+
with T.init():
2460+
S_local[i, j] = 0.0
2461+
S_local[i, j] += (
2462+
T.cast(Q_smem[i, k], "float32")
2463+
* T.cast(K_smem[j, k], "float32")
2464+
* sm_scale
2465+
* math.log2(math.exp(1))
2466+
)
2467+
T.tvm_storage_sync("shared")
2468+
for li, lj in T.grid(tile_x, tile_z):
2469+
with T.sblock("S_store"):
2470+
i, j = T.axis.remap("SS", [li, lj])
2471+
S_smem[i, j] = S_local[i, j]
2472+
T.tvm_storage_sync("shared")
2473+
2474+
# Update S, m, d — use padding mask instead of causal.
2475+
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
2476+
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
2477+
if row < tile_x:
2478+
with T.sblock("update1"):
2479+
m_prev[i] = m_smem[row]
2480+
m_new[i] = m_smem[row]
2481+
row_: T.int32 = (LH_start + row) // group_size
2482+
for j in T.serial(tile_z):
2483+
if _valid_length_mask(
2484+
valid_len,
2485+
row=row_,
2486+
col=L_kv_start + j,
2487+
qo_len=qo_len,
2488+
):
2489+
m_new[i] = T.max(
2490+
m_new[i], S_smem[row, j]
2491+
)
2492+
d_new[i] = d_smem[row] * T.exp2(
2493+
m_prev[i] - m_new[i]
2494+
)
2495+
2496+
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
2497+
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
2498+
with T.sblock("update"):
2499+
for j in T.serial(tile_z):
2500+
# sync is outside the branch, so the predicate is inside
2501+
if row < tile_x:
2502+
row_: T.int32 = (
2503+
LH_start + row
2504+
) // group_size
2505+
if _valid_length_mask(
2506+
valid_len,
2507+
row=row_,
2508+
col=L_kv_start + j,
2509+
qo_len=qo_len,
2510+
):
2511+
S_smem[row, j] = T.exp2(
2512+
S_smem[row, j] - m_new[i]
2513+
)
2514+
else:
2515+
S_smem[row, j] = T.exp2(-5e4 - m_new[i])
2516+
2517+
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
2518+
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
2519+
if row < tile_x:
2520+
with T.sblock("update"):
2521+
for j in T.serial(tile_z):
2522+
d_new[i] += S_smem[row, j]
2523+
m_smem[row] = m_new[i]
2524+
d_smem[row] = d_new[i]
2525+
m_prev_smem[row] = m_prev[i]
2526+
T.tvm_storage_sync("shared")
2527+
2528+
# Update O
2529+
with T.sblock():
2530+
for li, lj, lk in T.grid(tile_x, tile_y, tile_z):
2531+
with T.sblock("O_gemm"):
2532+
i, j, k = T.axis.remap("SSR", [li, lj, lk])
2533+
with T.init():
2534+
O_local[i, j] *= T.exp2(
2535+
m_prev_smem[i] - m_smem[i]
2536+
)
2537+
O_local[i, j] += S_smem[i, k] * T.cast(
2538+
V_smem[k, j], "float32"
2539+
)
2540+
2541+
# Store O
2542+
for li, lj in T.grid(tile_x, tile_y):
2543+
with T.sblock("O_store"):
2544+
i, j = T.axis.remap("SS", [li, lj])
2545+
cur_L: T.int32 = 0 + (LH_start + i) // group_size
2546+
cur_H_qo: T.int32 = (
2547+
by * group_size + (LH_start + i) % group_size
2548+
)
2549+
if cur_L < qo_len:
2550+
output[b_idx, cur_L, cur_H_qo, j] = (
2551+
O_local[i, j] / d_smem[i]
2552+
)
2553+
2554+
# Store LSE
2555+
for li in T.grid(tile_x):
2556+
with T.sblock("lse_store"):
2557+
i = T.axis.remap("S", [li])
2558+
cur_L: T.int32 = 0 + (LH_start + i) // group_size
2559+
cur_H_qo: T.int32 = (
2560+
by * group_size + (LH_start + i) % group_size
2561+
)
2562+
if cur_L < qo_len:
2563+
lse[b_idx, cur_L, cur_H_qo] = m_smem[i] + T.log2(
2564+
d_smem[i]
2565+
)
2566+
2567+
# fmt: on
2568+
sch = tvm.s_tir.Schedule(batch_sequence_prefill_kv_masked)
2569+
sch = _schedule_prefill_kernel(
2570+
sch, LOAD_VEC, bdx, num_warps, tile_x, tile_y, tile_z, False, False
2571+
)
2572+
return sch.mod["main"].with_attr("tirx.is_scheduled", True)
2573+
2574+
23082575
def _attention_prefill_ragged_cpu(h_kv, h_q, d_qk, d_v, dtype, rope_scaling: dict[str, Any]):
23092576
group_size = h_q // h_kv
23102577

0 commit comments

Comments
 (0)