Skip to content

Commit 9d13fc0

Browse files
authored
[S-TIR] Fix Segfault when applying Parallel during TIR schedule rewriting (#19403)
Hi Commiters, This PR is trying to fix issues #18424. Any suggestions would be appreciated if you are available. ### Root Cause Unsafe dynamic-shape dereferences in `AdjustParallelVectorize` The code assumed IntImm for buffer shape / loop extent and dereferenced directly. With dynamic shapes, as<IntImmNode>() can be null, which can segfault before any try/catch handles it. ### Solution Replaced unsafe `IntImm` assumptions with null checks and GetLoopIntExtent(...); if contiguous analysis is not possible, conservatively disables that path instead of dereferencing null. --------- Co-authored-by: cchung100m <cchung100m@users.noreply.github.com>
1 parent b343943 commit 9d13fc0

2 files changed

Lines changed: 43 additions & 4 deletions

File tree

src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ void AdjustParallelVectorize(const Schedule& sch, const SBlockRV& block_rv,
213213
// (vectorizable) axes
214214
for (const BufferRegion& access : buffer_access) {
215215
int fusible = 0;
216+
bool can_analyze_contiguous_access = true;
216217
std::vector<int64_t> strides;
217218
// get strides for each loop var
218219
for (const StmtSRef& loop_sref : loop_srefs) {
@@ -226,10 +227,22 @@ void AdjustParallelVectorize(const Schedule& sch, const SBlockRV& block_rv,
226227
stride = coef * buffer_stride;
227228
break;
228229
}
229-
buffer_stride *= access->buffer->shape[i].as<IntImmNode>()->value;
230+
const auto* shape = access->buffer->shape[i].as<IntImmNode>();
231+
if (shape == nullptr) {
232+
can_analyze_contiguous_access = false;
233+
break;
234+
}
235+
buffer_stride *= shape->value;
236+
}
237+
if (!can_analyze_contiguous_access) {
238+
break;
230239
}
231240
strides.push_back(stride);
232241
}
242+
if (!can_analyze_contiguous_access) {
243+
max_fusible = 0;
244+
break;
245+
}
233246
int prev_used_iter = -1;
234247
// check the number of fusible loops
235248
for (int i = strides.size() - 1; i >= 0; i--) {
@@ -246,9 +259,11 @@ void AdjustParallelVectorize(const Schedule& sch, const SBlockRV& block_rv,
246259
prev_used_iter = i;
247260
} else {
248261
// contiguous memory access
249-
const auto* prev_loop = loop_srefs[prev_used_iter]->StmtAs<ForNode>();
250-
int64_t prev_used_iter_extent = prev_loop->extent.as<IntImmNode>()->value;
251-
if (strides[i] == strides[prev_used_iter] * prev_used_iter_extent) {
262+
const int64_t* prev_used_iter_extent = GetLoopIntExtent(loop_srefs[prev_used_iter]);
263+
if (prev_used_iter_extent == nullptr) {
264+
break;
265+
}
266+
if (strides[i] == strides[prev_used_iter] * (*prev_used_iter_extent)) {
252267
fusible++;
253268
prev_used_iter = i;
254269
} else {

tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,24 @@ def after_postproc_add(
181181
add_compute[v0, v1, v2, v3, v4] = lhs[v0, v1, v2, v3, v4] + rhs[v0, v1, v2, v3, v4]
182182

183183

184+
@T.prim_func
185+
def before_postproc_dynamic_shape_vectorize(
186+
a: T.handle,
187+
b: T.handle,
188+
) -> None:
189+
n = T.int64()
190+
A = T.match_buffer(a, (n,), dtype="float32")
191+
B = T.match_buffer(b, (n,), dtype="float32")
192+
with T.block("root"):
193+
T.block_attr({"meta_schedule.vectorize": 64})
194+
for i in T.serial(0, n):
195+
with T.block("copy"):
196+
vi = T.axis.spatial(n, i)
197+
T.reads(A[vi])
198+
T.writes(B[vi])
199+
B[vi] = A[vi]
200+
201+
184202
# fmt: on
185203
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable
186204

@@ -269,5 +287,11 @@ def expected(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "flo
269287
assert_structural_equal_ignore_global_symbol(mod["main"], expected)
270288

271289

290+
def test_rewrite_parallel_vectorize_unroll_dynamic_shape_no_crash():
291+
sch = Schedule(before_postproc_dynamic_shape_vectorize)
292+
rule = RewriteParallelVectorizeUnroll()
293+
assert rule.apply(sch)
294+
295+
272296
if __name__ == "__main__":
273297
tvm.testing.main()

0 commit comments

Comments
 (0)