Skip to content

Commit b343943

Browse files
authored
[Relax][Frontend][KVCache] Restructure kv_cache kernels (#19405)
Pure refactor — does not change any generated TIR / kernel behavior. Dedupe the tiled prefill kernels in kv_cache.py by extracting the shared online-softmax pieces as T.macro helpers (init_states, compute_s_gemm, softmax_update_{causal,valid_length}, compute_o_gemm, advance_tile_batch, paged_store_output_lse), plus Python helpers for the common buffer allocations (softmax state, MHA/MLA Q/K/V/O, tile-walk scalars). Split the kernel factories out of kv_cache.py into private sibling modules: _kernel_common.py (shared helpers + macros + schedule), _page_kernels.py (append/debug/copy/compact), _prefill_kernels.py (paged/ragged/MLA/dense/masked-sequence), _decode_kernels.py (decode + state merge). kv_cache.py now holds only the PagedKVCache classes and re-exports every moved symbol so existing imports keep working. tree_attn.py also switches to the shared helpers. kv_cache.py drops from 2815 to 668 lines; the package is ~2.4k lines smaller overall. No test files modified; GPU tests pass unchanged (72 passed, 4 pre-existing skips).
1 parent 1b94055 commit b343943

6 files changed

Lines changed: 2464 additions & 2958 deletions

File tree

python/tvm/relax/frontend/nn/llm/_decode_kernels.py

Lines changed: 526 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)