Skip to content

Commit 3d1e402

Browse files
[Frontend][TFLite] Add test coverage for SHAPE and RANGE operators (#19401)
Initial goal was to add SHAPE and RANGE tests, solving part of #18971 This PR achieves that and includes the minimum necessary frontend fixes discovered during implementation so those tests reflect real supported behavior instead of xfail/workarounds. so this PR includes both: **1. New SHAPE/RANGE tests 2. Targeted frontend fixes required to make those tests pass correctly** ## Why These Changes Were Needed - SHAPE conversion previously produced symbolic shape info instead of a tensor output aligned with TFLite SHAPE semantics. - RANGE conversion passed tensor expressions into arange instead of scalar values for constant scalar bounds. - Zero-input TFLite subgraphs (valid for constant-only models such as RANGE without inputs) were blocked by a strict assertion. - Model output collection was brittle for constant/prefetched outputs and could fail when output expressions were not already in the expr table. - As a result, i could not add meaningful SHAPE/RANGE coverage without fixing frontend behavior. ## **Modifications** ### **Frontend Changes** (In tflite_frontend.py): - Updated convert_shape: SHAPE now materializes shape output as a tensor using shape_to_tensor(shape_of(...)) - Applies output dtype casting based on ShapeOptions OutType (int32/int64) - Updated convert_range: Extracts scalar values for start/limit/delta from scalar constants - Calls arange with scalar-like values - Keeps dynamic scalar RANGE explicit as unsupported (raises OpNotImplemented with clear message) - Updated _input_type: Removed assumption that every subgraph must have at least one input - Supports valid zero-input subgraphs - Updated from_tflite output assembly: Resolves outputs via tensor wrappers and get_tensor_expr instead of direct expr-table lookup by name --- **Main functional changes are localized to SHAPE/RANGE conversion and model output/input handling.** --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 5efa4b7 commit 3d1e402

2 files changed

Lines changed: 103 additions & 5 deletions

File tree

python/tvm/relax/frontend/tflite/tflite_frontend.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -925,15 +925,35 @@ def convert_range(self, op):
925925

926926
start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2]
927927

928-
expressions = [self.get_tensor_expr(t) for t in [start, limit, delta]]
928+
def get_scalar_value(tensor):
929+
if self.has_expr(tensor.tensor_idx):
930+
expr = self.get_expr(tensor.tensor_idx)
931+
if isinstance(expr, relax.Constant):
932+
value = expr.data.numpy()
933+
else:
934+
# relax.op.arange currently expects scalar-like values here.
935+
# Keep dynamic scalar RANGE explicit until frontend support is added.
936+
raise tvm.error.OpNotImplemented(
937+
"TFLite RANGE with dynamic scalar inputs is not supported in Relax frontend yet."
938+
)
939+
else:
940+
value = self.get_tensor_value(tensor)
929941

942+
# TFLite RANGE operands are scalar tensors in the flatbuffer.
943+
assert value.size == 1, "RANGE scalar input must have exactly one element"
944+
return value.item()
945+
946+
start_value = get_scalar_value(start)
947+
limit_value = get_scalar_value(limit)
948+
delta_value = get_scalar_value(delta)
949+
930950
# out type inference
931951
if delta.tensor.Type() == TensorType.FLOAT32:
932952
out_type = self.get_tensor_type_str(delta.tensor.Type())
933953
else:
934954
out_type = self.get_tensor_type_str(start.tensor.Type())
935955

936-
out = relax.op.arange(expressions[0], expressions[1], expressions[2], out_type)
956+
out = relax.op.arange(start_value, limit_value, delta_value, out_type)
937957

938958
return out
939959

@@ -942,6 +962,7 @@ def convert_shape(self, op):
942962

943963
from tflite.BuiltinOptions import BuiltinOptions
944964
from tflite.ShapeOptions import ShapeOptions
965+
from tflite.TensorType import TensorType
945966

946967
input_tensors = self.get_input_tensors(op)
947968
assert len(input_tensors) == 1, "input tensors length should be 1"
@@ -951,7 +972,10 @@ def convert_shape(self, op):
951972
shape_options = ShapeOptions()
952973
shape_options.Init(op_options.Bytes, op_options.Pos)
953974

