Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Aug 24, 2020
1 parent 9e9c71f commit 2f41107
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 77 deletions.
93 changes: 47 additions & 46 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,6 +1811,53 @@ def _impl(inputs, input_types):
return _op.meshgrid(data, indexing="ij")
return _impl


def _nms(prelude):
def _impl(inputs, input_types):
boxes = inputs[0]
scores = inputs[1]
iou_threshold = inputs[2]

# Generate data with shape (1, num_anchors, 5)
scores = AttrCvt(op_name="expand_dims",
extras={'axis': -1, 'num_newaxis': 1})([scores], {})

# Prepare input data for get_valid_counts
data = _op.concatenate([scores, boxes], -1)
data = _op.expand_dims(data, 0, 1)
# Leverage get_valid_counts to sort the data and clear invalid boxes
ct, data, indices = get_relay_op('get_valid_counts')(data,
score_threshold=-1.0,
id_index=-1,
score_index=0)

# Perform Non-Maximum Suppression,
# PyTorch NMS doesn't have parameter top_k and max_output_size
score_index = 0
top_k = max_out_size = -1
nms_ret = get_relay_op('non_max_suppression')(data=data,
valid_count=ct,
indices=indices,
max_output_size=max_out_size,
iou_threshold=iou_threshold,
force_suppress=True,
top_k=top_k,
coord_start=1,
score_index=score_index,
id_index=-1,
return_indices=True,
invalid_to_bottom=False)

# squeeze the two outputs of nms for strided_slice
size = get_relay_op("squeeze")(nms_ret[1], axis=[1])
data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0])

# strided slice to get the dynamic result
return get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]),
end=size, slice_mode="size")
return _impl


def _pytorch_result_type(dtypes, non_tensor_inputs):
"""This promotes TVM dtypes like PyTorch would"""
import torch
Expand Down Expand Up @@ -2116,52 +2163,6 @@ def _get_convert_map(prelude):
return convert_map


def _nms(prelude):
def _impl(inputs, input_types):
boxes = inputs[0]
scores = inputs[1]
iou_threshold = inputs[2]

# Generate data with shape (1, num_anchors, 5)
scores = AttrCvt(op_name="expand_dims",
extras={'axis': -1, 'num_newaxis': 1})([scores], {})

# Prepare input data for get_valid_counts
data = _op.concatenate([scores, boxes], -1)
data = _op.expand_dims(data, 0, 1)
# Leverage get_valid_counts to sort the data and clear invalid boxes
ct, data, indices = get_relay_op('get_valid_counts')(data,
score_threshold=-1.0,
id_index=-1,
score_index=0)

# Perform Non-Maximum Suppression,
# PyTorch NMS doesn't have parameter top_k and max_output_size
score_index = 0
top_k = max_out_size = -1
nms_ret = get_relay_op('non_max_suppression')(data=data,
valid_count=ct,
indices=indices,
max_output_size=max_out_size,
iou_threshold=iou_threshold,
force_suppress=True,
top_k=top_k,
coord_start=1,
score_index=score_index,
id_index=-1,
return_indices=True,
invalid_to_bottom=False)

# squeeze the two outputs of nms for strided_slice
size = get_relay_op("squeeze")(nms_ret[1], axis=[1])
data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0])

# strided slice to get the dynamic result
return get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]),
end=size, slice_mode="size")
return _impl


def _run_jit_passes(graph):
""" The inline pass is necessary to unwrap prim::CallMethod """
import torch
Expand Down
78 changes: 47 additions & 31 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,19 +1431,13 @@ def test_forward_upsample3d():
def test_forward_nms():
"""dynamic Non-Maximum Suppression"""
torch.set_grad_enabled(False)
class NonMaxSupression1(Module):
def forward(self, *args):
return torchvision.ops.nms(args[0], args[1], 0.3)

class NonMaxSupression2(Module):
def forward(self, *args):
from torchvision.ops import nms
return torchvision.ops.nms(args[0], args[1], 0.5)
class NonMaxSupression(Module):
def __init__(self, iou_thres):
super().__init__()
self.iou_threshold = iou_thres

