Skip to content

Commit 78b5ed0

Browse files
authored
[Relax] Implement dynamic output trimming for NMS (#18676)
## Why NMS operator returns fixed-size output with trailing garbage data, wasting memory and requiring manual trimming for ONNX compatibility. ## How - Add dynamic_strided_slice to trim NMS output to valid detections only - Build slice parameters using TE compute to avoid legalization issues
1 parent 4533d31 commit 78b5ed0

3 files changed

Lines changed: 146 additions & 74 deletions

File tree

python/tvm/relax/op/vision/nms.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,10 @@ def all_class_non_max_suppression(
5454
`num_total_detection` of shape `(1,)` representing the total number of selected
5555
boxes. The three values in `indices` encode batch, class, and box indices.
5656
Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come
57-
first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of
58-
`batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection`
59-
rows are valid.
57+
first, in descending of scores, followed by boxes from batch 0, class 1 etc.
58+
The output uses dynamic_strided_slice to trim to only valid detections,
59+
so the first tensor has shape (num_total_detection, 3) containing only valid rows.
6060
61-
TODO: Implement true dynamic output shapes to match ONNX Runtime behavior exactly.
62-
This would eliminate the need for manual trimming and improve memory efficiency.
6361
If `output_format` is "tensorflow", the output is three tensors, the first
6462
is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of
6563
size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size

python/tvm/relax/transform/legalize_ops/vision.py

Lines changed: 49 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -15,64 +15,27 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Default legalization function for vision network related operators."""
18-
from tvm import topi, te
19-
from tvm import relax
18+
from tvm import relax, te, tir, topi
19+
2020
from ...block_builder import BlockBuilder
21-
from ...expr import Call, Expr
21+
from ...expr import Call, Expr, TupleGetItem
2222
from .common import register_legalize
2323

2424

25-
def _create_onnx_nms_te(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
26-
"""Create a proper NMS implementation that follows the correct algorithm"""
27-
scores_shape = list(scores.shape)
28-
if len(scores_shape) == 3:
29-
batch, num_classes, _ = scores_shape
30-
elif len(scores_shape) == 2:
31-
num_classes, _ = scores_shape
32-
batch = 1
33-
else:
34-
raise ValueError(f"Unexpected scores shape: {scores_shape}")
35-
36-
if hasattr(max_output_boxes_per_class, "data"):
37-
max_boxes = int(max_output_boxes_per_class.data.numpy())
38-
else:
39-
max_boxes = 3 # Default value
40-
41-
expected_detections = batch * num_classes * max_boxes
42-
43-
selected_indices_full, _ = topi.vision.all_class_non_max_suppression(
44-
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx"
45-
)
46-
47-
def slice_to_onnx_shape(data, expected_size):
48-
def compute_element(i, j):
49-
return tvm.tir.if_then_else(i < expected_size, data[i, j], tvm.tir.Cast("int64", 0))
50-
51-
return te.compute((expected_size, 3), compute_element, name="sliced_indices")
52-
53-
sliced_indices = slice_to_onnx_shape(selected_indices_full, expected_detections)
54-
55-
actual_detections = te.compute(
56-
(1,), lambda i: tvm.tir.Cast("int64", expected_detections), name="actual_detections"
57-
)
58-
59-
return [sliced_indices, actual_detections]
60-
61-
6225
@register_legalize("relax.vision.all_class_non_max_suppression")
6326
def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> Expr:
64-
"""Legalize all_class_non_max_suppression with fixed shape output.
65-
66-
Note: This implementation outputs fixed-size tensors with trailing garbage data.
67-
Only the first `num_total_detection` rows contain valid data. Users should use
68-
the `valid_count` tensor to determine how many rows are actually valid.
69-
70-
For complete ONNX compatibility, users can post-process the output:
71-
```python
72-
selected_indices, valid_count = nms_output
73-
actual_count = int(valid_count.numpy()[0])
74-
valid_indices = selected_indices.numpy()[:actual_count, :]
75-
```
27+
"""Legalize all_class_non_max_suppression with dynamic output trimming.
28+
29+
This implementation uses dynamic_strided_slice to trim the NMS output to only
30+
contain valid detections, improving memory efficiency and ONNX compatibility.
31+
32+
Returns
33+
-------
34+
result : Tuple[Tensor, Tensor]
35+
A tuple of (trimmed_indices, num_total_detections) where:
36+
- trimmed_indices: Tensor of shape (num_total_detections, 3) containing only
37+
valid detection indices (batch_id, class_id, box_id)
38+
- num_total_detections: Tensor of shape (1,) with the count of valid detections
7639
"""
7740
boxes = call.args[0]
7841
scores = call.args[1]
@@ -105,16 +68,37 @@ def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> E
10568
output_format,
10669
)
10770

108-
# TODO: Implement dynamic output trimming for better memory efficiency
109-
# Current approach returns fixed-size output with trailing garbage data
110-
# Future improvements could include:
111-
# 1. Dynamic strided_slice based on num_total_detections
112-
# 2. Custom Relax operator with true dynamic shapes
113-
# 3. VM builtin functions for runtime shape adjustment
114-
# 4. Symbolic shape inference in Relax IR
115-
#
116-
# For now, users should trim manually:
117-
# actual_count = int(num_total_detections.numpy()[0])
118-
# valid_indices = selected_indices.numpy()[:actual_count, :]
119-
120-
return nms_result
71+
# Dynamic output trimming using dynamic_strided_slice
72+
# Extract selected_indices and num_total_detections from the NMS result
73+
selected_indices = block_builder.emit(TupleGetItem(nms_result, 0))
74+
num_total_detections = block_builder.emit(TupleGetItem(nms_result, 1))
75+
76+
# Build slicing parameters using TE to avoid high-level Relax ops during legalization
77+
def build_begin():
78+
return te.compute((2,), lambda i: tir.const(0, "int64"), name="begin")
79+
80+
def build_strides():
81+
return te.compute((2,), lambda i: tir.const(1, "int64"), name="strides")
82+
83+
def build_end(count_tensor):
84+
# end = [count_tensor[0], 3]
85+
def compute_end(i):
86+
return tir.if_then_else(
87+
i == 0,
88+
tir.Cast("int64", count_tensor[0]),
89+
tir.const(3, "int64"),
90+
)
91+
92+
return te.compute((2,), compute_end, name="end")
93+
94+
begin = block_builder.call_te(build_begin)
95+
strides = block_builder.call_te(build_strides)
96+
end = block_builder.call_te(build_end, num_total_detections)
97+
98+
# Apply dynamic strided slice to trim to valid detections only
99+
trimmed_indices = block_builder.emit(
100+
relax.op.dynamic_strided_slice(selected_indices, begin, end, strides)
101+
)
102+
103+
# Return trimmed indices along with num_total_detections for compatibility
104+
return relax.Tuple([trimmed_indices, num_total_detections])

tests/python/relax/test_op_vision.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import numpy as np
1819
import pytest
20+
1921
import tvm
2022
import tvm.testing
21-
from tvm import relax, tir
22-
from tvm import TVMError
23-
from tvm.ir import Op, VDevice
23+
from tvm import TVMError, relax, tir
24+
from tvm.relax.transform import LegalizeOps
2425
from tvm.script import relax as R
2526

2627

@@ -53,7 +54,6 @@ def test_all_class_non_max_suppression_infer_struct_info():
5354

5455

5556
def test_all_class_non_max_suppression_wrong_input_number():
56-
bb = relax.BlockBuilder()
5757
boxes = relax.Var("boxes", R.Tensor((1, 5, 4), "float32"))
5858
scores = relax.Var("scores", R.Tensor((1, 3, 5), "float32"))
5959

@@ -86,5 +86,95 @@ def test_all_class_non_max_suppression_infer_struct_info_shape_var():
8686
)
8787