954-
out = relax.op.shape_of(self.get_tensor_expr(input_tensors[0]))
975+
# SHAPE must materialize as a tensor output in Relax, not just symbolic shape info.
976+
out = relax.op.shape_to_tensor(relax.op.shape_of(self.get_tensor_expr(input_tensors[0])))
977+
if shape_options.OutType() == TensorType.INT32:
978+
out = relax.op.astype(out, "int32")
955979

956980
return out
957981

@@ -4055,7 +4079,7 @@ def _input_type(model):
40554079
for subgraph_index in range(subgraph_count):
40564080
subgraph = model.Subgraphs(subgraph_index)
40574081
inputs_count = subgraph.InputsLength()
4058-
assert inputs_count >= 1
4082+
# TFLite subgraphs can validly have zero inputs (e.g. constant-only RANGE models).
40594083
for input_index in range(inputs_count):
40604084
input_ = subgraph.Inputs(input_index)
40614085
assert subgraph.TensorsLength() > input_
@@ -4209,7 +4233,9 @@ def func(self, data):
42094233
op_converter.convert_op_to_relax()
42104234

42114235
# params and outputs
4212-
outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs]
4236+
# Resolve outputs through tensor wrappers so constant/prefetched outputs are handled.
4237+
output_tensors = op_converter.get_tensors(model_outputs)
4238+
outputs = [op_converter.get_tensor_expr(tensor) for tensor in output_tensors]
42134239
outputs = outputs[0] if len(outputs) == 1 else relax.Tuple(outputs)
42144240
output_var = bb.emit_output(outputs)
42154241

tests/python/relax/test_frontend_tflite.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,78 @@ def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 2, 15), dtype="f
279279
verify(Reshape, Expected)
280280

281281

282+
@pytest.mark.parametrize(
283+
"input_shape, out_type",
284+
[
285+
((2, 3, 4), tf.int32),
286+
((5,), tf.int64),
287+
((1, 1, 1, 1), tf.int32),
288+
((), tf.int32),
289+
((0, 3), tf.int64),
290+
],
291+
)
292+
def test_shape(input_shape, out_type):
293+
"""SHAPE conversion for static-rank non-quantized tensors."""
294+
295+
class Shape(tf.Module):
296+
@tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)])
297+
def func(self, x):
298+
return tf.shape(x, out_type=out_type)
299+
300+
verify(Shape)
301+
302+
303+
def test_shape_dynamic_dim():
304+
"""SHAPE conversion with a dynamic input dimension."""
305+
306+
class ShapeDynamic(tf.Module):
307+
@tf.function(input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)])
308+
def func(self, x):
309+
return tf.shape(x, out_type=tf.int32)
310+
311+
verify(ShapeDynamic)
312+
313+
314+
@pytest.mark.parametrize(
315+
"start, limit, delta, dtype",
316+
[
317+
(0, 8, 2, tf.int32),
318+
(1, 9, 2, tf.int64),
319+
(0.0, 1.0, 0.2, tf.float32),
320+
(8, 0, -2, tf.int32),
321+
(0, 0, 1, tf.int32),
322+
(0, 7, 2, tf.int32),
323+
(0.0, -1.0, -0.25, tf.float32),
324+
],
325+
)
326+
def test_range(start, limit, delta, dtype):
327+
"""RANGE conversion with non-quantized constant scalar bounds."""
328+
329+
class Range(tf.Module):
330+
@tf.function(input_signature=[])
331+
def func(self):
332+
return tf.range(start, limit, delta, dtype=dtype)
333+
334+
verify(Range)
335+
336+
337+
def test_range_dynamic_scalar_inputs_not_supported():
338+
"""RANGE conversion currently rejects dynamic scalar inputs."""
339+
340+
class RangeDynamic(tf.Module):
341+
@tf.function(
342+
input_signature=[
343+
tf.TensorSpec(shape=(), dtype=tf.int32),
344+
tf.TensorSpec(shape=(), dtype=tf.int32),
345+
tf.TensorSpec(shape=(), dtype=tf.int32),
346+
]
347+
)
348+
def func(self, start, limit, delta):
349+
return tf.range(start, limit, delta, dtype=tf.int32)
350+
351+
with pytest.raises(tvm.error.OpNotImplemented, match="dynamic scalar inputs"):
352+
verify(RangeDynamic)
353+
282354
def test_tile_ir():
283355
"""TILE conversion with explicit Relax IR structural check."""
284356

0 commit comments

Comments
 (0)