Skip to content

Commit 0a79095

Browse files
authored
[S-TIR] Fix cache_read/cache_write region when inner block has T.whe… (#19406)
…re predicate When the actual buffer access is gated by T.where on a nested (inner) sblock, the outer block's own predicate is trivially true. Both cache_write and cache_read were computing cache regions based only on that outer predicate, producing allocations as large as the full loop extent instead of the guarded region Fix: - Add CollectNestedBlockPredicates(), a single helper parameterised by BufferIndexType (kRead / kWrite) that walks the outer block's body, finds nested sblocks accessing the target buffer, and AND-combines their predicates after substituting iter-var bindings into the outer scope. - Add extra_predicate parameter to RelaxBufferRegion() and AND it with the block's own predicate before region relaxation. - cache_write: pass the collected nested-write predicate to RelaxBufferRegion so the cache allocation is tightened. - cache_read (Case 2 — input buffer): when a non-trivial nested-read predicate exists, relax the consumer block's declared read region under that predicate; otherwise fall back to the original scope_block->reads path (preserves non-int32 dtypes in extents).
1 parent 9d13fc0 commit 0a79095

2 files changed

Lines changed: 273 additions & 8 deletions

File tree

src/s_tir/schedule/primitive/cache_read_write.cc

Lines changed: 103 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -542,26 +542,94 @@ bool AllConsumersUnderStmt(ScheduleState self, Buffer buffer, StmtSRef scope_sre
542542
return true;
543543
}
544544

545+
/*!
546+
* \brief Collect OR-combined predicates from all nested BlockRealize nodes within
547+
* the given statement that access the specified buffer (read or write, controlled by
548+
* \p index_type). Each nested block's predicate is expressed in the enclosing block's
549+
* scope by substituting the nested block's iter var bindings. This is needed when the
550+
* actual access is gated by a predicate (T.where) on a nested block while the outer
551+
* block has a trivially-true predicate. Sibling blocks that each access the buffer under
552+
* different predicates are OR-ed together so the result covers the union of their access
553+
* regions.
554+
* \param body The body statement of the outer block to search within.
555+
* \param buffer The buffer being accessed.
556+
* \param index_type Whether to look for reads (kRead) or writes (kWrite).
557+
* \return The OR-combination of all nested block predicates found.
558+
*/
559+
static PrimExpr CollectNestedBlockPredicates(const Stmt& body, const Buffer& buffer,
560+
BufferIndexType index_type) {
561+
struct Collector : public StmtVisitor {
562+
Collector(const Buffer& buf, BufferIndexType idx_type)
563+
: buffer_(buf), index_type_(idx_type), result_(Bool(false)), found_(false) {}
564+
565+
void VisitStmt_(const SBlockRealizeNode* realize) final {
566+
const SBlockNode* block = realize->block.get();
567+
const auto& regions =
568+
(index_type_ == BufferIndexType::kRead) ? block->reads : block->writes;
569+
bool accesses_buffer = false;
570+
for (const BufferRegion& region : regions) {
571+
if (region->buffer.same_as(buffer_)) {
572+
accesses_buffer = true;
573+
break;
574+
}
575+
}
576+
if (accesses_buffer) {
577+
// Build substitution: nested block iter vars -> their binding values
578+
// (which are already expressed in terms of the outer scope).
579+
ffi::Map<Var, PrimExpr> subst;
580+
for (size_t i = 0; i < block->iter_vars.size(); ++i) {
581+
subst.Set(block->iter_vars[i]->var, realize->iter_values[i]);
582+
}
583+
PrimExpr pred =
584+
subst.empty() ? realize->predicate : Substitute(realize->predicate, subst);
585+
// OR the predicates across all accessing nested blocks: each such block is an
586+
// independent alternative access path (sibling blocks in a SeqStmt), so the
587+
// cache must cover the *union* of their access regions, not the intersection.
588+
// Using AND (the previous behaviour) underestimates the required region when
589+
// sibling blocks have non-overlapping predicates.
590+
result_ = found_ ? (result_ || pred) : pred;
591+
found_ = true;
592+
}
593+
// Continue recursing into deeper nested blocks.
594+
StmtVisitor::VisitStmt_(realize);
595+
}
596+
597+
const Buffer& buffer_;
598+
BufferIndexType index_type_;
599+
PrimExpr result_;
600+
bool found_;
601+
};
602+
603+
Collector collector(buffer, index_type);
604+
collector(body);
605+
// If no nested block accessed the buffer, return true (no restriction — the caller
606+
// will fall back to the original scope-block reads / FullRegion path).
607+
return collector.found_ ? collector.result_ : Bool(true);
608+
}
609+
545610
/*!
546611
* \brief Get the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive)
547612
* \param self The state of the schedule.
548613
* \param buffer_region The buffer region to be analyzed.
549614
* \param block_sref The sref of the block related to the region.
550615
* \param dom_low_inclusive The lowest node in the sref tree path.
551616
* \param dom_high_exclusive The highest node in the sref tree path.
617+
* \param extra_predicate An additional predicate (e.g. collected from nested blocks) to AND
618+
* with the block's own predicate before relaxation. Defaults to true (no effect).
552619
* \return The relaxed buffer region.
553620
*/
554621
BufferRegion RelaxBufferRegion(ScheduleState self, const BufferRegion& buffer_region,
555622
const StmtSRef& block_sref, const StmtSRef& dom_low_inclusive,
556-
const StmtSRef& dom_high_exclusive) {
623+
const StmtSRef& dom_high_exclusive,
624+
PrimExpr extra_predicate = Bool(true)) {
557625
SBlockRealize realize = GetSBlockRealize(self, block_sref);
558626
ffi::Map<Var, PrimExpr> binding = GetBindings(realize);
559627
const Buffer& buffer = buffer_region->buffer;
560628
arith::Analyzer analyzer;
561629
BufferRegion subst_region = BufferRegion(buffer, Substitute(buffer_region->region, binding));
562630
ffi::Array<arith::IntSet> int_sets = AnalyzeRegionUpperBound(
563631
/*region=*/subst_region,
564-
/*predicate=*/realize->predicate,
632+
/*predicate=*/Substitute(realize->predicate && extra_predicate, binding),
565633
/*dom_low_inclusive=*/dom_low_inclusive,
566634
/*dom_high_exclusive=*/dom_high_exclusive,
567635
/*analyzer=*/&analyzer);
@@ -1703,9 +1771,25 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff
17031771
// Case 2. The buffer is the input block for the scope.
17041772
info.loc_sref = scope_sref;
17051773
info.loc_pos = 0;
1706-
if (ffi::Optional<BufferRegion> region =
1707-
GetBufferRegionFromBuffer(scope_block->reads, read_buffer)) {
1708-
cache_region = region.value();
1774+
// When a nested block gates the actual read with T.where, the consumer block's own
1775+
// predicate is trivially true, so the scope-block read annotation covers the full loop
1776+
// range. Collect nested-read predicates and, if any are non-trivial, relax the consumer
1777+
// block's read region under that predicate to get a tighter cache allocation.
1778+
// Without a nested predicate we fall back to scope_block->reads (which preserves the
1779+
// original buffer's dtype in its extents, e.g. int64 shapes).
1780+
ffi::Optional<BufferRegion> read_region_opt =
1781+
GetBufferRegionFromBuffer(block->reads, read_buffer);
1782+
PrimExpr nested_pred =
1783+
read_region_opt
1784+
? CollectNestedBlockPredicates(block->body, read_buffer, BufferIndexType::kRead)
1785+
: Bool(true);
1786+
if (read_region_opt && !is_one(nested_pred) && block_sref->parent != nullptr) {
1787+
StmtSRef parent_sref = ffi::GetRef<StmtSRef>(block_sref->parent);
1788+
cache_region = RelaxBufferRegion(self, read_region_opt.value(), block_sref, parent_sref,
1789+
scope_sref, nested_pred);
1790+
} else if (ffi::Optional<BufferRegion> scope_region =
1791+
GetBufferRegionFromBuffer(scope_block->reads, read_buffer)) {
1792+
cache_region = scope_region.value();
17091793
} else {
17101794
cache_region = BufferRegion::FullRegion(read_buffer);
17111795
}
@@ -1782,11 +1866,22 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
17821866

17831867
// Step 4. Find the producing region and insert position
17841868
BufferRegion region = GetBufferRegionFromBuffer(block->writes, write_buffer).value();
1785-
StmtSRef parent_sref = ffi::GetRef<StmtSRef>(block_sref->parent);
17861869
// Detect insert position
17871870
CacheLocDetector::Detect</*is_cache_read=*/false>(self, block_sref, scope_sref, &info);
1788-
BufferRegion cache_region =
1789-
RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref);
1871+
// Collect predicates from any nested blocks that gate the actual write (e.g. T.where on an
1872+
// inner block). The outer block's own predicate may be trivially true even though the write
1873+
// is restricted by a nested predicate, so we OR them together for a tighter region estimate.
1874+
PrimExpr nested_write_pred =
1875+
CollectNestedBlockPredicates(block->body, write_buffer, BufferIndexType::kWrite);
1876+
BufferRegion cache_region;
1877+
if (block_sref->parent != nullptr) {
1878+
StmtSRef parent_sref = ffi::GetRef<StmtSRef>(block_sref->parent);
1879+
cache_region =
1880+
RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref, nested_write_pred);
1881+
} else {
1882+
// Root block: no enclosing loops to relax over, use the write region directly.
1883+
cache_region = region;
1884+
}
17901885

17911886
bool cache_full_region = info.loc_sref->StmtAs<SBlockNode>() == nullptr ||
17921887
!AllConsumersUnderStmt(self, write_buffer, scope_sref, info.loc_sref);

tests/python/s_tir/schedule/test_tir_schedule_cache_read_write.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1670,5 +1670,175 @@ def test_symbolic_matmul_blocked_cache_write(use_block_name):
16701670
verify_trace_roundtrip(sch=sch, mod=symbolic_matmul_blocked)
16711671

16721672

1673+
def test_cache_write_with_nested_block_predicate():
1674+
@T.prim_func
1675+
def main(A: T.handle, C: T.handle) -> None:
1676+
A_buf = T.match_buffer(A, (12, 24), "float32")
1677+
C_buf = T.match_buffer(C, (10, 20), "float32")
1678+
1679+
for i, j in T.grid(12, 24):
1680+
with T.sblock("compute"):
1681+
vi, vj = T.axis.remap("SS", [i, j])
1682+
1683+
with T.sblock("inner"):
1684+
T.where(vi < 10 and vj < 20)
1685+
C_buf[vi, vj] = A_buf[vi, vj] * 2.0
1686+
1687+
@T.prim_func
1688+
def expected(A_buf: T.Buffer((12, 24), "float32"), C_buf: T.Buffer((10, 20), "float32")):
1689+
with T.sblock("root"):
1690+
C_buf_local = T.sblock_alloc_buffer((10, 20), scope="local")
1691+
for i, j in T.grid(12, 24):
1692+
with T.sblock("compute"):
1693+
vi, vj = T.axis.remap("SS", [i, j])
1694+
T.reads(A_buf[vi, vj])
1695+
T.writes(C_buf_local[vi, vj])
1696+
with T.sblock("inner"):
1697+
T.where(vi < 10 and vj < 20)
1698+
T.reads(A_buf[vi, vj])
1699+
T.writes(C_buf_local[vi, vj])
1700+
C_buf_local[vi, vj] = A_buf[vi, vj] * T.float32(2)
1701+
for ax0, ax1 in T.grid(10, 20):
1702+
with T.sblock("C_buf_local"):
1703+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
1704+
T.reads(C_buf_local[v0, v1])
1705+
T.writes(C_buf[v0, v1])
1706+
C_buf[v0, v1] = C_buf_local[v0, v1]
1707+
1708+
sch = tvm.s_tir.Schedule(main)
1709+
block = sch.get_sblock("compute")
1710+
sch.cache_write(block, 0, "local")
1711+
assert_structural_equal_ignore_global_symbol(expected, sch.mod["main"])
1712+
1713+
1714+
def test_cache_read_with_nested_block_predicate():
1715+
@T.prim_func
1716+
def main(A: T.handle, C: T.handle) -> None:
1717+
A_buf = T.match_buffer(A, (12, 24), "float32")
1718+
C_buf = T.match_buffer(C, (10, 20), "float32")
1719+
1720+
for i, j in T.grid(12, 24):
1721+
with T.sblock("compute"):
1722+
vi, vj = T.axis.remap("SS", [i, j])
1723+
1724+
with T.sblock("inner"):
1725+
T.where(vi < 10 and vj < 20)
1726+
C_buf[vi, vj] = A_buf[vi, vj] * 2.0
1727+
1728+
@T.prim_func
1729+
def expected(A_buf: T.Buffer((12, 24), "float32"), C_buf: T.Buffer((10, 20), "float32")):
1730+
with T.sblock("root"):
1731+
A_buf_local = T.sblock_alloc_buffer((10, 20), scope="local")
1732+
for ax0, ax1 in T.grid(10, 20):
1733+
with T.sblock("A_buf_local"):
1734+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
1735+
T.reads(A_buf[v0, v1])
1736+
T.writes(A_buf_local[v0, v1])
1737+
A_buf_local[v0, v1] = A_buf[v0, v1]
1738+
for i, j in T.grid(12, 24):
1739+
with T.sblock("compute"):
1740+
vi, vj = T.axis.remap("SS", [i, j])
1741+
T.reads(A_buf_local[vi, vj])
1742+
T.writes(C_buf[vi, vj])
1743+
with T.sblock("inner"):
1744+
T.where(vi < 10 and vj < 20)
1745+
T.reads(A_buf_local[vi, vj])
1746+
T.writes(C_buf[vi, vj])
1747+
C_buf[vi, vj] = A_buf_local[vi, vj] * T.float32(2)
1748+
1749+
sch = tvm.s_tir.Schedule(main)
1750+
block = sch.get_sblock("compute")
1751+
sch.cache_read(block, 0, "local")
1752+
assert_structural_equal_ignore_global_symbol(expected, sch.mod["main"])
1753+
1754+
1755+
def test_cache_write_sibling_nested_block_predicates_use_union():
1756+
"""Regression: cache_write with sibling nested blocks must union their predicates.
1757+
1758+
Two sibling nested sblocks access the same buffer under *different* predicates:
1759+
left block: T.where(vi < 8) — writes rows 0-7, all columns
1760+
top block: T.where(vj < 16) — writes all rows, columns 0-15
1761+
1762+
The cache must cover the UNION of both access sets. The bounding box of that
1763+
union is (12, 24) — the full buffer shape.
1764+
1765+
Bug: CollectNestedBlockPredicates ANDs the predicates of all found nested blocks,
1766+
giving (vi < 8) AND (vj < 16). RelaxBufferRegion under that intersection predicate
1767+
yields the bounding box of the *intersection* instead: (8, 16), which is too small.
1768+
The "left" block then writes C_buf_local[vi, vj] for vi in [8,12) — indices that
1769+
were never loaded into C_buf_local — resulting in incorrect output.
1770+
"""
1771+
1772+
@T.prim_func
1773+
def main(A: T.handle, C: T.handle) -> None:
1774+
A_buf = T.match_buffer(A, (12, 24), "float32")
1775+
C_buf = T.match_buffer(C, (12, 24), "float32")
1776+
for i, j in T.grid(12, 24):
1777+
with T.sblock("compute"):
1778+
vi, vj = T.axis.remap("SS", [i, j])
1779+
with T.sblock("left"):
1780+
T.where(vi < 8)
1781+
C_buf[vi, vj] = A_buf[vi, vj] * 2.0
1782+
with T.sblock("top"):
1783+
T.where(vj < 16)
1784+
C_buf[vi, vj] = A_buf[vi, vj] * 3.0
1785+
1786+
sch = tvm.s_tir.Schedule(main)
1787+
block = sch.get_sblock("compute")
1788+
sch.cache_write(block, 0, "local")
1789+
1790+
# Extract the alloc buffer shape from the resulting IR.
1791+
result_script = sch.mod["main"].script()
1792+
# The cache must be large enough to hold the union of both write regions.
1793+
# Union bounding box = full (12, 24). The buggy AND gives (8, 16).
1794+
assert "sblock_alloc_buffer((12, 24)" in result_script, (
1795+
f"Expected cache shape (12, 24) covering the union of both write regions, "
1796+
f"but got a smaller shape. Full IR:\n{result_script}"
1797+
)
1798+
1799+
1800+
def test_cache_read_sibling_nested_block_predicates_use_union():
1801+
"""Regression: cache_read with sibling nested blocks must union their predicates.
1802+
1803+
Two sibling nested sblocks read the same input buffer under different predicates:
1804+
left block: T.where(vi < 8) — reads rows 0-7, all columns
1805+
top block: T.where(vj < 16) — reads all rows, columns 0-15
1806+
1807+
The cache must cover the UNION of both read sets. The bounding box of that
1808+
union is (12, 24) — the full buffer shape.
1809+
1810+
Bug: CollectNestedBlockPredicates ANDs the two predicates, giving (vi < 8) AND
1811+
(vj < 16). Case 2 of CacheRead calls RelaxBufferRegion under that intersection
1812+
predicate, producing a cache of shape (8, 16). The "left" block then tries to
1813+
read A_buf_local[vi, vj] for vi in [8,12) — indices outside the cache — which
1814+
is incorrect.
1815+
"""
1816+
1817+
@T.prim_func
1818+
def main(A: T.handle, C: T.handle) -> None:
1819+
A_buf = T.match_buffer(A, (12, 24), "float32")
1820+
C_buf = T.match_buffer(C, (12, 24), "float32")
1821+
for i, j in T.grid(12, 24):
1822+
with T.sblock("compute"):
1823+
vi, vj = T.axis.remap("SS", [i, j])
1824+
with T.sblock("left"):
1825+
T.where(vi < 8)
1826+
C_buf[vi, vj] = A_buf[vi, vj] * 2.0
1827+
with T.sblock("top"):
1828+
T.where(vj < 16)
1829+
C_buf[vi, vj] = A_buf[vi, vj] * 3.0
1830+
1831+
sch = tvm.s_tir.Schedule(main)
1832+
block = sch.get_sblock("compute")
1833+
sch.cache_read(block, 0, "local")
1834+
1835+
result_script = sch.mod["main"].script()
1836+
# Cache must cover the union bounding box (12, 24). Buggy AND gives (8, 16).
1837+
assert "sblock_alloc_buffer((12, 24)" in result_script, (
1838+
f"Expected cache shape (12, 24) covering the union of both read regions, "
1839+
f"but got a smaller shape. Full IR:\n{result_script}"
1840+
)
1841+
1842+
16731843
if __name__ == "__main__":
16741844
tvm.testing.main()

0 commit comments

Comments
 (0)