Skip to content

Commit 5efa4b7

Browse files
authored
[Test][TFLite] Add unit tests for PRELU (#19402)
This PR adds unit test coverage for `PRELU` activation in the Relax TFLite frontend, as part of #18971 - Added unit test for `PRELU` and Enabled converter to handle alpha broadcasting more cleanly across constant and expression-backed alpha inputs.
1 parent a6e2ea8 commit 5efa4b7

2 files changed

Lines changed: 70 additions & 24 deletions

File tree

python/tvm/relax/frontend/tflite/tflite_frontend.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,7 +2259,7 @@ def convert_slice(self, op):
22592259
# Create axes list for all dimensions being sliced
22602260
axes = list(range(input_tensor_rank))
22612261
begin = [int(v) for v in begin]
2262-
end = [int(v) for v in end]
2262+
end = [int(v) for v in end]
22632263
out = relax.op.strided_slice(in_expr, axes=axes, begin=begin, end=end)
22642264
return out
22652265

@@ -2840,9 +2840,7 @@ def convert_batch_matmul(self, op):
28402840
new_b_shape = [1] * max(0, rank_a - rank_b) + [int(s) for s in shape_b]
28412841
max_rank = max(rank_a, rank_b)
28422842

2843-
batch_shape = [
2844-
max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 2)
2845-
]
2843+
batch_shape = [max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 2)]
28462844

28472845
a_broadcast = batch_shape + [int(shape_a[-2]), int(shape_a[-1])]
28482846
b_broadcast = batch_shape + [int(shape_b[-2]), int(shape_b[-1])]
@@ -2987,21 +2985,11 @@ def convert_prelu(self, op):
29872985

29882986
input_tensor = input_tensors[0]
29892987
alpha_tensor = input_tensors[1]
2990-
if self.has_expr(alpha_tensor.tensor_idx):
2991-
alpha_expr = self.get_expr(alpha_tensor.tensor_idx)
2992-
else:
2993-
alpha_tensor_type = alpha_tensor.tensor.Type()
2994-
alpha_tensor_type_str = self.get_tensor_type_str(alpha_tensor_type)
2995-
alpha_expr = self.exp_tab.new_const(
2996-
self.get_tensor_value(alpha_tensor),
2997-
dtype=alpha_tensor_type_str,
2998-
source_name=alpha_tensor.tensor.Name(),
2999-
)
3000-
in_expr = self.get_expr(input_tensor.tensor_idx)
30012988
data_shape = to_int_list(self.get_tensor_shape(input_tensor))
3002-
3003-
alpha_expr = relax.op.broadcast_to(alpha_expr, data_shape)
3004-
alpha_expr = relax.op.reshape(alpha_expr, [-1])
2989+
alpha_expr = self.get_tensor_expr(alpha_tensor)
2990+
alpha_expr = self.bb.normalize(relax.op.broadcast_to(alpha_expr, data_shape))
2991+
alpha_expr = self.bb.normalize(relax.op.reshape(alpha_expr, [-1]))
2992+
in_expr = self.get_tensor_expr(input_tensor)
30052993
out = relax.op.nn.prelu(_op.reshape(in_expr, [-1]), alpha_expr, axis=0)
30062994
out = relax.op.reshape(out, data_shape)
30072995
return out

tests/python/relax/test_frontend_tflite.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ def func(self, x):
322322

323323
verify(Tile)
324324

325+
325326
def test_concat_v2():
326327
class ConcatV2(tf.Module):
327328
@tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)])
@@ -804,6 +805,7 @@ def func(self, data, kernel):
804805

805806
verify(TransposeConv)
806807

808+
807809
def test_l2_pool2d():
808810
class L2Pool2D(tf.Module):
809811
@tf.function(input_signature=[tf.TensorSpec(shape=(1, 8, 8, 2), dtype=tf.float32)])
@@ -815,9 +817,9 @@ def func(self, data):
815817
@I.ir_module
816818
class Expected:
817819
@R.function
818-
def main(
819-
data: R.Tensor((1, 8, 8, 2), dtype="float32")
820-
) -> R.Tensor((1, 8, 8, 2), dtype="float32"):
820+
def main(data: R.Tensor((1, 8, 8, 2), dtype="float32")) -> R.Tensor(
821+
(1, 8, 8, 2), dtype="float32"
822+
):
821823
R.func_attr({"num_input": 1})
822824
with R.dataflow():
823825
squared = R.power(data, R.const(2.0, "float32"))
@@ -883,6 +885,7 @@ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float3
883885

