Skip to content

Commit 75a6b30

Browse files
authored
[Relax][Frontend][TFLite] Fix and test MATRIX_DIAG, MATRIX_SET_DIAG, SPARSE_TO_DENSE (#19408)
This PR partially implements test coverage requested in issue #18971 for Relax TFLite frontend operator tests. ## Bug Fix The TFLite frontend converters for `MATRIX_DIAG`, `MATRIX_SET_DIAG`, and `SPARSE_TO_DENSE` were broken due to calling non-existent Relax ops: - `relax.op.matrix_set_diag` - never registered in Relax - `relax.op.sparse_to_dense` - never registered in Relax These ops only exist as TOPI packed functions (`topi.matrix_set_diag`, `topi.sparse_to_dense`). **Fix:** Replace direct op calls with `call_dps_packed` to invoke the TOPI packed functions: - `convert_matrix_diag`: zeros + call_dps_packed("topi.matrix_set_diag", ...) - `convert_matrix_set_diag`: call_dps_packed("topi.matrix_set_diag", ...) - `convert_sparse_to_dense`: call_dps_packed("topi.sparse_to_dense", ...) Refs: #18971
1 parent 5c17111 commit 75a6b30

File tree

2 files changed

+153
-12
lines changed

2 files changed

+153
-12
lines changed

python/tvm/relax/frontend/tflite/tflite_frontend.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3015,11 +3015,20 @@ def convert_sparse_to_dense(self, op):
30153015
t_type = t.tensor.Type()
30163016
assert t_type in (TensorType.INT32, TensorType.INT64)
30173017

3018-
out = relax.op.sparse_to_dense(
3019-
self.get_tensor_expr(indices),
3020-
list(self.get_tensor_value(output_shape)),
3021-
self.get_tensor_expr(values),
3022-
self.get_tensor_expr(default_value),
3018+
output_tensors = self.get_output_tensors(op)
3019+
output_tensor = output_tensors[0]
3020+
output_shape_val = to_int_list(self.get_tensor_shape(output_tensor))
3021+
output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
3022+
3023+
indices_expr = self.get_tensor_expr(indices)
3024+
values_expr = self.get_tensor_expr(values)
3025+
default_value_expr = self.get_tensor_expr(default_value)
3026+
output_shape_expr = relax.const(list(self.get_tensor_value(output_shape)), "int32")
3027+
3028+
out = relax.op.call_dps_packed(
3029+
"topi.sparse_to_dense",
3030+
(indices_expr, output_shape_expr, values_expr, default_value_expr),
3031+
out_sinfo=relax.TensorStructInfo(output_shape_val, output_dtype),
30233032
)
30243033

30253034
return out
@@ -3700,7 +3709,18 @@ def convert_matrix_set_diag(self, op):
37003709
input_expr = self.get_tensor_expr(input_tensors[0])
37013710
diagonal_expr = self.get_tensor_expr(input_tensors[1])
37023711

3703-
out = relax.op.matrix_set_diag(input_expr, diagonal_expr)
3712+
output_tensors = self.get_output_tensors(op)
3713+
output_tensor = output_tensors[0]
3714+
output_shape = to_int_list(self.get_tensor_shape(output_tensor))
3715+
output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
3716+
3717+
# topi.matrix_set_diag(input, diagonal, k1, k2, super_diag_right_align, sub_diag_right_align)
3718+
# TFLite MATRIX_SET_DIAG only sets the main diagonal, so k1=0, k2=0
3719+
out = relax.op.call_dps_packed(
3720+
"topi.matrix_set_diag",
3721+
(input_expr, diagonal_expr, relax.const(0), relax.const(0), relax.const(False), relax.const(False)),
3722+
out_sinfo=relax.TensorStructInfo(output_shape, output_dtype),
3723+
)
37043724
return out
37053725

37063726
def convert_matrix_diag(self, op):
@@ -3718,14 +3738,21 @@ def convert_matrix_diag(self, op):
37183738
scale and zero points to be equal"
37193739
)
37203740

3721-
shape = to_int_list(self.get_tensor_shape(diagonal))
3722-
shape = np.append(shape, shape[-1])
3723-
dtype = self.get_tensor_type_str(diagonal.tensor.Type())
3741+
output_tensors = self.get_output_tensors(op)
3742+
output_tensor = output_tensors[0]
3743+
output_shape = to_int_list(self.get_tensor_shape(output_tensor))
3744+
output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
37243745

3725-
input_expr = relax.op.zeros(tuple(shape), dtype)
37263746
diagonal_expr = self.get_tensor_expr(diagonal)
3727-
3728-
out = relax.op.matrix_set_diag(input_expr, diagonal_expr)
3747+
zeros_expr = relax.op.zeros(output_shape, output_dtype)
3748+
3749+
# topi.matrix_set_diag(input, diagonal, k1, k2, super_diag_right_align, sub_diag_right_align)
3750+
# TFLite MATRIX_DIAG only sets the main diagonal, so k1=0, k2=0
3751+
out = relax.op.call_dps_packed(
3752+
"topi.matrix_set_diag",
3753+
(zeros_expr, diagonal_expr, relax.const(0), relax.const(0), relax.const(False), relax.const(False)),
3754+
out_sinfo=relax.TensorStructInfo(output_shape, output_dtype),
3755+
)
37293756
return out
37303757

37313758
def convert_densify(self, op):

tests/python/relax/test_frontend_tflite.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2292,5 +2292,119 @@ def main(x: R.Tensor((1, 4, 4, 3), dtype="float32")) -> R.Tensor(
22922292
tvm.ir.assert_structural_equal(mod, Expected)
22932293

22942294

2295+
def test_matrix_diag():
2296+
"""Test TFLite MATRIX_DIAG operator."""
2297+
2298+
class MatrixDiag(tf.Module):
2299+
@tf.function(input_signature=[tf.TensorSpec(shape=(3,), dtype=tf.float32)])
2300+
def func(self, diagonal):
2301+
return tf.raw_ops.MatrixDiag(diagonal=diagonal)
2302+
2303+
@I.ir_module
2304+
class Expected:
2305+
@R.function
2306+
def main(diagonal: R.Tensor((3,), dtype="float32")) -> R.Tensor((3, 3), dtype="float32"):
2307+
R.func_attr({"num_input": 1})
2308+
with R.dataflow():
2309+
lv: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32")
2310+
gv = R.call_dps_packed(
2311+
"topi.matrix_set_diag",
2312+
(
2313+
lv,
2314+
diagonal,
2315+
R.const(0, "int32"),
2316+
R.const(0, "int32"),
2317+
R.const(False, "bool"),
2318+
R.const(False, "bool"),
2319+
),
2320+
out_sinfo=R.Tensor((3, 3), dtype="float32"),
2321+
)
2322+
R.output(gv)
2323+
return gv
2324+
2325+
verify(MatrixDiag, Expected)
2326+
2327+
2328+
def test_matrix_set_diag():
2329+
"""Test TFLite MATRIX_SET_DIAG operator."""
2330+
2331+
class MatrixSetDiag(tf.Module):
2332+
@tf.function(
2333+
input_signature=[
2334+
tf.TensorSpec(shape=(3, 3), dtype=tf.float32),
2335+
tf.TensorSpec(shape=(3,), dtype=tf.float32),
2336+
]
2337+
)
2338+
def func(self, input, diagonal):
2339+
return tf.raw_ops.MatrixSetDiag(input=input, diagonal=diagonal)
2340+
2341+
@I.ir_module
2342+
class Expected:
2343+
@R.function
2344+
def main(
2345+
input: R.Tensor((3, 3), dtype="float32"),
2346+
diagonal: R.Tensor((3,), dtype="float32"),
2347+
) -> R.Tensor((3, 3), dtype="float32"):
2348+
R.func_attr({"num_input": 2})
2349+
with R.dataflow():
2350+
gv = R.call_dps_packed(
2351+
"topi.matrix_set_diag",
2352+
(
2353+
input,
2354+
diagonal,
2355+
R.const(0, "int32"),
2356+
R.const(0, "int32"),
2357+
R.const(False, "bool"),
2358+
R.const(False, "bool"),
2359+
),
2360+
out_sinfo=R.Tensor((3, 3), dtype="float32"),
2361+
)
2362+
R.output(gv)
2363+
return gv
2364+
2365+
verify(MatrixSetDiag, Expected)
2366+
2367+
2368+
def test_sparse_to_dense():
2369+
"""Test TFLite SPARSE_TO_DENSE operator."""
2370+
2371+
class SparseToDense(tf.Module):
2372+
@tf.function(
2373+
input_signature=[
2374+
tf.TensorSpec(shape=(2,), dtype=tf.int32),
2375+
tf.TensorSpec(shape=(2,), dtype=tf.float32),
2376+
tf.TensorSpec(shape=(), dtype=tf.float32),
2377+
]
2378+
)
2379+
def func(self, indices, values, default_value):
2380+
# output_shape is provided as a constant, not an input
2381+
return tf.raw_ops.SparseToDense(
2382+
sparse_indices=indices,
2383+
output_shape=tf.constant([3], dtype=tf.int32),
2384+
sparse_values=values,
2385+
default_value=default_value,
2386+
)
2387+
2388+
@I.ir_module
2389+
class Expected:
2390+
@R.function
2391+
def main(
2392+
indices: R.Tensor((2,), dtype="int32"),
2393+
values: R.Tensor((2,), dtype="float32"),
2394+
default_value: R.Tensor((), dtype="float32"),
2395+
) -> R.Tensor((3,), dtype="float32"):
2396+
R.func_attr({"num_input": 3})
2397+
with R.dataflow():
2398+
gv = R.call_dps_packed(
2399+
"topi.sparse_to_dense",
2400+
(indices, R.const([3], "int32"), values, default_value),
2401+
out_sinfo=R.Tensor((3,), dtype="float32"),
2402+
)
2403+
R.output(gv)
2404+
return gv
2405+
2406+
verify(SparseToDense, Expected)
2407+
2408+
22952409
if __name__ == "__main__":
22962410
pytest.main(["-s", __file__])

0 commit comments

Comments
 (0)