Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Onnx Fix 6 MaskRCNN models (#20178)
Browse files Browse the repository at this point in the history
* fixes for maskrcnn: 1. topk issue in nms 2. where operator when condition tensor needs to be broadcast

* fix for roi_align

* add model tests

* remove print
  • Loading branch information
Zha0q1 authored Apr 19, 2021
1 parent 45b5c11 commit 4bd7ad5
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3405,6 +3405,9 @@ def convert_contrib_box_nms(node, **kwargs):

center_point_box = 0 if in_format == 'corner' else 1

if topk == -1:
topk = 2**31-1

if in_format != out_format:
raise NotImplementedError('box_nms does not currently support in_fomat != out_format')

Expand Down Expand Up @@ -3560,17 +3563,29 @@ def convert_equal_scalar(node, **kwargs):
return nodes


@mx_op.register("where")
@mx_op.register('where')
def convert_where(node, **kwargs):
"""Map MXNet's where operator attributes to onnx's Where
operator and return the created node.
"""
from onnx.helper import make_node
from onnx import TensorProto
name, input_nodes, _ = get_inputs(node, kwargs)
# note that in mxnet the condition tensor can either have the same shape as x and y OR
# have shape (first dim of x,)
create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([1], name+'_1', kwargs['initializer'])
nodes = [
make_node("Cast", [input_nodes[0]], [name+"_bool"], to=int(TensorProto.BOOL)),
make_node("Where", [name+"_bool", input_nodes[1], input_nodes[2]], [name], name=name)
make_node('Shape', [input_nodes[0]], [name+'_cond_shape']),
make_node('Shape', [name+'_cond_shape'], [name+'_cond_dim']),
make_node('Shape', [input_nodes[1]], [name+'_x_shape']),
make_node('Shape', [name+'_x_shape'], [name+'_x_dim']),
make_node('Sub', [name+'_x_dim', name+'_cond_dim'], [name+'_sub']),
make_node('Concat', [name+'_0', name+'_sub'], [name+'_concat'], axis=0),
make_node('Pad', [name+'_cond_shape', name+'_concat', name+'_1'], [name+'_cond_new_shape']),
make_node('Reshape', [input_nodes[0], name+'_cond_new_shape'], [name+'_cond']),
make_node('Cast', [name+'_cond'], [name+'_bool'], to=int(TensorProto.BOOL)),
make_node('Where', [name+'_bool', input_nodes[1], input_nodes[2]], [name], name=name)
]
return nodes

Expand Down Expand Up @@ -4066,17 +4081,22 @@ def convert_contrib_roialign(node, **kwargs):
aligned!=False')

create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([0], name+'_0_s', kwargs['initializer'], dtype='float32')
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([5], name+'_5', kwargs['initializer'])

nodes = [
make_node('Slice', [input_nodes[1], name+'_1', name+'_5', name+'_1'], [name+'_rois']),
make_node('Slice', [input_nodes[1], name+'_0', name+'_1', name+'_1'], [name+'_inds__']),
make_node('Squeeze', [name+'_inds__'], [name+'_inds_'], axes=[1]),
make_node('Slice', [input_nodes[1], name+'_0', name+'_1', name+'_1'], [name+'_inds___']),
make_node('Squeeze', [name+'_inds___'], [name+'_inds__'], axes=[1]),
make_node('Relu', [name+'_inds__'], [name+'_inds_']),
make_node('Cast', [name+'_inds_'], [name+'_inds'], to=int(TensorProto.INT64)),
make_node('RoiAlign', [input_nodes[0], name+'_rois', name+'_inds'], [name],
make_node('RoiAlign', [input_nodes[0], name+'_rois', name+'_inds'], [name+'_roi'],
mode='avg', output_height=pooled_size[0], output_width=pooled_size[1],
sampling_ratio=sample_ratio, spatial_scale=spatial_scale)
sampling_ratio=sample_ratio, spatial_scale=spatial_scale),
make_node('Unsqueeze', [name+'_inds___'], [name+'_unsq'], axes=(2, 3)),
make_node('Less', [name+'_unsq', name+'_0_s'], [name+'_less']),
make_node('Where', [name+'_less', name+'_0_s', name+'_roi'], [name])
]

return nodes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,9 @@ def convert_contrib_box_nms(node, **kwargs):

center_point_box = 0 if in_format == 'corner' else 1

if topk == -1:
topk = 2**31-1

if in_format != out_format:
raise NotImplementedError('box_nms does not currently support in_fomat != out_format')

Expand Down Expand Up @@ -943,17 +946,23 @@ def convert_contrib_roialign(node, **kwargs):
aligned!=False')

create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([0], name+'_0_s', kwargs['initializer'], dtype='float32')
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([5], name+'_5', kwargs['initializer'])
create_tensor([2, 3], name+'_2_3', kwargs['initializer'])

nodes = [
make_node('Slice', [input_nodes[1], name+'_1', name+'_5', name+'_1'], [name+'_rois']),
make_node('Slice', [input_nodes[1], name+'_0', name+'_1', name+'_1'], [name+'_inds__']),
make_node('Squeeze', [name+'_inds__', name+'_1'], [name+'_inds_']),
make_node('Slice', [input_nodes[1], name+'_0', name+'_1', name+'_1'], [name+'_inds___']),
make_node('Squeeze', [name+'_inds___', name+'_1'], [name+'_inds__']),
make_node('Relu', [name+'_inds__'], [name+'_inds_']),
make_node('Cast', [name+'_inds_'], [name+'_inds'], to=int(TensorProto.INT64)),
make_node('RoiAlign', [input_nodes[0], name+'_rois', name+'_inds'], [name],
make_node('RoiAlign', [input_nodes[0], name+'_rois', name+'_inds'], [name+'_roi'],
mode='avg', output_height=pooled_size[0], output_width=pooled_size[1],
sampling_ratio=sample_ratio, spatial_scale=spatial_scale)
sampling_ratio=sample_ratio, spatial_scale=spatial_scale),
make_node('Unsqueeze', [name+'_inds___', name+'_2_3'], [name+'_unsq']),
make_node('Less', [name+'_unsq', name+'_0_s'], [name+'_less']),
make_node('Where', [name+'_less', name+'_0_s', name+'_roi'], [name])
]

return nodes
Expand Down
41 changes: 25 additions & 16 deletions tests/python-pytest/onnx/test_onnxruntime_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,6 @@ def obj_detection_test_images(tmpdir_factory):
tmpdir = tmpdir_factory.mktemp("obj_det_data")
from urllib.parse import urlparse
test_image_urls = [
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/car.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/duck.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/fieldhockey.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/flower.jpg',
'https://github.com/apache/incubator-mxnet-ci/raw/master/test-data/images/runners.jpg',
Expand Down Expand Up @@ -240,14 +238,20 @@ def obj_detection_test_images(tmpdir_factory):
'faster_rcnn_resnet101_v1d_coco',
'yolo3_darknet53_coco',
'yolo3_mobilenet1.0_coco',
'mask_rcnn_resnet18_v1b_coco',
'mask_rcnn_fpn_resnet18_v1b_coco',
'mask_rcnn_resnet50_v1b_coco',
'mask_rcnn_fpn_resnet50_v1b_coco',
'mask_rcnn_resnet101_v1d_coco',
'mask_rcnn_fpn_resnet101_v1d_coco',
])
def test_obj_detection_model_inference_onnxruntime(tmp_path, model, obj_detection_test_images):
def assert_obj_detetion_result(mx_ids, mx_scores, mx_boxes,
onnx_ids, onnx_scores, onnx_boxes,
score_thresh=0.6, score_tol=1e-4):
def assert_bbox(mx_boxe, onnx_boxe, box_tol=1e-2):
def assert_scalar(a, b, tol=box_tol):
return np.abs(a-b) <= tol
score_thresh=0.6, score_tol=0.0001, box_tol=0.01):
def assert_bbox(mx_boxe, onnx_boxe):
def assert_scalar(a, b):
return np.abs(a-b) <= box_tol
return assert_scalar(mx_boxe[0], onnx_boxe[0]) and assert_scalar(mx_boxe[1], onnx_boxe[1]) \
and assert_scalar(mx_boxe[2], onnx_boxe[2]) and assert_scalar(mx_boxe[3], onnx_boxe[3])

