From 9e9c71f0290d60dfac98d2e5164478d39452003d Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Thu, 20 Aug 2020 17:12:19 +0800 Subject: [PATCH 1/3] [Relay] Support for PyTorch Non-Maximum Suppression --- python/tvm/relay/frontend/pytorch.py | 49 ++++++++++++++++- tests/python/frontend/pytorch/test_forward.py | 55 +++++++++++++++++-- 2 files changed, 98 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 723740377cde..f57f7e35972a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -32,7 +32,7 @@ from ..ty import TupleType, TensorType, Any from ..loops import while_loop from .. import transform -from .common import get_relay_op +from .common import AttrCvt, get_relay_op from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value from .common import infer_value_simulated as _infer_value_simulated @@ -2111,10 +2111,57 @@ def _get_convert_map(prelude): "aten::gather" : _gather(), "aten::index_select" : _select(), "aten::index" : _index(), + "torchvision::nms" : _nms(prelude), } return convert_map +def _nms(prelude): + def _impl(inputs, input_types): + boxes = inputs[0] + scores = inputs[1] + iou_threshold = inputs[2] + + # Generate data with shape (1, num_anchors, 5) + scores = AttrCvt(op_name="expand_dims", + extras={'axis': -1, 'num_newaxis': 1})([scores], {}) + + # Prepare input data for get_valid_counts + data = _op.concatenate([scores, boxes], -1) + data = _op.expand_dims(data, 0, 1) + # Leverage get_valid_counts to sort the data and clear invalid boxes + ct, data, indices = get_relay_op('get_valid_counts')(data, + score_threshold=-1.0, + id_index=-1, + score_index=0) + + # Perform Non-Maximum Suppression, + # PyTorch NMS doesn't have parameter top_k and max_output_size + score_index = 0 + top_k = max_out_size = -1 + nms_ret = get_relay_op('non_max_suppression')(data=data, + valid_count=ct, + indices=indices, + max_output_size=max_out_size, + iou_threshold=iou_threshold, + force_suppress=True, + top_k=top_k, + coord_start=1, + score_index=score_index, + id_index=-1, + return_indices=True, + invalid_to_bottom=False) + + # squeeze the two outputs of nms for strided_slice + size = get_relay_op("squeeze")(nms_ret[1], axis=[1]) + data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0]) + + # strided slice to get the dynamic result + return get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]), + end=size, slice_mode="size") + return _impl + + def _run_jit_passes(graph): """ The inline pass is necessary to unwrap prim::CallMethod """ import torch diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index ab0a4b03cafa..ca3e944d1d60 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1428,6 +1428,48 @@ def test_forward_upsample3d(): verify_model(torch.nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True).eval(), inp) +def test_forward_nms(): + """dynamic Non-Maximum Suppression""" + torch.set_grad_enabled(False) + class NonMaxSupression1(Module): + def forward(self, *args): + return torchvision.ops.nms(args[0], args[1], 0.3) + + class NonMaxSupression2(Module): + def forward(self, *args): + from torchvision.ops import nms + return torchvision.ops.nms(args[0], args[1], 0.5) + + class NonMaxSupression3(Module): + def forward(self, *args): + from torchvision.ops import nms + return torchvision.ops.nms(args[0], args[1], 0.9) + + # Generate random input data + def _gen_rand_inputs(num_boxes): + box_len = 4 + boxes = torch.rand(num_boxes, box_len, dtype=torch.float) * 0.5 + boxes[:, 2] += boxes[:, 0] + boxes[:, 3] += boxes[:, 1] + scores = torch.rand(num_boxes, dtype=torch.float) + return boxes, scores + + in_boxes, in_scores = _gen_rand_inputs(10) + scripted_model1 = torch.jit.trace(NonMaxSupression1(), [in_boxes, in_scores]) + verify_script_model(scripted_model1, [in_boxes.shape, in_scores.shape], + idata=[in_boxes, in_scores]) + + in_boxes, in_scores = _gen_rand_inputs(100) + scripted_model2 = torch.jit.trace(NonMaxSupression2(), [in_boxes, in_scores]) + verify_script_model(scripted_model2, [in_boxes.shape, in_scores.shape], + idata=[in_boxes, in_scores]) + + in_boxes, in_scores = _gen_rand_inputs(500) + scripted_model3 = torch.jit.trace(NonMaxSupression3(), [in_boxes, in_scores]) + verify_script_model(scripted_model3, [in_boxes.shape, in_scores.shape], + idata=[in_boxes, in_scores]) + + def test_conv3d(): for ishape in [(1, 32, 16, 16, 16), (1, 32, 9, 15, 15), @@ -1575,14 +1617,16 @@ def test_3d_models(): verify_model(resnet3d, [torch.rand(input_shape)], atol=1e-4, rtol=1e-4) -def verify_script_model(pt_model, ishapes): +def verify_script_model(pt_model, ishapes, idata=None): script_module = torch.jit.script(pt_model) input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)] input_shapes = list(zip(input_names, ishapes)) - inputs = [torch.randn(shape, dtype=torch.float) - for shape in ishapes] + if not idata: + input_data = [torch.randn(shape, dtype=torch.float) for shape in ishapes] + else: + input_data = idata mod, params = relay.frontend.from_pytorch(script_module, input_shapes) @@ -1590,13 +1634,13 @@ def verify_script_model(pt_model, ishapes): target="llvm") evaluator = executor.evaluate() - for name, inp in zip(input_names, inputs): + for name, inp in zip(input_names, input_data): params[name] = inp.numpy() op_res = evaluator(**params) with torch.no_grad(): - pt_result = pt_model(*inputs) + pt_result = pt_model(*input_data) if not isinstance(pt_result, torch.Tensor): tvm_res = op_res.asnumpy().item() @@ -2863,6 +2907,7 @@ def test_forward_pretrained_bert_base_uncased(): test_forward_gather() test_upsample() test_forward_upsample3d() + test_forward_nms() test_to() test_type_as() test_forward_functional_pad() From 2f41107f710ac0b93192b542777b3f220f18fe3b Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Thu, 20 Aug 2020 20:27:26 +0800 Subject: [PATCH 2/3] fix comment --- python/tvm/relay/frontend/pytorch.py | 93 ++++++++++--------- tests/python/frontend/pytorch/test_forward.py | 78 +++++++++------- 2 files changed, 94 insertions(+), 77 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index f57f7e35972a..21cf9c3a1b97 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1811,6 +1811,53 @@ def _impl(inputs, input_types): return _op.meshgrid(data, indexing="ij") return _impl + +def _nms(prelude): + def _impl(inputs, input_types): + boxes = inputs[0] + scores = inputs[1] + iou_threshold = inputs[2] + + # Generate data with shape (1, num_anchors, 5) + scores = AttrCvt(op_name="expand_dims", + extras={'axis': -1, 'num_newaxis': 1})([scores], {}) + + # Prepare input data for get_valid_counts + data = _op.concatenate([scores, boxes], -1) + data = _op.expand_dims(data, 0, 1) + # Leverage get_valid_counts to sort the data and clear invalid boxes + ct, data, indices = get_relay_op('get_valid_counts')(data, + score_threshold=-1.0, + id_index=-1, + score_index=0) + + # Perform Non-Maximum Suppression, + # PyTorch NMS doesn't have parameter top_k and max_output_size + score_index = 0 + top_k = max_out_size = -1 + nms_ret = get_relay_op('non_max_suppression')(data=data, + valid_count=ct, + indices=indices, + max_output_size=max_out_size, + iou_threshold=iou_threshold, + force_suppress=True, + top_k=top_k, + coord_start=1, + score_index=score_index, + id_index=-1, + return_indices=True, + invalid_to_bottom=False) + + # squeeze the two outputs of nms for strided_slice + size = get_relay_op("squeeze")(nms_ret[1], axis=[1]) + data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0]) + + # strided slice to get the dynamic result + return get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]), + end=size, slice_mode="size") + return _impl + + def _pytorch_result_type(dtypes, non_tensor_inputs): """This promotes TVM dtypes like PyTorch would""" import torch @@ -2116,52 +2163,6 @@ def _get_convert_map(prelude): return convert_map -def _nms(prelude): - def _impl(inputs, input_types): - boxes = inputs[0] - scores = inputs[1] - iou_threshold = inputs[2] - - # Generate data with shape (1, num_anchors, 5) - scores = AttrCvt(op_name="expand_dims", - extras={'axis': -1, 'num_newaxis': 1})([scores], {}) - - # Prepare input data for get_valid_counts - data = _op.concatenate([scores, boxes], -1) - data = _op.expand_dims(data, 0, 1) - # Leverage get_valid_counts to sort the data and clear invalid boxes - ct, data, indices = get_relay_op('get_valid_counts')(data, - score_threshold=-1.0, - id_index=-1, - score_index=0) - - # Perform Non-Maximum Suppression, - # PyTorch NMS doesn't have parameter top_k and max_output_size - score_index = 0 - top_k = max_out_size = -1 - nms_ret = get_relay_op('non_max_suppression')(data=data, - valid_count=ct, - indices=indices, - max_output_size=max_out_size, - iou_threshold=iou_threshold, - force_suppress=True, - top_k=top_k, - coord_start=1, - score_index=score_index, - id_index=-1, - return_indices=True, - invalid_to_bottom=False) - - # squeeze the two outputs of nms for strided_slice - size = get_relay_op("squeeze")(nms_ret[1], axis=[1]) - data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0]) - - # strided slice to get the dynamic result - return get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]), - end=size, slice_mode="size") - return _impl - - def _run_jit_passes(graph): """ The inline pass is necessary to unwrap prim::CallMethod """ import torch diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index ca3e944d1d60..ed3d2da11c76 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1431,19 +1431,13 @@ def test_forward_upsample3d(): def test_forward_nms(): """dynamic Non-Maximum Suppression""" torch.set_grad_enabled(False) - class NonMaxSupression1(Module): - def forward(self, *args): - return torchvision.ops.nms(args[0], args[1], 0.3) - - class NonMaxSupression2(Module): - def forward(self, *args): - from torchvision.ops import nms - return torchvision.ops.nms(args[0], args[1], 0.5) + class NonMaxSupression(Module): + def __init__(self, iou_thres): + super().__init__() + self.iou_threshold = iou_thres - class NonMaxSupression3(Module): def forward(self, *args): - from torchvision.ops import nms - return torchvision.ops.nms(args[0], args[1], 0.9) + return torchvision.ops.nms(args[0], args[1], self.iou_threshold) # Generate random input data def _gen_rand_inputs(num_boxes): @@ -1454,20 +1448,11 @@ def _gen_rand_inputs(num_boxes): scores = torch.rand(num_boxes, dtype=torch.float) return boxes, scores - in_boxes, in_scores = _gen_rand_inputs(10) - scripted_model1 = torch.jit.trace(NonMaxSupression1(), [in_boxes, in_scores]) - verify_script_model(scripted_model1, [in_boxes.shape, in_scores.shape], - idata=[in_boxes, in_scores]) - - in_boxes, in_scores = _gen_rand_inputs(100) - scripted_model2 = torch.jit.trace(NonMaxSupression2(), [in_boxes, in_scores]) - verify_script_model(scripted_model2, [in_boxes.shape, in_scores.shape], - idata=[in_boxes, in_scores]) - - in_boxes, in_scores = _gen_rand_inputs(500) - scripted_model3 = torch.jit.trace(NonMaxSupression3(), [in_boxes, in_scores]) - verify_script_model(scripted_model3, [in_boxes.shape, in_scores.shape], - idata=[in_boxes, in_scores]) + for num_boxes, iou_thres in [(10, 0.3), (100, 0.5), (500, 0.9)]: + in_boxes, in_scores = _gen_rand_inputs(num_boxes) + traced_model = torch.jit.trace(NonMaxSupression(iou_thres), [in_boxes, in_scores]) + verify_trace_model(traced_model, [in_boxes.shape, in_scores.shape], + idata=[in_boxes, in_scores]) def test_conv3d(): @@ -1617,30 +1602,61 @@ def test_3d_models(): verify_model(resnet3d, [torch.rand(input_shape)], atol=1e-4, rtol=1e-4) -def verify_script_model(pt_model, ishapes, idata=None): +def verify_script_model(pt_model, ishapes): script_module = torch.jit.script(pt_model) input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)] input_shapes = list(zip(input_names, ishapes)) + inputs = [torch.randn(shape, dtype=torch.float) + for shape in ishapes] + + mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + + executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), + target="llvm") + evaluator = executor.evaluate() + + for name, inp in zip(input_names, inputs): + params[name] = inp.numpy() + + op_res = evaluator(**params) + + with torch.no_grad(): + pt_result = pt_model(*inputs) + + if not isinstance(pt_result, torch.Tensor): + tvm_res = op_res.asnumpy().item() + assert pt_result == tvm_res + else: + tvm.testing.assert_allclose(op_res.asnumpy(), pt_result.numpy(), + rtol=1e-5, atol=1e-5) + + +def verify_trace_model(traced_model, ishapes, idata=None): + trace_module = traced_model + + input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)] + input_shapes = list(zip(input_names, ishapes)) if not idata: - input_data = [torch.randn(shape, dtype=torch.float) for shape in ishapes] + inputs = [torch.randn(shape, dtype=torch.float) + for shape in ishapes] else: - input_data = idata + inputs = idata - mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + mod, params = relay.frontend.from_pytorch(trace_module, input_shapes) executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm") evaluator = executor.evaluate() - for name, inp in zip(input_names, input_data): + for name, inp in zip(input_names, inputs): params[name] = inp.numpy() op_res = evaluator(**params) with torch.no_grad(): - pt_result = pt_model(*input_data) + pt_result = trace_module(*inputs) if not isinstance(pt_result, torch.Tensor): tvm_res = op_res.asnumpy().item() From 7a953559d5ad9878844801be8cb213e5770c4eed Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Mon, 24 Aug 2020 15:55:14 +0800 Subject: [PATCH 3/3] add verify_model_vm --- tests/python/frontend/pytorch/test_forward.py | 64 ++++++------------- 1 file changed, 20 insertions(+), 44 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index ed3d2da11c76..946712df5086 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1450,9 +1450,7 @@ def _gen_rand_inputs(num_boxes): for num_boxes, iou_thres in [(10, 0.3), (100, 0.5), (500, 0.9)]: in_boxes, in_scores = _gen_rand_inputs(num_boxes) - traced_model = torch.jit.trace(NonMaxSupression(iou_thres), [in_boxes, in_scores]) - verify_trace_model(traced_model, [in_boxes.shape, in_scores.shape], - idata=[in_boxes, in_scores]) + verify_trace_model(NonMaxSupression(iou_thres), [in_boxes, in_scores]) def test_conv3d(): @@ -1604,65 +1602,43 @@ def test_3d_models(): def verify_script_model(pt_model, ishapes): script_module = torch.jit.script(pt_model) + verify_model_vm(script_module, ishapes) - input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)] - input_shapes = list(zip(input_names, ishapes)) - - inputs = [torch.randn(shape, dtype=torch.float) - for shape in ishapes] - mod, params = relay.frontend.from_pytorch(script_module, input_shapes) +def verify_trace_model(pt_model, idata): + traced_model = torch.jit.trace(pt_model, idata) + ishapes = [data.shape for data in idata] + verify_model_vm(traced_model, ishapes, idata=idata) - executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), - target="llvm") - evaluator = executor.evaluate() - - for name, inp in zip(input_names, inputs): - params[name] = inp.numpy() - - op_res = evaluator(**params) - - with torch.no_grad(): - pt_result = pt_model(*inputs) - - if not isinstance(pt_result, torch.Tensor): - tvm_res = op_res.asnumpy().item() - assert pt_result == tvm_res - else: - tvm.testing.assert_allclose(op_res.asnumpy(), pt_result.numpy(), - rtol=1e-5, atol=1e-5) - - -def verify_trace_model(traced_model, ishapes, idata=None): - trace_module = traced_model +def verify_model_vm(imodel, ishapes, idtype=torch.float, idata=None): + input_model = imodel input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)] input_shapes = list(zip(input_names, ishapes)) - if not idata: - inputs = [torch.randn(shape, dtype=torch.float) - for shape in ishapes] - else: - inputs = idata - - mod, params = relay.frontend.from_pytorch(trace_module, input_shapes) + input_data = idata if idata else [torch.randn(shape, dtype=idtype) + for shape in ishapes] + # Compile via VM + mod, params = relay.frontend.from_pytorch(input_model, input_shapes) executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm") evaluator = executor.evaluate() - for name, inp in zip(input_names, inputs): + # Inference + for name, inp in zip(input_names, input_data): params[name] = inp.numpy() + vm_res = evaluator(**params) - op_res = evaluator(**params) - + # Baseline result with torch.no_grad(): - pt_result = trace_module(*inputs) + pt_result = input_model(*input_data) + # Verify the accuracy if not isinstance(pt_result, torch.Tensor): - tvm_res = op_res.asnumpy().item() + tvm_res = vm_res.asnumpy().item() assert pt_result == tvm_res else: - tvm.testing.assert_allclose(op_res.asnumpy(), pt_result.numpy(), + tvm.testing.assert_allclose(vm_res.asnumpy(), pt_result.numpy(), rtol=1e-5, atol=1e-5)