diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index f57f7e35972a..21cf9c3a1b97 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1811,6 +1811,53 @@ def _impl(inputs, input_types): return _op.meshgrid(data, indexing="ij") return _impl + +def _nms(prelude): + def _impl(inputs, input_types): + boxes = inputs[0] + scores = inputs[1] + iou_threshold = inputs[2] + + # Generate data with shape (1, num_anchors, 5) + scores = AttrCvt(op_name="expand_dims", + extras={'axis': -1, 'num_newaxis': 1})([scores], {}) + + # Prepare input data for get_valid_counts + data = _op.concatenate([scores, boxes], -1) + data = _op.expand_dims(data, 0, 1) + # Leverage get_valid_counts to sort the data and clear invalid boxes + ct, data, indices = get_relay_op('get_valid_counts')(data, + score_threshold=-1.0, + id_index=-1, + score_index=0) + + # Perform Non-Maximum Suppression, + # PyTorch NMS doesn't have parameter top_k and max_output_size + score_index = 0 + top_k = max_out_size = -1 + nms_ret = get_relay_op('non_max_suppression')(data=data, + valid_count=ct, + indices=indices, + max_output_size=max_out_size, + iou_threshold=iou_threshold, + force_suppress=True, + top_k=top_k, + coord_start=1, + score_index=score_index, + id_index=-1, + return_indices=True, + invalid_to_bottom=False) + + # squeeze the two outputs of nms for strided_slice + size = get_relay_op("squeeze")(nms_ret[1], axis=[1]) + data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0]) + + # strided slice to get the dynamic result + return get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]), + end=size, slice_mode="size") + return _impl + + def _pytorch_result_type(dtypes, non_tensor_inputs): """This promotes TVM dtypes like PyTorch would""" import torch @@ -2116,52 +2163,6 @@ def _get_convert_map(prelude): return convert_map -def _nms(prelude): - def _impl(inputs, input_types): - boxes = inputs[0] - scores = inputs[1] - iou_threshold = inputs[2] - - # Generate data with shape (1, num_anchors, 5) - scores = AttrCvt(op_name="expand_dims", - extras={'axis': -1, 'num_newaxis': 1})([scores], {}) - - # Prepare input data for get_valid_counts - data = _op.concatenate([scores, boxes], -1) - data = _op.expand_dims(data, 0, 1) - # Leverage get_valid_counts to sort the data and clear invalid boxes - ct, data, indices = get_relay_op('get_valid_counts')(data, - score_threshold=-1.0, - id_index=-1, - score_index=0) - - # Perform Non-Maximum Suppression, - # PyTorch NMS doesn't have parameter top_k and max_output_size - score_index = 0 - top_k = max_out_size = -1 - nms_ret = get_relay_op('non_max_suppression')(data=data, - valid_count=ct, - indices=indices, - max_output_size=max_out_size, - iou_threshold=iou_threshold, - force_suppress=True, - top_k=top_k, - coord_start=1, - score_index=score_index, - id_index=-1, - return_indices=True, - invalid_to_bottom=False) - - # squeeze the two outputs of nms for strided_slice - size = get_relay_op("squeeze")(nms_ret[1], axis=[1]) - data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0]) - - # strided slice to get the dynamic result - return get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]), - end=size, slice_mode="size") - return _impl - - def _run_jit_passes(graph): """ The inline pass is necessary to unwrap prim::CallMethod """ import torch diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index ca3e944d1d60..ed3d2da11c76 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1431,19 +1431,13 @@ def test_forward_upsample3d(): def test_forward_nms(): """dynamic Non-Maximum Suppression""" torch.set_grad_enabled(False) - class NonMaxSupression1(Module): - def forward(self, *args): - return torchvision.ops.nms(args[0], args[1], 0.3) - - class NonMaxSupression2(Module): - def forward(self, *args): - from torchvision.ops import nms - return torchvision.ops.nms(args[0], args[1], 0.5) + class NonMaxSupression(Module): + def __init__(self, iou_thres): + super().__init__() + self.iou_threshold = iou_thres - class NonMaxSupression3(Module): def forward(self, *args): - from torchvision.ops import nms - return torchvision.ops.nms(args[0], args[1], 0.9) + return torchvision.ops.nms(args[0], args[1], self.iou_threshold) # Generate random input data def _gen_rand_inputs(num_boxes): @@ -1454,20 +1448,11 @@ def _gen_rand_inputs(num_boxes): scores = torch.rand(num_boxes, dtype=torch.float) return boxes, scores - in_boxes, in_scores = _gen_rand_inputs(10) - scripted_model1 = torch.jit.trace(NonMaxSupression1(), [in_boxes, in_scores]) - verify_script_model(scripted_model1, [in_boxes.shape, in_scores.shape], - idata=[in_boxes, in_scores]) - - in_boxes, in_scores = _gen_rand_inputs(100) - scripted_model2 = torch.jit.trace(NonMaxSupression2(), [in_boxes, in_scores]) - verify_script_model(scripted_model2, [in_boxes.shape, in_scores.shape], - idata=[in_boxes, in_scores]) - - in_boxes, in_scores = _gen_rand_inputs(500) - scripted_model3 = torch.jit.trace(NonMaxSupression3(), [in_boxes, in_scores]) - verify_script_model(scripted_model3, [in_boxes.shape, in_scores.shape], - idata=[in_boxes, in_scores]) + for num_boxes, iou_thres in [(10, 0.3), (100, 0.5), (500, 0.9)]: + in_boxes, in_scores = _gen_rand_inputs(num_boxes) + traced_model = torch.jit.trace(NonMaxSupression(iou_thres), [in_boxes, in_scores]) + verify_trace_model(traced_model, [in_boxes.shape, in_scores.shape], + idata=[in_boxes, in_scores]) def test_conv3d(): @@ -1617,30 +1602,61 @@ def test_3d_models(): verify_model(resnet3d, [torch.rand(input_shape)], atol=1e-4, rtol=1e-4) -def verify_script_model(pt_model, ishapes, idata=None): +def verify_script_model(pt_model, ishapes): script_module = torch.jit.script(pt_model) input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)] input_shapes = list(zip(input_names, ishapes)) + inputs = [torch.randn(shape, dtype=torch.float) + for shape in ishapes] + + mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + + executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), + target="llvm") + evaluator = executor.evaluate() + + for name, inp in zip(input_names, inputs): + params[name] = inp.numpy() + + op_res = evaluator(**params) + + with torch.no_grad(): + pt_result = pt_model(*inputs) + + if not isinstance(pt_result, torch.Tensor): + tvm_res = op_res.asnumpy().item() + assert pt_result == tvm_res + else: + tvm.testing.assert_allclose(op_res.asnumpy(), pt_result.numpy(), + rtol=1e-5, atol=1e-5) + + +def verify_trace_model(traced_model, ishapes, idata=None): + trace_module = traced_model + + input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)] + input_shapes = list(zip(input_names, ishapes)) if not idata: - input_data = [torch.randn(shape, dtype=torch.float) for shape in ishapes] + inputs = [torch.randn(shape, dtype=torch.float) + for shape in ishapes] else: - input_data = idata + inputs = idata - mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + mod, params = relay.frontend.from_pytorch(trace_module, input_shapes) executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm") evaluator = executor.evaluate() - for name, inp in zip(input_names, input_data): + for name, inp in zip(input_names, inputs): params[name] = inp.numpy() op_res = evaluator(**params) with torch.no_grad(): - pt_result = pt_model(*input_data) + pt_result = trace_module(*inputs) if not isinstance(pt_result, torch.Tensor): tvm_res = op_res.asnumpy().item()