|
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
17 | 17 | """Default legalization function for vision network related operators.""" |
18 | | -from tvm import topi, te |
19 | | -from tvm import relax |
| 18 | +from tvm import relax, te, tir, topi |
| 19 | + |
20 | 20 | from ...block_builder import BlockBuilder |
21 | | -from ...expr import Call, Expr |
| 21 | +from ...expr import Call, Expr, TupleGetItem |
22 | 22 | from .common import register_legalize |
23 | 23 |
|
24 | 24 |
|
25 | | -def _create_onnx_nms_te(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold): |
26 | | - """Create a proper NMS implementation that follows the correct algorithm""" |
27 | | - scores_shape = list(scores.shape) |
28 | | - if len(scores_shape) == 3: |
29 | | - batch, num_classes, _ = scores_shape |
30 | | - elif len(scores_shape) == 2: |
31 | | - num_classes, _ = scores_shape |
32 | | - batch = 1 |
33 | | - else: |
34 | | - raise ValueError(f"Unexpected scores shape: {scores_shape}") |
35 | | - |
36 | | - if hasattr(max_output_boxes_per_class, "data"): |
37 | | - max_boxes = int(max_output_boxes_per_class.data.numpy()) |
38 | | - else: |
39 | | - max_boxes = 3 # Default value |
40 | | - |
41 | | - expected_detections = batch * num_classes * max_boxes |
42 | | - |
43 | | - selected_indices_full, _ = topi.vision.all_class_non_max_suppression( |
44 | | - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx" |
45 | | - ) |
46 | | - |
47 | | - def slice_to_onnx_shape(data, expected_size): |
48 | | - def compute_element(i, j): |
49 | | - return tvm.tir.if_then_else(i < expected_size, data[i, j], tvm.tir.Cast("int64", 0)) |
50 | | - |
51 | | - return te.compute((expected_size, 3), compute_element, name="sliced_indices") |
52 | | - |
53 | | - sliced_indices = slice_to_onnx_shape(selected_indices_full, expected_detections) |
54 | | - |
55 | | - actual_detections = te.compute( |
56 | | - (1,), lambda i: tvm.tir.Cast("int64", expected_detections), name="actual_detections" |
57 | | - ) |
58 | | - |
59 | | - return [sliced_indices, actual_detections] |
60 | | - |
61 | | - |
62 | 25 | @register_legalize("relax.vision.all_class_non_max_suppression") |
63 | 26 | def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> Expr: |
64 | | - """Legalize all_class_non_max_suppression with fixed shape output. |
65 | | -
|
66 | | - Note: This implementation outputs fixed-size tensors with trailing garbage data. |
67 | | - Only the first `num_total_detection` rows contain valid data. Users should use |
68 | | - the `valid_count` tensor to determine how many rows are actually valid. |
69 | | -
|
70 | | - For complete ONNX compatibility, users can post-process the output: |
71 | | - ```python |
72 | | - selected_indices, valid_count = nms_output |
73 | | - actual_count = int(valid_count.numpy()[0]) |
74 | | - valid_indices = selected_indices.numpy()[:actual_count, :] |
75 | | - ``` |
| 27 | + """Legalize all_class_non_max_suppression with dynamic output trimming. |
| 28 | +
|
| 29 | + This implementation uses dynamic_strided_slice to trim the NMS output to only |
| 30 | + contain valid detections, improving memory efficiency and ONNX compatibility. |
| 31 | +
|
| 32 | + Returns |
| 33 | + ------- |
| 34 | + result : Tuple[Tensor, Tensor] |
| 35 | + A tuple of (trimmed_indices, num_total_detections) where: |
| 36 | + - trimmed_indices: Tensor of shape (num_total_detections, 3) containing only |
| 37 | + valid detection indices (batch_id, class_id, box_id) |
| 38 | + - num_total_detections: Tensor of shape (1,) with the count of valid detections |
76 | 39 | """ |
77 | 40 | boxes = call.args[0] |
78 | 41 | scores = call.args[1] |
@@ -105,16 +68,37 @@ def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> E |
105 | 68 | output_format, |
106 | 69 | ) |
107 | 70 |
|
108 | | - # TODO: Implement dynamic output trimming for better memory efficiency |
109 | | - # Current approach returns fixed-size output with trailing garbage data |
110 | | - # Future improvements could include: |
111 | | - # 1. Dynamic strided_slice based on num_total_detections |
112 | | - # 2. Custom Relax operator with true dynamic shapes |
113 | | - # 3. VM builtin functions for runtime shape adjustment |
114 | | - # 4. Symbolic shape inference in Relax IR |
115 | | - # |
116 | | - # For now, users should trim manually: |
117 | | - # actual_count = int(num_total_detections.numpy()[0]) |
118 | | - # valid_indices = selected_indices.numpy()[:actual_count, :] |
119 | | - |
120 | | - return nms_result |
| 71 | + # Dynamic output trimming using dynamic_strided_slice |
| 72 | + # Extract selected_indices and num_total_detections from the NMS result |
| 73 | + selected_indices = block_builder.emit(TupleGetItem(nms_result, 0)) |
| 74 | + num_total_detections = block_builder.emit(TupleGetItem(nms_result, 1)) |
| 75 | + |
| 76 | + # Build slicing parameters using TE to avoid high-level Relax ops during legalization |
| 77 | + def build_begin(): |
| 78 | + return te.compute((2,), lambda i: tir.const(0, "int64"), name="begin") |
| 79 | + |
| 80 | + def build_strides(): |
| 81 | + return te.compute((2,), lambda i: tir.const(1, "int64"), name="strides") |
| 82 | + |
| 83 | + def build_end(count_tensor): |
| 84 | + # end = [count_tensor[0], 3] |
| 85 | + def compute_end(i): |
| 86 | + return tir.if_then_else( |
| 87 | + i == 0, |
| 88 | + tir.Cast("int64", count_tensor[0]), |
| 89 | + tir.const(3, "int64"), |
| 90 | + ) |
| 91 | + |
| 92 | + return te.compute((2,), compute_end, name="end") |
| 93 | + |
| 94 | + begin = block_builder.call_te(build_begin) |
| 95 | + strides = block_builder.call_te(build_strides) |
| 96 | + end = block_builder.call_te(build_end, num_total_detections) |
| 97 | + |
| 98 | + # Apply dynamic strided slice to trim to valid detections only |
| 99 | + trimmed_indices = block_builder.emit( |
| 100 | + relax.op.dynamic_strided_slice(selected_indices, begin, end, strides) |
| 101 | + ) |
| 102 | + |
| 103 | + # Return trimmed indices along with num_total_detections for compatibility |
| 104 | + return relax.Tuple([trimmed_indices, num_total_detections]) |
0 commit comments