@@ -322,6 +322,7 @@ def func(self, x):
322322
323323 verify (Tile )
324324
325+
325326def test_concat_v2 ():
326327 class ConcatV2 (tf .Module ):
327328 @tf .function (input_signature = [tf .TensorSpec (shape = (1 , 30 ), dtype = tf .float32 )])
@@ -804,6 +805,7 @@ def func(self, data, kernel):
804805
805806 verify (TransposeConv )
806807
808+
807809def test_l2_pool2d ():
808810 class L2Pool2D (tf .Module ):
809811 @tf .function (input_signature = [tf .TensorSpec (shape = (1 , 8 , 8 , 2 ), dtype = tf .float32 )])
@@ -815,9 +817,9 @@ def func(self, data):
815817 @I .ir_module
816818 class Expected :
817819 @R .function
818- def main (
819- data : R . Tensor (( 1 , 8 , 8 , 2 ), dtype = "float32" )
820- ) -> R . Tensor (( 1 , 8 , 8 , 2 ), dtype = "float32" ) :
820+ def main (data : R . Tensor (( 1 , 8 , 8 , 2 ), dtype = "float32" )) -> R . Tensor (
821+ ( 1 , 8 , 8 , 2 ), dtype = "float32"
822+ ):
821823 R .func_attr ({"num_input" : 1 })
822824 with R .dataflow ():
823825 squared = R .power (data , R .const (2.0 , "float32" ))
@@ -883,6 +885,7 @@ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float3
883885
884886 verify (ReverseV2 , Expected )
885887
888+
886889def _make_conv2d_module (data_shape , kernel_shape , data_format , strides , padding ):
887890 class Conv2DModule (tf .Module ):
888891 @tf .function (
@@ -1590,9 +1593,7 @@ def test_nms_v5_ir():
15901593 "build_kwargs,expected_topk_count,expected_keep_background" ,
15911594 _DETECTION_POSTPROCESS_SMOKE_CASES ,
15921595)
1593- def test_detection_postprocess_smoke (
1594- build_kwargs , expected_topk_count , expected_keep_background
1595- ):
1596+ def test_detection_postprocess_smoke (build_kwargs , expected_topk_count , expected_keep_background ):
15961597 mod = _build_detection_postprocess_mod (** build_kwargs )
15971598 ir = mod .script ()
15981599
@@ -1649,6 +1650,7 @@ def test_detection_postprocess_shape_variations(build_kwargs):
16491650 ),
16501651 )
16511652
1653+
16521654def _make_resize_expected (
16531655 input_shape , output_size , method , coordinate_transformation_mode , rounding_method
16541656):
@@ -2109,5 +2111,61 @@ def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="floa
21092111 verify (ReLU_N1_to_1 , Expected )
21102112
21112113
2114+ @pytest .mark .parametrize (
2115+ "shared_axes" ,
2116+ [
2117+ pytest .param ([1 , 2 ], id = "channelwise_shared_axes" ),
2118+ pytest .param ([1 , 2 , 3 ], id = "scalar_shared_axes" ),
2119+ pytest .param (None , id = "elementwise_no_shared_axes" ),
2120+ ],
2121+ )
2122+ def test_prelu (shared_axes ):
2123+ inputs = tf .keras .Input (shape = (4 , 4 , 3 ), batch_size = 1 , dtype = tf .float32 )
2124+ prelu_kwargs = {
2125+ "alpha_initializer" : tf .initializers .constant (0.25 ),
2126+ }
2127+ if shared_axes is not None :
2128+ prelu_kwargs ["shared_axes" ] = shared_axes
2129+ outputs = tf .keras .layers .PReLU (** prelu_kwargs )(inputs )
2130+ keras_model = tf .keras .Model (inputs , outputs )
2131+
2132+ converter = tf .lite .TFLiteConverter .from_keras_model (keras_model )
2133+ tflite_model_buf = converter .convert ()
2134+ if hasattr (tflite .Model , "Model" ):
2135+ tflite_model = tflite .Model .Model .GetRootAsModel (tflite_model_buf , 0 )
2136+ else :
2137+ tflite_model = tflite .Model .GetRootAsModel (tflite_model_buf , 0 )
2138+
2139+ mod = from_tflite (tflite_model )
2140+ mod ["main" ] = mod ["main" ].without_attr ("params" )
2141+
2142+ if shared_axes == [1 , 2 ]:
2143+ alpha_const = np .full ((1 , 1 , 3 ), 0.25 , dtype = np .float32 )
2144+ elif shared_axes == [1 , 2 , 3 ]:
2145+ alpha_const = np .full ((1 , 1 , 1 ), 0.25 , dtype = np .float32 )
2146+ else :
2147+ alpha_const = np .full ((4 , 4 , 3 ), 0.25 , dtype = np .float32 )
2148+
2149+ @I .ir_module
2150+ class Expected :
2151+ @R .function
2152+ def main (x : R .Tensor ((1 , 4 , 4 , 3 ), dtype = "float32" )) -> R .Tensor (
2153+ (1 , 4 , 4 , 3 ), dtype = "float32"
2154+ ):
2155+ R .func_attr ({"num_input" : 1 })
2156+ with R .dataflow ():
2157+ lv : R .Tensor ((1 , 4 , 4 , 3 ), dtype = "float32" ) = R .broadcast_to (
2158+ R .const (alpha_const ), R .shape ([1 , 4 , 4 , 3 ])
2159+ )
2160+ lv1 : R .Tensor ((48 ,), dtype = "float32" ) = R .reshape (x , R .shape ([48 ]))
2161+ lv2 : R .Tensor ((48 ,), dtype = "float32" ) = R .reshape (lv , R .shape ([48 ]))
2162+ lv3 : R .Tensor ((48 ,), dtype = "float32" ) = R .nn .prelu (lv1 , lv2 , axis = 0 )
2163+ gv : R .Tensor ((1 , 4 , 4 , 3 ), dtype = "float32" ) = R .reshape (lv3 , R .shape ([1 , 4 , 4 , 3 ]))
2164+ R .output (gv )
2165+ return gv
2166+
2167+ tvm .ir .assert_structural_equal (mod , Expected )
2168+
2169+
21122170if __name__ == "__main__" :
21132171 pytest .main (["-s" , __file__ ])
0 commit comments