8888

89+
def test_all_class_non_max_suppression_legalize_dynamic_trim():
90+
@tvm.script.ir_module
91+
class NMSModule:
92+
@R.function
93+
def main(
94+
boxes: R.Tensor((1, 5, 4), "float32"),
95+
scores: R.Tensor((1, 2, 5), "float32"),
96+
) -> R.Tuple(R.Tensor(dtype="int64", ndim=2), R.Tensor((1,), "int64")):
97+
max_output_boxes_per_class = R.const(3, "int64")
98+
iou_threshold = R.const(0.5, "float32")
99+
score_threshold = R.const(0.1, "float32")
100+
return R.vision.all_class_non_max_suppression(
101+
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx"
102+
)
103+
104+
mod = LegalizeOps()(NMSModule)
105+
106+
# Check legalized function has dynamic output (uses dynamic_strided_slice)
107+
assert "dynamic_strided_slice" in str(mod)
108+
109+
ret_sinfo = mod["main"].ret_struct_info
110+
tvm.ir.assert_structural_equal(
111+
ret_sinfo,
112+
relax.TupleStructInfo(
113+
[
114+
relax.TensorStructInfo(ndim=2, dtype="int64"),
115+
relax.TensorStructInfo((1,), "int64"),
116+
]
117+
),
118+
)
119+
120+
121+
def test_all_class_non_max_suppression_legalize_e2e():
122+
@tvm.script.ir_module
123+
class NMSModule:
124+
@R.function
125+
def main(
126+
boxes: R.Tensor((1, 5, 4), "float32"),
127+
scores: R.Tensor((1, 2, 5), "float32"),
128+
) -> R.Tuple(R.Tensor(dtype="int64", ndim=2), R.Tensor((1,), "int64")):
129+
max_output_boxes_per_class = R.const(3, "int64")
130+
iou_threshold = R.const(0.5, "float32")
131+
score_threshold = R.const(0.1, "float32")
132+
return R.vision.all_class_non_max_suppression(
133+
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx"
134+
)
135+
136+
boxes_data = np.array(
137+
[
138+
[
139+
[0.0, 0.0, 1.0, 1.0],
140+
[0.1, 0.1, 1.1, 1.1],
141+
[2.0, 2.0, 3.0, 3.0],
142+
[4.0, 4.0, 5.0, 5.0],
143+
[6.0, 6.0, 7.0, 7.0],
144+
]
145+
],
146+
dtype=np.float32,
147+
)
148+
scores_data = np.array(
149+
[[[0.9, 0.8, 0.7, 0.6, 0.5], [0.85, 0.75, 0.65, 0.55, 0.45]]],
150+
dtype=np.float32,
151+
)
152+
153+
mod = LegalizeOps()(NMSModule)
154+
155+
# Check struct info
156+
tvm.ir.assert_structural_equal(
157+
mod["main"].ret_struct_info,
158+
relax.TupleStructInfo(
159+
[
160+
relax.TensorStructInfo(ndim=2, dtype="int64"),
161+
relax.TensorStructInfo((1,), "int64"),
162+
]
163+
),
164+
)
165+
166+
# Check runtime execution
167+
exe = tvm.compile(mod, target="llvm")
168+
vm = relax.VirtualMachine(exe, tvm.cpu())
169+
result = vm["main"](
170+
tvm.runtime.tensor(boxes_data, tvm.cpu()),
171+
tvm.runtime.tensor(scores_data, tvm.cpu()),
172+
)
173+
174+
selected_indices = result[0].numpy()
175+
num_total_detections = int(result[1].numpy()[0])
176+
tvm.testing.assert_allclose(selected_indices.shape, (num_total_detections, 3))
177+
178+
89179
if __name__ == "__main__":
90180
tvm.testing.main()

0 commit comments

Comments
 (0)