Expand All @@ -256,7 +260,6 @@ def assert_scalar(a, b, tol=box_tol):
onnx_id = onnx_ids[i][0]
onnx_score = onnx_scores[i][0]
onnx_boxe = onnx_boxes[i]

if onnx_score < score_thresh:
break
for j in range(len(mx_ids)):
Expand All @@ -267,18 +270,16 @@ def assert_scalar(a, b, tol=box_tol):
if onnx_score < mx_score - score_tol:
continue
if onnx_score > mx_score + score_tol:
return False
assert found_match, 'match not found'
# check id
if onnx_id != mx_id:
continue
# check bounding box
if assert_bbox(mx_boxe, onnx_boxe):
found_match = True
break
if not found_match:
return False
assert found_match, 'match not found'
found_match = False
return True

def normalize_image(imgfile):
img = mx.image.imread(imgfile)
Expand All @@ -298,18 +299,26 @@ def normalize_image(imgfile):

for img in obj_detection_test_images:
img_data = normalize_image(img)
mx_class_ids, mx_scores, mx_boxes = M.predict(img_data)
if model.startswith('mask_rcnn'):
mx_class_ids, mx_scores, mx_boxes, _ = M.predict(img_data)
else:
mx_class_ids, mx_scores, mx_boxes = M.predict(img_data)
# center_net_resnet models have different output format
if 'center_net_resnet' in model:
onnx_scores, onnx_class_ids, onnx_boxes = session.run([], {input_name: img_data.asnumpy()})
assert_almost_equal(mx_class_ids, onnx_class_ids)
assert_almost_equal(mx_scores, onnx_scores)
assert_almost_equal(mx_boxes, onnx_boxes)
else:
onnx_class_ids, onnx_scores, onnx_boxes = session.run([], {input_name: img_data.asnumpy()})
if not assert_obj_detetion_result(mx_class_ids[0], mx_scores[0], mx_boxes[0], \
onnx_class_ids[0], onnx_scores[0], onnx_boxes[0]):
raise AssertionError("Assertion error on model: " + model)
if model.startswith('mask_rcnn'):
onnx_class_ids, onnx_scores, onnx_boxes, _ = session.run([], {input_name: img_data.asnumpy()})
assert_obj_detetion_result(mx_class_ids[0], mx_scores[0], mx_boxes[0],
onnx_class_ids[0], onnx_scores[0], onnx_boxes[0],
score_thresh=0.8, score_tol=0.05, box_tol=15)
else:
onnx_class_ids, onnx_scores, onnx_boxes = session.run([], {input_name: img_data.asnumpy()})
assert_obj_detetion_result(mx_class_ids[0], mx_scores[0], mx_boxes[0],
onnx_class_ids[0], onnx_scores[0], onnx_boxes[0])