884886
verify(ReverseV2, Expected)
885887

888+
886889
def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, padding):
887890
class Conv2DModule(tf.Module):
888891
@tf.function(
@@ -1590,9 +1593,7 @@ def test_nms_v5_ir():
15901593
"build_kwargs,expected_topk_count,expected_keep_background",
15911594
_DETECTION_POSTPROCESS_SMOKE_CASES,
15921595
)
1593-
def test_detection_postprocess_smoke(
1594-
build_kwargs, expected_topk_count, expected_keep_background
1595-
):
1596+
def test_detection_postprocess_smoke(build_kwargs, expected_topk_count, expected_keep_background):
15961597
mod = _build_detection_postprocess_mod(**build_kwargs)
15971598
ir = mod.script()
15981599

@@ -1649,6 +1650,7 @@ def test_detection_postprocess_shape_variations(build_kwargs):
16491650
),
16501651
)
16511652

1653+
16521654
def _make_resize_expected(
16531655
input_shape, output_size, method, coordinate_transformation_mode, rounding_method
16541656
):
@@ -2109,5 +2111,61 @@ def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="floa
21092111
verify(ReLU_N1_to_1, Expected)
21102112

21112113

2114+
@pytest.mark.parametrize(
2115+
"shared_axes",
2116+
[
2117+
pytest.param([1, 2], id="channelwise_shared_axes"),
2118+
pytest.param([1, 2, 3], id="scalar_shared_axes"),
2119+
pytest.param(None, id="elementwise_no_shared_axes"),
2120+
],
2121+
)
2122+
def test_prelu(shared_axes):
2123+
inputs = tf.keras.Input(shape=(4, 4, 3), batch_size=1, dtype=tf.float32)
2124+
prelu_kwargs = {
2125+
"alpha_initializer": tf.initializers.constant(0.25),
2126+
}
2127+
if shared_axes is not None:
2128+
prelu_kwargs["shared_axes"] = shared_axes
2129+
outputs = tf.keras.layers.PReLU(**prelu_kwargs)(inputs)
2130+
keras_model = tf.keras.Model(inputs, outputs)
2131+
2132+
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
2133+
tflite_model_buf = converter.convert()
2134+
if hasattr(tflite.Model, "Model"):
2135+
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
2136+
else:
2137+
tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
2138+
2139+
mod = from_tflite(tflite_model)
2140+
mod["main"] = mod["main"].without_attr("params")
2141+
2142+
if shared_axes == [1, 2]:
2143+
alpha_const = np.full((1, 1, 3), 0.25, dtype=np.float32)
2144+
elif shared_axes == [1, 2, 3]:
2145+
alpha_const = np.full((1, 1, 1), 0.25, dtype=np.float32)
2146+
else:
2147+
alpha_const = np.full((4, 4, 3), 0.25, dtype=np.float32)
2148+
2149+
@I.ir_module
2150+
class Expected:
2151+
@R.function
2152+
def main(x: R.Tensor((1, 4, 4, 3), dtype="float32")) -> R.Tensor(
2153+
(1, 4, 4, 3), dtype="float32"
2154+
):
2155+
R.func_attr({"num_input": 1})
2156+
with R.dataflow():
2157+
lv: R.Tensor((1, 4, 4, 3), dtype="float32") = R.broadcast_to(
2158+
R.const(alpha_const), R.shape([1, 4, 4, 3])
2159+
)
2160+
lv1: R.Tensor((48,), dtype="float32") = R.reshape(x, R.shape([48]))
2161+
lv2: R.Tensor((48,), dtype="float32") = R.reshape(lv, R.shape([48]))
2162+
lv3: R.Tensor((48,), dtype="float32") = R.nn.prelu(lv1, lv2, axis=0)
2163+
gv: R.Tensor((1, 4, 4, 3), dtype="float32") = R.reshape(lv3, R.shape([1, 4, 4, 3]))
2164+
R.output(gv)
2165+
return gv
2166+
2167+
tvm.ir.assert_structural_equal(mod, Expected)
2168+
2169+
21122170
if __name__ == "__main__":
21132171
pytest.main(["-s", __file__])

0 commit comments

Comments
 (0)