Commit b343943
authored
[Relax][Frontend][KVCache] Restructure kv_cache kernels (#19405)
Pure refactor — does not change any generated TIR / kernel behavior.
Dedupe the tiled prefill kernels in kv_cache.py by extracting the shared
online-softmax pieces as T.macro helpers (init_states, compute_s_gemm,
softmax_update_{causal,valid_length}, compute_o_gemm,
advance_tile_batch, paged_store_output_lse), plus Python helpers for the
common buffer allocations (softmax state, MHA/MLA Q/K/V/O, tile-walk
scalars).
Split the kernel factories out of kv_cache.py into private sibling
modules: _kernel_common.py (shared helpers + macros + schedule),
_page_kernels.py (append/debug/copy/compact), _prefill_kernels.py
(paged/ragged/MLA/dense/masked-sequence), _decode_kernels.py (decode +
state merge). kv_cache.py now holds only the PagedKVCache classes and
re-exports every moved symbol so existing imports keep working.
tree_attn.py also switches to the shared helpers.
kv_cache.py drops from 2815 to 668 lines; the package is ~2.4k lines
smaller overall. No test files modified; GPU tests pass unchanged (72
passed, 4 pre-existing skips).1 parent 1b94055 commit b343943
6 files changed
Lines changed: 2464 additions & 2958 deletions
File tree
- python/tvm/relax/frontend/nn/llm
Large diffs are not rendered by default.
0 commit comments