From 15d8e246f66cf25f6c79e107614289c90d44f287 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Thu, 15 Apr 2021 19:08:24 +0000 Subject: [PATCH 1/4] fixes for maskrcnn: 1. topk issue in nms 2. where operator when condition tensor needs to be broadcast --- .../_op_translations_opset12.py | 21 ++++++++++++++++--- .../_op_translations_opset13.py | 3 +++ tests/python-pytest/onnx/test_operators.py | 11 ++++++---- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py index d683aad7000c..efa27fe69a26 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py @@ -3315,6 +3315,9 @@ def convert_contrib_box_nms(node, **kwargs): center_point_box = 0 if in_format == 'corner' else 1 + if topk == -1: + topk = 2**31-1 + if in_format != out_format: raise NotImplementedError('box_nms does not currently support in_fomat != out_format') @@ -3470,7 +3473,7 @@ def convert_equal_scalar(node, **kwargs): return nodes -@mx_op.register("where") +@mx_op.register('where') def convert_where(node, **kwargs): """Map MXNet's where operator attributes to onnx's Where operator and return the created node. @@ -3478,9 +3481,21 @@ def convert_where(node, **kwargs): from onnx.helper import make_node from onnx import TensorProto name, input_nodes, _ = get_inputs(node, kwargs) + # note that in mxnet the condition tensor can either have the same shape as x and y OR + # have shape (first dim of x,) + create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([1], name+'_1', kwargs['initializer']) nodes = [ - make_node("Cast", [input_nodes[0]], [name+"_bool"], to=int(TensorProto.BOOL)), - make_node("Where", [name+"_bool", input_nodes[1], input_nodes[2]], [name], name=name) + make_node('Shape', [input_nodes[0]], [name+'_cond_shape']), + make_node('Shape', [name+'_cond_shape'], [name+'_cond_dim']), + make_node('Shape', [input_nodes[1]], [name+'_x_shape']), + make_node('Shape', [name+'_x_shape'], [name+'_x_dim']), + make_node('Sub', [name+'_x_dim', name+'_cond_dim'], [name+'_sub']), + make_node('Concat', [name+'_0', name+'_sub'], [name+'_concat'], axis=0), + make_node('Pad', [name+'_cond_shape', name+'_concat', name+'_1'], [name+'_cond_new_shape']), + make_node('Reshape', [input_nodes[0], name+'_cond_new_shape'], [name+'_cond']), + make_node('Cast', [name+'_cond'], [name+'_bool'], to=int(TensorProto.BOOL)), + make_node('Where', [name+'_bool', input_nodes[1], input_nodes[2]], [name], name=name) ] return nodes diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py index df137644b151..d86a00679d9e 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py @@ -843,6 +843,9 @@ def convert_contrib_box_nms(node, **kwargs): center_point_box = 0 if in_format == 'corner' else 1 + if topk == -1: + topk = 2**31-1 + if in_format != out_format: raise NotImplementedError('box_nms does not currently support in_fomat != out_format') diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index b032fa7fc1bd..3ec2c9cb1aba 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -483,7 +483,7 @@ def test_onnx_export_contrib_BilinearResize2D(tmp_path, dtype, params): op_export_test('contrib_BilinearResize2D', M, [x], tmp_path) -@pytest.mark.parametrize('topk', [2, 3, 4]) +@pytest.mark.parametrize('topk', [-1, 2, 3, 4]) @pytest.mark.parametrize('valid_thresh', [0.3, 0.4, 0.8]) @pytest.mark.parametrize('overlap_thresh', [0.4, 0.7, 1.0]) def test_onnx_export_contrib_box_nms(tmp_path, topk, valid_thresh, overlap_thresh): @@ -574,12 +574,15 @@ def test_onnx_export_equal_scalar(tmp_path, dtype, scalar): op_export_test('_internal._equal_scalar', M, [x], tmp_path) -@pytest.mark.parametrize("dtype", ["float16", "float32", "int32", "int64"]) -@pytest.mark.parametrize("shape", [(1,1), (3,3), (10,2), (20,30,40)]) -def test_onnx_export_where(tmp_path, dtype, shape): +@pytest.mark.parametrize('dtype', ["float16", "float32", "int32", "int64"]) +@pytest.mark.parametrize('shape', [(5,), (3,3), (10,2), (20,30,40)]) +@pytest.mark.parametrize('broadcast', [True, False]) +def test_onnx_export_where(tmp_path, dtype, shape, broadcast): M = def_model('where') x = mx.nd.zeros(shape, dtype=dtype) y = mx.nd.ones(shape, dtype=dtype) + if broadcast: + shape = shape[0:1] cond = mx.nd.random.randint(low=0, high=1, shape=shape, dtype='int32') op_export_test('where', M, [cond, x, y], tmp_path) From 55c7682d27fa6b09aceeae3a3d2d572cf1db05c9 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Thu, 15 Apr 2021 23:16:04 +0000 Subject: [PATCH 2/4] fix for roi_align --- .../_op_translations/_op_translations_opset12.py | 13 +++++++++---- .../_op_translations/_op_translations_opset13.py | 14 ++++++++++---- tests/python-pytest/onnx/test_operators.py | 13 ++++++++++--- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py index efa27fe69a26..e7563ec0a9e8 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py @@ -3991,17 +3991,22 @@ def convert_contrib_roialign(node, **kwargs): aligned!=False') create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([0], name+'_0_s', kwargs['initializer'], dtype='float32') create_tensor([1], name+'_1', kwargs['initializer']) create_tensor([5], name+'_5', kwargs['initializer']) nodes = [ make_node('Slice', [input_nodes[1], name+'_1', name+'_5', name+'_1'], [name+'_rois']), - make_node('Slice', [input_nodes[1], name+'_0', name+'_1', name+'_1'], [name+'_inds__']), - make_node('Squeeze', [name+'_inds__'], [name+'_inds_'], axes=[1]), + make_node('Slice', [input_nodes[1], name+'_0', name+'_1', name+'_1'], [name+'_inds___']), + make_node('Squeeze', [name+'_inds___'], [name+'_inds__'], axes=[1]), + make_node('Relu', [name+'_inds__'], [name+'_inds_']), make_node('Cast', [name+'_inds_'], [name+'_inds'], to=int(TensorProto.INT64)), - make_node('RoiAlign', [input_nodes[0], name+'_rois', name+'_inds'], [name], + make_node('RoiAlign', [input_nodes[0], name+'_rois', name+'_inds'], [name+'_roi'], mode='avg', output_height=pooled_size[0], output_width=pooled_size[1], - sampling_ratio=sample_ratio, spatial_scale=spatial_scale) + sampling_ratio=sample_ratio, spatial_scale=spatial_scale), + make_node('Unsqueeze', [name+'_inds___'], [name+'_unsq'], axes=(2, 3)), + make_node('Less', [name+'_unsq', name+'_0_s'], [name+'_less']), + make_node('Where', [name+'_less', name+'_0_s', name+'_roi'], [name]) ] return nodes diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py index d86a00679d9e..9beddde43b98 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py @@ -938,17 +938,23 @@ def convert_contrib_roialign(node, **kwargs): aligned!=False') create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([0], name+'_0_s', kwargs['initializer'], dtype='float32') create_tensor([1], name+'_1', kwargs['initializer']) create_tensor([5], name+'_5', kwargs['initializer']) + create_tensor([2, 3], name+'_2_3', kwargs['initializer']) nodes = [ make_node('Slice', [input_nodes[1], name+'_1', name+'_5', name+'_1'], [name+'_rois']), - make_node('Slice', [input_nodes[1], name+'_0', name+'_1', name+'_1'], [name+'_inds__']), - make_node('Squeeze', [name+'_inds__', name+'_1'], [name+'_inds_']), + make_node('Slice', [input_nodes[1], name+'_0', name+'_1', name+'_1'], [name+'_inds___']), + make_node('Squeeze', [name+'_inds___', name+'_1'], [name+'_inds__']), + make_node('Relu', [name+'_inds__'], [name+'_inds_']), make_node('Cast', [name+'_inds_'], [name+'_inds'], to=int(TensorProto.INT64)), - make_node('RoiAlign', [input_nodes[0], name+'_rois', name+'_inds'], [name], + make_node('RoiAlign', [input_nodes[0], name+'_rois', name+'_inds'], [name+'_roi'], mode='avg', output_height=pooled_size[0], output_width=pooled_size[1], - sampling_ratio=sample_ratio, spatial_scale=spatial_scale) + sampling_ratio=sample_ratio, spatial_scale=spatial_scale), + make_node('Unsqueeze', [name+'_inds___', name+'_2_3'], [name+'_unsq']), + make_node('Less', [name+'_unsq', name+'_0_s'], [name+'_less']), + make_node('Where', [name+'_less', name+'_0_s', name+'_roi'], [name]) ] return nodes diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 3ec2c9cb1aba..7cbd10249f9a 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -795,8 +795,9 @@ def test_onnx_export_broadcast_like(tmp_path, dtype, lhs_axes, rhs_axes): @pytest.mark.parametrize('spatial_scale', [1, 0.5, 0.0625]) @pytest.mark.parametrize('spatial_ratio', [1, 2, 3, 5]) def test_onnx_export_contrib_ROIAlign(tmp_path, dtype, pooled_size, spatial_scale, spatial_ratio): - data = mx.random.uniform(0, 1, (5, 3, 128, 128)).astype(dtype) - rois = mx.nd.array([[0, 0, 0, 63, 63], + data = mx.random.uniform(0, 1, (5, 3, 512, 512)).astype(dtype) + rois = mx.nd.array([[-1, 0, 0, 0, 0], + [0, 0, 0, 63, 63], [1, 34, 52, 25, 85], [2, 50, 50, 100, 100], [3, 0, 0, 127, 127], @@ -804,7 +805,13 @@ def test_onnx_export_contrib_ROIAlign(tmp_path, dtype, pooled_size, spatial_scal [0, 0, 0, 1, 1]]).astype(dtype) M = def_model('contrib.ROIAlign', pooled_size=pooled_size, spatial_scale=spatial_scale, sample_ratio=spatial_ratio) - op_export_test('_contrib_ROIAlign', M, [data, rois], tmp_path) + # according to https://mxnet.apache.org/versions/1.7.0/api/python/docs/api/contrib/symbol/index.html#mxnet.contrib.symbol.ROIAlign + # the returned value for when batch_id < 0 should be all 0's + # however mxnet 1.8 does always behave this way so we set the first roi to 0's manually + def mx_map(x): + x[0] = 0 + return x + op_export_test('_contrib_ROIAlign', M, [data, rois], tmp_path, mx_map=mx_map) @pytest.mark.parametrize('dtype', ['float32', 'float64']) From f702577601636733d4630202848a3a9d4f09ea24 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 16 Apr 2021 02:33:59 +0000 Subject: [PATCH 3/4] add model tests --- .../python-pytest/onnx/test_onnxruntime_cv.py | 42 ++++++++++++------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/tests/python-pytest/onnx/test_onnxruntime_cv.py b/tests/python-pytest/onnx/test_onnxruntime_cv.py index f8c2f227c8b8..61d449a5c73e 100644 --- a/tests/python-pytest/onnx/test_onnxruntime_cv.py +++ b/tests/python-pytest/onnx/test_onnxruntime_cv.py @@ -199,8 +199,6 @@ def obj_detection_test_images(tmpdir_factory): tmpdir = tmpdir_factory.mktemp("obj_det_data") from urllib.parse import urlparse test_image_urls = [ - 'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/car.jpg', - 'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/duck.jpg', 'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/fieldhockey.jpg', 'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/flower.jpg', 'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/runners.jpg', @@ -240,14 +238,20 @@ def obj_detection_test_images(tmpdir_factory): 'faster_rcnn_resnet101_v1d_coco', 'yolo3_darknet53_coco', 'yolo3_mobilenet1.0_coco', + 'mask_rcnn_resnet18_v1b_coco', + 'mask_rcnn_fpn_resnet18_v1b_coco', + 'mask_rcnn_resnet50_v1b_coco', + 'mask_rcnn_fpn_resnet50_v1b_coco', + 'mask_rcnn_resnet101_v1d_coco', + 'mask_rcnn_fpn_resnet101_v1d_coco', ]) def test_obj_detection_model_inference_onnxruntime(tmp_path, model, obj_detection_test_images): def assert_obj_detetion_result(mx_ids, mx_scores, mx_boxes, onnx_ids, onnx_scores, onnx_boxes, - score_thresh=0.6, score_tol=1e-4): - def assert_bbox(mx_boxe, onnx_boxe, box_tol=1e-2): - def assert_scalar(a, b, tol=box_tol): - return np.abs(a-b) <= tol + score_thresh=0.6, score_tol=0.0001, box_tol=0.01): + def assert_bbox(mx_boxe, onnx_boxe): + def assert_scalar(a, b): + return np.abs(a-b) <= box_tol return assert_scalar(mx_boxe[0], onnx_boxe[0]) and assert_scalar(mx_boxe[1], onnx_boxe[1]) \ and assert_scalar(mx_boxe[2], onnx_boxe[2]) and assert_scalar(mx_boxe[3], onnx_boxe[3]) @@ -256,7 +260,7 @@ def assert_scalar(a, b, tol=box_tol): onnx_id = onnx_ids[i][0] onnx_score = onnx_scores[i][0] onnx_boxe = onnx_boxes[i] - + print('onnx id', onnx_id) if onnx_score < score_thresh: break for j in range(len(mx_ids)): @@ -267,7 +271,7 @@ def assert_scalar(a, b, tol=box_tol): if onnx_score < mx_score - score_tol: continue if onnx_score > mx_score + score_tol: - return False + assert found_match, 'match not found' # check id if onnx_id != mx_id: continue @@ -275,10 +279,8 @@ def assert_scalar(a, b, tol=box_tol): if assert_bbox(mx_boxe, onnx_boxe): found_match = True break - if not found_match: - return False + assert found_match, 'match not found' found_match = False - return True def normalize_image(imgfile): img = mx.image.imread(imgfile) @@ -298,7 +300,10 @@ def normalize_image(imgfile): for img in obj_detection_test_images: img_data = normalize_image(img) - mx_class_ids, mx_scores, mx_boxes = M.predict(img_data) + if model.startswith('mask_rcnn'): + mx_class_ids, mx_scores, mx_boxes, _ = M.predict(img_data) + else: + mx_class_ids, mx_scores, mx_boxes = M.predict(img_data) # center_net_resnet models have different output format if 'center_net_resnet' in model: onnx_scores, onnx_class_ids, onnx_boxes = session.run([], {input_name: img_data.asnumpy()}) @@ -306,10 +311,15 @@ def normalize_image(imgfile): assert_almost_equal(mx_scores, onnx_scores) assert_almost_equal(mx_boxes, onnx_boxes) else: - onnx_class_ids, onnx_scores, onnx_boxes = session.run([], {input_name: img_data.asnumpy()}) - if not assert_obj_detetion_result(mx_class_ids[0], mx_scores[0], mx_boxes[0], \ - onnx_class_ids[0], onnx_scores[0], onnx_boxes[0]): - raise AssertionError("Assertion error on model: " + model) + if model.startswith('mask_rcnn'): + onnx_class_ids, onnx_scores, onnx_boxes, _ = session.run([], {input_name: img_data.asnumpy()}) + assert_obj_detetion_result(mx_class_ids[0], mx_scores[0], mx_boxes[0], + onnx_class_ids[0], onnx_scores[0], onnx_boxes[0], + score_thresh=0.8, score_tol=0.05, box_tol=15) + else: + onnx_class_ids, onnx_scores, onnx_boxes = session.run([], {input_name: img_data.asnumpy()}) + assert_obj_detetion_result(mx_class_ids[0], mx_scores[0], mx_boxes[0], + onnx_class_ids[0], onnx_scores[0], onnx_boxes[0]) finally: shutil.rmtree(tmp_path) From 3fc131bca8909fc81d35c45c15cdd6bc7da82dae Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 16 Apr 2021 02:35:15 +0000 Subject: [PATCH 4/4] remove print --- tests/python-pytest/onnx/test_onnxruntime_cv.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python-pytest/onnx/test_onnxruntime_cv.py b/tests/python-pytest/onnx/test_onnxruntime_cv.py index 61d449a5c73e..e03923b39aa5 100644 --- a/tests/python-pytest/onnx/test_onnxruntime_cv.py +++ b/tests/python-pytest/onnx/test_onnxruntime_cv.py @@ -260,7 +260,6 @@ def assert_scalar(a, b): onnx_id = onnx_ids[i][0] onnx_score = onnx_scores[i][0] onnx_boxe = onnx_boxes[i] - print('onnx id', onnx_id) if onnx_score < score_thresh: break for j in range(len(mx_ids)):