class NonMaxSupression3(Module):
def forward(self, *args):
from torchvision.ops import nms
return torchvision.ops.nms(args[0], args[1], 0.9)
return torchvision.ops.nms(args[0], args[1], self.iou_threshold)

# Generate random input data
def _gen_rand_inputs(num_boxes):
Expand All @@ -1454,20 +1448,11 @@ def _gen_rand_inputs(num_boxes):
scores = torch.rand(num_boxes, dtype=torch.float)
return boxes, scores

in_boxes, in_scores = _gen_rand_inputs(10)
scripted_model1 = torch.jit.trace(NonMaxSupression1(), [in_boxes, in_scores])
verify_script_model(scripted_model1, [in_boxes.shape, in_scores.shape],
idata=[in_boxes, in_scores])

in_boxes, in_scores = _gen_rand_inputs(100)
scripted_model2 = torch.jit.trace(NonMaxSupression2(), [in_boxes, in_scores])
verify_script_model(scripted_model2, [in_boxes.shape, in_scores.shape],
idata=[in_boxes, in_scores])

in_boxes, in_scores = _gen_rand_inputs(500)
scripted_model3 = torch.jit.trace(NonMaxSupression3(), [in_boxes, in_scores])
verify_script_model(scripted_model3, [in_boxes.shape, in_scores.shape],
idata=[in_boxes, in_scores])
for num_boxes, iou_thres in [(10, 0.3), (100, 0.5), (500, 0.9)]:
in_boxes, in_scores = _gen_rand_inputs(num_boxes)
traced_model = torch.jit.trace(NonMaxSupression(iou_thres), [in_boxes, in_scores])
verify_trace_model(traced_model, [in_boxes.shape, in_scores.shape],
idata=[in_boxes, in_scores])


def test_conv3d():
Expand Down Expand Up @@ -1617,30 +1602,61 @@ def test_3d_models():
verify_model(resnet3d, [torch.rand(input_shape)], atol=1e-4, rtol=1e-4)


def verify_script_model(pt_model, ishapes, idata=None):
def verify_script_model(pt_model, ishapes):
script_module = torch.jit.script(pt_model)

input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
input_shapes = list(zip(input_names, ishapes))

inputs = [torch.randn(shape, dtype=torch.float)
for shape in ishapes]

mod, params = relay.frontend.from_pytorch(script_module, input_shapes)

executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0),
target="llvm")
evaluator = executor.evaluate()

for name, inp in zip(input_names, inputs):
params[name] = inp.numpy()

op_res = evaluator(**params)

with torch.no_grad():
pt_result = pt_model(*inputs)

if not isinstance(pt_result, torch.Tensor):
tvm_res = op_res.asnumpy().item()
assert pt_result == tvm_res
else:
tvm.testing.assert_allclose(op_res.asnumpy(), pt_result.numpy(),
rtol=1e-5, atol=1e-5)


def verify_trace_model(traced_model, ishapes, idata=None):
trace_module = traced_model

input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
input_shapes = list(zip(input_names, ishapes))
if not idata:
input_data = [torch.randn(shape, dtype=torch.float) for shape in ishapes]
inputs = [torch.randn(shape, dtype=torch.float)
for shape in ishapes]
else:
input_data = idata
inputs = idata

mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
mod, params = relay.frontend.from_pytorch(trace_module, input_shapes)

executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0),
target="llvm")
evaluator = executor.evaluate()

for name, inp in zip(input_names, input_data):
for name, inp in zip(input_names, inputs):
params[name] = inp.numpy()

op_res = evaluator(**params)

with torch.no_grad():
pt_result = pt_model(*input_data)
pt_result = trace_module(*inputs)

if not isinstance(pt_result, torch.Tensor):
tvm_res = op_res.asnumpy().item()
Expand Down

0 comments on commit 2f41107

Please sign in to comment.