Skip to content

Commit 4533d31

Browse files
authored
[Relax] Add FDataDependent operator attribute for LegalizeOps (#18664)
## Why The LegalizeOps transform was using string matching to detect data-dependent operators by checking if "dynamic" appears in the operator name. This approach is fragile and doesn't scale well as new data-dependent operators are added. ## How - Add FDataDependent operator attribute to properly mark data-dependent operators - Set FDataDependent=true for relax.dynamic_strided_slice operator - Update LegalizeOps transform to check the FDataDependent attribute instead of string matching
1 parent 15ac9db commit 4533d31

3 files changed

Lines changed: 24 additions & 10 deletions

File tree

src/relax/op/tensor/index.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,8 @@ TVM_REGISTER_OP("relax.dynamic_strided_slice")
574574
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoDynStridedSlice)
575575
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutDynStridedSlice)
576576
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
577-
.set_attr<Bool>("FPurity", Bool(true));
577+
.set_attr<Bool>("FPurity", Bool(true))
578+
.set_attr<Bool>("FDataDependent", Bool(true));
578579

579580
} // namespace relax
580581
} // namespace tvm

src/relax/transform/legalize_ops.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,15 @@ class LegalizeMutator : public ExprMutator {
287287
return false;
288288
}
289289

290-
std::string op_name(op->name);
291-
bool is_data_dependent_op = (op_name.find("dynamic") != std::string::npos);
290+
bool is_data_dependent_op = [&]() -> bool {
291+
if (Op::HasAttrMap("FDataDependent")) {
292+
auto op_map = Op::GetAttrMap<Bool>("FDataDependent");
293+
if (op_map.count(op)) {
294+
return op_map[op]->value;
295+
}
296+
}
297+
return false;
298+
}();
292299
bool ret_shape_defined = KnowAllShapeValues(GetStructInfo(visited_call));
293300
if (!is_data_dependent_op && !ret_shape_defined) {
294301
// This operator cannot be legalized, because legalization by
@@ -303,10 +310,6 @@ class LegalizeMutator : public ExprMutator {
303310
// data-dependent op, and match cast to define symbolic output
304311
// shapes. These symbolic output shapes at compile time can
305312
// be by later operations to refer to the runtime shape.
306-
//
307-
// TODO(Lunderberg): Make a new operator attribute
308-
// `.set_attr<Bool>("DataDependent")`, rather than relying on
309-
// the name of the operator.
310313
return false;
311314
}
312315

tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
# under the License.
1717

1818
import tvm
19-
from tvm.relax.transform import LegalizeOps
20-
from tvm.script import relax as R, tir as T, ir as I
2119
import tvm.testing
22-
20+
from tvm.ir import Op
21+
from tvm.relax.transform import LegalizeOps
22+
from tvm.script import ir as I
23+
from tvm.script import relax as R
24+
from tvm.script import tir as T
2325

2426
##################### Indexing #####################
2527

@@ -1197,5 +1199,13 @@ def einsum(
11971199
tvm.ir.assert_structural_equal(mod, Expected)
11981200

11991201

1202+
def test_data_dependent_attribute():
1203+
dynamic_strided_slice_op = Op.get("relax.dynamic_strided_slice")
1204+
assert dynamic_strided_slice_op.get_attr("FDataDependent")
1205+
1206+
strided_slice_op = Op.get("relax.strided_slice")
1207+
assert strided_slice_op.get_attr("FDataDependent") is None
1208+
1209+
12001210
if __name__ == "__main__":
12011211
tvm.testing.main()

0 commit comments

Comments
 (0)