diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 723740377cde5..21cf9c3a1b972 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -32,7 +32,7 @@ from ..ty import TupleType, TensorType, Any from ..loops import while_loop from .. import transform -from .common import get_relay_op +from .common import AttrCvt, get_relay_op from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value from .common import infer_value_simulated as _infer_value_simulated @@ -1811,6 +1811,53 @@ def _impl(inputs, input_types): return _op.meshgrid(data, indexing="ij") return _impl + +def _nms(prelude): + def _impl(inputs, input_types): + boxes = inputs[0] + scores = inputs[1] + iou_threshold = inputs[2] + + # Generate data with shape (1, num_anchors, 5) + scores = AttrCvt(op_name="expand_dims", + extras={'axis': -1, 'num_newaxis': 1})([scores], {}) + + # Prepare input data for get_valid_counts + data = _op.concatenate([scores, boxes], -1) + data = _op.expand_dims(data, 0, 1) + # Leverage get_valid_counts to sort the data and clear invalid boxes + ct, data, indices = get_relay_op('get_valid_counts')(data, + score_threshold=-1.0, + id_index=-1, + score_index=0) + + # Perform Non-Maximum Suppression, + # PyTorch NMS doesn't have parameter top_k and max_output_size + score_index = 0 + top_k = max_out_size = -1 + nms_ret = get_relay_op('non_max_suppression')(data=data, + valid_count=ct, + indices=indices, + max_output_size=max_out_size, + iou_threshold=iou_threshold, + force_suppress=True, + top_k=top_k, + coord_start=1, + score_index=score_index, + id_index=-1, + return_indices=True, + invalid_to_bottom=False) + + # squeeze the two outputs of nms for strided_slice + size = get_relay_op("squeeze")(nms_ret[1], axis=[1]) + data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0]) + + # strided slice to get the dynamic result + return get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]), + end=size, slice_mode="size") + return _impl + + def _pytorch_result_type(dtypes, non_tensor_inputs): """This promotes TVM dtypes like PyTorch would""" import torch @@ -2111,6 +2158,7 @@ def _get_convert_map(prelude): "aten::gather" : _gather(), "aten::index_select" : _select(), "aten::index" : _index(), + "torchvision::nms" : _nms(prelude), } return convert_map diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index ab0a4b03cafa0..946712df50861 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1428,6 +1428,31 @@ def test_forward_upsample3d(): verify_model(torch.nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True).eval(), inp) +def test_forward_nms(): + """dynamic Non-Maximum Suppression""" + torch.set_grad_enabled(False) + class NonMaxSupression(Module): + def __init__(self, iou_thres): + super().__init__() + self.iou_threshold = iou_thres + + def forward(self, *args): + return torchvision.ops.nms(args[0], args[1], self.iou_threshold) + + # Generate random input data + def _gen_rand_inputs(num_boxes): + box_len = 4 + boxes = torch.rand(num_boxes, box_len, dtype=torch.float) * 0.5 + boxes[:, 2] += boxes[:, 0] + boxes[:, 3] += boxes[:, 1] + scores = torch.rand(num_boxes, dtype=torch.float) + return boxes, scores + + for num_boxes, iou_thres in [(10, 0.3), (100, 0.5), (500, 0.9)]: + in_boxes, in_scores = _gen_rand_inputs(num_boxes) + verify_trace_model(NonMaxSupression(iou_thres), [in_boxes, in_scores]) + + def test_conv3d(): for ishape in [(1, 32, 16, 16, 16), (1, 32, 9, 15, 15), @@ -1577,32 +1602,43 @@ def test_3d_models(): def verify_script_model(pt_model, ishapes): script_module = torch.jit.script(pt_model) + verify_model_vm(script_module, ishapes) - input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)] - input_shapes = list(zip(input_names, ishapes)) - inputs = [torch.randn(shape, dtype=torch.float) - for shape in ishapes] +def verify_trace_model(pt_model, idata): + traced_model = torch.jit.trace(pt_model, idata) + ishapes = [data.shape for data in idata] + verify_model_vm(traced_model, ishapes, idata=idata) - mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + +def verify_model_vm(imodel, ishapes, idtype=torch.float, idata=None): + input_model = imodel + input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)] + input_shapes = list(zip(input_names, ishapes)) + input_data = idata if idata else [torch.randn(shape, dtype=idtype) + for shape in ishapes] + # Compile via VM + mod, params = relay.frontend.from_pytorch(input_model, input_shapes) executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm") evaluator = executor.evaluate() - for name, inp in zip(input_names, inputs): + # Inference + for name, inp in zip(input_names, input_data): params[name] = inp.numpy() + vm_res = evaluator(**params) - op_res = evaluator(**params) - + # Baseline result with torch.no_grad(): - pt_result = pt_model(*inputs) + pt_result = input_model(*input_data) + # Verify the accuracy if not isinstance(pt_result, torch.Tensor): - tvm_res = op_res.asnumpy().item() + tvm_res = vm_res.asnumpy().item() assert pt_result == tvm_res else: - tvm.testing.assert_allclose(op_res.asnumpy(), pt_result.numpy(), + tvm.testing.assert_allclose(vm_res.asnumpy(), pt_result.numpy(), rtol=1e-5, atol=1e-5) @@ -2863,6 +2899,7 @@ def test_forward_pretrained_bert_base_uncased(): test_forward_gather() test_upsample() test_forward_upsample3d() + test_forward_nms() test_to() test_type_as() test_forward_functional_pad()