finally:
shutil.rmtree(tmp_path)
Expand Down
24 changes: 17 additions & 7 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def test_onnx_export_contrib_BilinearResize2D(tmp_path, dtype, params):
op_export_test('contrib_BilinearResize2D', M, [x], tmp_path)


@pytest.mark.parametrize('topk', [2, 3, 4])
@pytest.mark.parametrize('topk', [-1, 2, 3, 4])
@pytest.mark.parametrize('valid_thresh', [0.3, 0.4, 0.8])
@pytest.mark.parametrize('overlap_thresh', [0.4, 0.7, 1.0])
def test_onnx_export_contrib_box_nms(tmp_path, topk, valid_thresh, overlap_thresh):
Expand Down Expand Up @@ -577,12 +577,15 @@ def test_onnx_export_equal_scalar(tmp_path, dtype, scalar):
op_export_test('_internal._equal_scalar', M, [x], tmp_path)


@pytest.mark.parametrize("dtype", ["float16", "float32", "int32", "int64"])
@pytest.mark.parametrize("shape", [(1,1), (3,3), (10,2), (20,30,40)])
def test_onnx_export_where(tmp_path, dtype, shape):
@pytest.mark.parametrize('dtype', ["float16", "float32", "int32", "int64"])
@pytest.mark.parametrize('shape', [(5,), (3,3), (10,2), (20,30,40)])
@pytest.mark.parametrize('broadcast', [True, False])
def test_onnx_export_where(tmp_path, dtype, shape, broadcast):
M = def_model('where')
x = mx.nd.zeros(shape, dtype=dtype)
y = mx.nd.ones(shape, dtype=dtype)
if broadcast:
shape = shape[0:1]
cond = mx.nd.random.randint(low=0, high=1, shape=shape, dtype='int32')
op_export_test('where', M, [cond, x, y], tmp_path)

Expand Down Expand Up @@ -795,16 +798,23 @@ def test_onnx_export_broadcast_like(tmp_path, dtype, lhs_axes, rhs_axes):
@pytest.mark.parametrize('spatial_scale', [1, 0.5, 0.0625])
@pytest.mark.parametrize('spatial_ratio', [1, 2, 3, 5])
def test_onnx_export_contrib_ROIAlign(tmp_path, dtype, pooled_size, spatial_scale, spatial_ratio):
data = mx.random.uniform(0, 1, (5, 3, 128, 128)).astype(dtype)
rois = mx.nd.array([[0, 0, 0, 63, 63],
data = mx.random.uniform(0, 1, (5, 3, 512, 512)).astype(dtype)
rois = mx.nd.array([[-1, 0, 0, 0, 0],
[0, 0, 0, 63, 63],
[1, 34, 52, 25, 85],
[2, 50, 50, 100, 100],
[3, 0, 0, 127, 127],
[4, 12, 84, 22, 94],
[0, 0, 0, 1, 1]]).astype(dtype)
M = def_model('contrib.ROIAlign', pooled_size=pooled_size, spatial_scale=spatial_scale,
sample_ratio=spatial_ratio)
op_export_test('_contrib_ROIAlign', M, [data, rois], tmp_path)
# according to https://mxnet.apache.org/versions/1.7.0/api/python/docs/api/contrib/symbol/index.html#mxnet.contrib.symbol.ROIAlign
# the returned value for when batch_id < 0 should be all 0's
# however mxnet 1.8 does always behave this way so we set the first roi to 0's manually
def mx_map(x):
x[0] = 0
return x
op_export_test('_contrib_ROIAlign', M, [data, rois], tmp_path, mx_map=mx_map)


@pytest.mark.parametrize('dtype', ['float32', 'float64'])
Expand Down

0 comments on commit 4bd7ad5

Please sign in to comment.