From abc52aae75bf12a8839cc509fe2547d1b4629bd0 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Mon, 27 Jul 2020 10:38:52 -0700 Subject: [PATCH] [Relay][OP] Support NMSv4 ingestion from TF. (#6085) --- python/tvm/relay/frontend/tensorflow.py | 6 +++- .../frontend/tensorflow/test_forward.py | 33 +++++++++++++++---- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 5f52553cfd77..aa62702b2214 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -637,10 +637,11 @@ def _impl(inputs, attr, params, mod): iou_threshold = np.atleast_1d(inputs[3].data.asnumpy())[0] # score_threshold was introduced from V3 score_threshold = np.atleast_1d(inputs[4].data.asnumpy())[0] if len(inputs) > 4 else 0.0 + pad_output = 'pad_to_max_output_size' # Generate data with shape (1, num_anchors, 5) scores = AttrCvt(op_name="expand_dims", - ignores=['T_threshold'], + ignores=['T_threshold', pad_output], extras={'axis': -1, 'num_newaxis': 1})([inputs[1]], attr) data = get_relay_op('concatenate')([scores, inputs[0]], -1) data = get_relay_op('expand_dims')(data, 0, 1) @@ -667,6 +668,8 @@ def _impl(inputs, attr, params, mod): return_indices=True, invalid_to_bottom=False) + if pad_output in attr and attr[pad_output]: + return nms_ret # squeeze it, TF NMS is not batched size = get_relay_op("squeeze")(nms_ret[1], axis=[1]) data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0]) @@ -2152,6 +2155,7 @@ def _impl(inputs, attr, params, mod): 'Neg' : AttrCvt('negative'), 'NonMaxSuppressionV2' : _nms(), 'NonMaxSuppressionV3' : _nms(), + 'NonMaxSuppressionV4' : _nms(), 'NoOp' : _no_op(), 'NotEqual' : _broadcast('not_equal'), 'OneHot' : _one_hot(), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 5c6bd6f12cb4..62829df24d8a 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2031,12 +2031,31 @@ def _test_forward_nms_v3(bx_shape, score_shape, iou_threshold, score_threshold, compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'], 'nms/NonMaxSuppressionV3:0', mode='debug') -def test_forward_nms_v3(): - """ NonMaxSuppressionV3 """ - _test_forward_nms_v3((5, 4), (5,), 0.7, 0.5, 5) - _test_forward_nms_v3((20, 4), (20,), 0.5, 0.6, 10) - _test_forward_nms_v3((1000, 4), (1000,), 0.3, 0.7, 1000) - _test_forward_nms_v3((2000, 4), (2000,), 0.4, 0.6, 7) +def _test_forward_nms_v4(bx_shape, score_shape, iou_threshold, score_threshold, out_size, dtype="float32"): + boxes = np.random.uniform(0, 10, size=bx_shape).astype(dtype) + scores = np.random.uniform(size=score_shape).astype(dtype) + max_output_size = np.int32(out_size) + tf.reset_default_graph() + in_data_1 = tf.placeholder(dtype, boxes.shape, name="in_data_1") + in_data_2 = tf.placeholder(dtype, scores.shape, name="in_data_2") + in_data_3 = tf.placeholder(tf.int32, name="in_data_3") + indices_padded, num_valid = tf.image.non_max_suppression_padded(boxes=in_data_1, scores=in_data_2, max_output_size=in_data_3, + iou_threshold=iou_threshold, score_threshold=score_threshold, name="nms", pad_to_max_output_size=True) + num_valid = tf.reshape(num_valid,shape=(-1,)) + indices_padded = tf.reshape(indices_padded, shape=(-1,)) + tf.slice(indices_padded, tf.constant([0]), num_valid, name="SlicedIndices") + compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'], + ['nms/NonMaxSuppressionV4:1', "SlicedIndices:0"], mode='vm') + compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'], + ['nms/NonMaxSuppressionV4:1', "SlicedIndices:0"], mode='debug') + +def test_forward_nms(): + """ NonMaxSuppressionV3,4 """ + for _test_forward_nms in [_test_forward_nms_v3, _test_forward_nms_v4]: + _test_forward_nms((5, 4), (5,), 0.7, 0.5, 5) + _test_forward_nms((20, 4), (20,), 0.5, 0.6, 10) + _test_forward_nms((1000, 4), (1000,), 0.3, 0.7, 1000) + _test_forward_nms((2000, 4), (2000,), 0.4, 0.6, 7) ####################################################################### @@ -3867,7 +3886,7 @@ def lstm_cell(): test_forward_truncatemod() test_forward_one_hot() test_forward_atan2() - test_forward_nms_v3() + test_forward_nms() # Activations test_forward_sigmoid()