Skip to content

Commit f6a16e1

Browse files
authored
Allow compact block sparse index tensors (#2417)
* Allow compact block sparse index tensors Relax validation in block_sparsity.py to allow idx.shape[3] <= expected_n_blocks instead of requiring exact equality. FA4 only accesses indices 0..cnt-1 per query tile, so the index tensor's last dimension does not need to be as large as ceil(seqlen_k / block_size_n). This enables memory-efficient compact index tensors that avoid O(N^2) memory at long sequence lengths (e.g., 1M+ tokens for sparse attention / NSA workloads). Changes: - _check_and_expand_block: accept compact n-block dimension and expand only the batch/head/m-block dimensions - infer_block_sparse_expected_shapes: change strict equality check to upper-bound check (error only when n-blocks exceeds expected, not when smaller) Backward compatible: existing code that passes full-sized tensors is unaffected. * Add test for compact block sparse index tensors Verify that truncating block sparse index tensors to idx.shape[3] = max(cnt) (instead of the full ceil(seqlen_k / block_size_n)) produces bit-identical output to full-sized tensors. This validates the relaxed validation from the previous commit.
1 parent 29e40cf commit f6a16e1

File tree

2 files changed

+96
-2
lines changed

2 files changed

+96
-2
lines changed

flash_attn/cute/block_sparsity.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ def _check_and_expand_block(
9898
expanded_cnt = _expand_sparsity_tensor(
9999
cnt, expected_count_shape, f"{name}_block_cnt", context, hint
100100
)
101+
# [Note] Allow Compact block sparse indices
102+
# Allow the last dimension (n_blocks) of idx to be <= expected, since
103+
# FA4 only accesses indices 0..cnt-1 per query tile. This enables compact
104+
# index tensors that avoid O(N^2) memory at long sequence lengths.
105+
if idx.ndim == 4 and idx.shape[3] <= expected_index_shape[3]:
106+
expected_index_shape = (*expected_index_shape[:3], idx.shape[3])
101107
expanded_idx = _expand_sparsity_tensor(
102108
idx, expected_index_shape, f"{name}_block_idx", context, hint
103109
)
@@ -200,9 +206,11 @@ def infer_block_sparse_expected_shapes(
200206
raise ValueError(f"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.")
201207
if mask_block_cnt.shape[2] != mask_block_idx.shape[2]:
202208
raise ValueError(f"Block sparse tensors{context} must share the same m-block dimension.")
203-
if mask_block_idx.shape[3] != expected_n_blocks:
209+
# [Note] Allow Compact block sparse indices: FA4 only accesses indices 0..cnt-1
210+
# per query tile, so idx.shape[3] can be <= expected_n_blocks.
211+
if mask_block_idx.shape[3] > expected_n_blocks:
204212
raise ValueError(
205-
f"Block sparse tensors{context} n-block dimension must be {expected_n_blocks}."
213+
f"Block sparse tensors{context} n-block dimension must be <= {expected_n_blocks}."
206214
)
207215
if expected_m_blocks != num_m_blocks:
208216
raise ValueError(

tests/cute/test_mask_mod.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,5 +1712,91 @@ def test_persistent_blocksparse_empty_tiles():
17121712

17131713

17141714

1715+
def test_compact_block_sparse_indices():
1716+
"""Test that compact block sparse index tensors (idx.shape[3] < n_blocks) work correctly.
1717+
1718+
FA4 only accesses indices 0..cnt-1 per query tile, so the index tensor's last
1719+
dimension does not need to be as large as ceil(seqlen_k / block_size_n). This
1720+
test verifies that truncated (compact) index tensors produce identical output
1721+
to full-sized ones.
1722+
"""
1723+
torch.manual_seed(42)
1724+
batch_size = 1
1725+
nheads = 4
1726+
seqlen_q = 1024
1727+
seqlen_k = 1024
1728+
headdim = 128
1729+
tile_m = 128
1730+
tile_n = 128
1731+
dtype = torch.bfloat16
1732+
1733+
sparse_tile_m = 2 * tile_m if COMPUTE_CAPABILITY == 10 else tile_m
1734+
1735+
mask_mod_cute, mask_mod_flex = get_mask_pair(
1736+
"block_diagonal", seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=None
1737+
)
1738+
tensors = create_tensors(
1739+
batch_size, seqlen_q, seqlen_k, nheads, nheads, headdim, headdim, dtype
1740+
)
1741+
1742+
bm = create_block_mask(
1743+
mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k,
1744+
device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n),
1745+
)
1746+
(_, _, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, *_) = bm.as_tuple()
1747+
1748+
# Determine the max count across all query tiles — this is the compact last dim
1749+
max_mask_k = kv_mask_cnt.max().item() if kv_mask_cnt is not None else 0
1750+
max_full_k = full_kv_cnt.max().item() if full_kv_cnt is not None else 0
1751+
max_k = max(max_mask_k, max_full_k, 1)
1752+
1753+
# Truncate index tensors to compact size
1754+
kv_mask_idx_compact = kv_mask_idx[:, :, :, :max_k].contiguous()
1755+
full_kv_idx_compact = full_kv_idx[:, :, :, :max_k].contiguous() if full_kv_idx is not None else None
1756+
1757+
block_sparse_compact = BlockSparseTensorsTorch(
1758+
mask_block_cnt=kv_mask_cnt,
1759+
mask_block_idx=kv_mask_idx_compact,
1760+
full_block_cnt=full_kv_cnt,
1761+
full_block_idx=full_kv_idx_compact,
1762+
block_size=(sparse_tile_m, tile_n),
1763+
)
1764+
1765+
out_compact, _ = _flash_attn_fwd(
1766+
q=tensors["q"], k=tensors["k"], v=tensors["v"],
1767+
out=tensors["out"].clone(), lse=tensors["lse"].clone(),
1768+
softmax_scale=1.0 / math.sqrt(headdim),
1769+
causal=False, mask_mod=mask_mod_cute,
1770+
block_sparse_tensors=block_sparse_compact,
1771+
return_lse=True,
1772+
)
1773+
1774+
# Reference: use full-sized index tensors
1775+
block_sparse_full = BlockSparseTensorsTorch(
1776+
mask_block_cnt=kv_mask_cnt,
1777+
mask_block_idx=kv_mask_idx,
1778+
full_block_cnt=full_kv_cnt,
1779+
full_block_idx=full_kv_idx,
1780+
block_size=(sparse_tile_m, tile_n),
1781+
)
1782+
1783+
out_full, _ = _flash_attn_fwd(
1784+
q=tensors["q"], k=tensors["k"], v=tensors["v"],
1785+
out=tensors["out"].clone(), lse=tensors["lse"].clone(),
1786+
softmax_scale=1.0 / math.sqrt(headdim),
1787+
causal=False, mask_mod=mask_mod_cute,
1788+
block_sparse_tensors=block_sparse_full,
1789+
return_lse=True,
1790+
)
1791+
1792+
assert not torch.isnan(out_compact).any(), "Compact output has NaN"
1793+
assert torch.isfinite(out_compact).all(), "Compact output has Inf"
1794+
# Compact and full should produce bit-identical results
1795+
assert torch.equal(out_compact, out_full), (
1796+
f"Compact and full block sparse outputs differ: "
1797+
f"max diff = {(out_compact - out_full).abs().max().item():.2e}"
1798+
)
1799+
1800+
17151801
if __name__ == "__main__":
17161802
pytest.main([__file__, "-v", "-s"])

0 commit comments

Comments
 (0)