Skip to content

Commit 95a816c

Browse files
authored
[TFLITE] Match TFLite shape for SSD custom op (apache#5473)
This patch ensures that the output shape from TVM's Detection_PostProcess is the same as TFLite's and expands the unit test to confirm this. Change-Id: If5db95741533f131241dfebbaa7708dbd528fe70
1 parent 063ba63 commit 95a816c

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2257,6 +2257,7 @@ def convert_detection_postprocess(self, op):
22572257
assert len(inputs) == 3, "inputs length should be 3"
22582258
cls_pred = self.get_expr(inputs[1].tensor_idx)
22592259
loc_prob = self.get_expr(inputs[0].tensor_idx)
2260+
batch_size = inputs[1].tensor.Shape(0)
22602261
anchor_values = self.get_tensor_value(inputs[2])
22612262
anchor_boxes = len(anchor_values)
22622263
anchor_type = self.get_tensor_type_str(inputs[2].tensor.Type())
@@ -2284,7 +2285,7 @@ def convert_detection_postprocess(self, op):
22842285
loc_prob = _op.concatenate(
22852286
[loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2
22862287
)
2287-
loc_prob = _op.reshape(loc_prob, [1, anchor_boxes*4])
2288+
loc_prob = _op.reshape(loc_prob, [batch_size, anchor_boxes*4])
22882289

22892290
# anchor coords are in yxhw format
22902291
# need to convert to ltrb
@@ -2327,10 +2328,14 @@ def convert_detection_postprocess(self, op):
23272328
ret = _op.vision.non_max_suppression(ret[0], ret[1], **non_max_suppression_attrs)
23282329
ret = _op.vision.get_valid_counts(ret, 0)
23292330
valid_count = ret[0]
2331+
# keep only the top 'max_detections' rows
2332+
ret = _op.strided_slice(ret[1],
2333+
[0, 0, 0],
2334+
[batch_size, custom_options["max_detections"], anchor_boxes])
23302335
# the output needs some reshaping to match tflite
2331-
ret = _op.split(ret[1], 6, axis=2)
2332-
cls_ids = ret[0]
2333-
scores = ret[1]
2336+
ret = _op.split(ret, 6, axis=2)
2337+
cls_ids = _op.reshape(ret[0], [batch_size, -1])
2338+
scores = _op.reshape(ret[1], [batch_size, -1])
23342339
boxes = _op.concatenate([ret[3], ret[2], ret[5], ret[4]], axis=2)
23352340
ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores, valid_count]), size=4)
23362341
return ret

tests/python/frontend/tflite/test_forward.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,7 +1731,14 @@ def test_detection_postprocess():
17311731
["raw_outputs/box_encodings", "raw_outputs/class_predictions"], num_output=4)
17321732
# check valid count is the same
17331733
assert tvm_output[3] == tflite_output[3]
1734+
# check all the output shapes are the same
1735+
assert tvm_output[0].shape == tflite_output[0].shape
1736+
assert tvm_output[1].shape == tflite_output[1].shape
1737+
assert tvm_output[2].shape == tflite_output[2].shape
17341738
valid_count = tvm_output[3][0]
1739+
# only check the valid detections are the same
1740+
# tvm has a different convention to tflite for invalid detections, it uses all -1s whereas
1741+
# tflite appears to put in nonsense data instead
17351742
tvm_boxes = tvm_output[0][0][:valid_count]
17361743
tvm_classes = tvm_output[1][0][:valid_count]
17371744
tvm_scores = tvm_output[2][0][:valid_count]

0 commit comments

Comments
 (0)