Skip to content

Commit

Permalink
add verify_model_vm
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Aug 24, 2020
1 parent 2f41107 commit 7a95355
Showing 1 changed file with 20 additions and 44 deletions.
64 changes: 20 additions & 44 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,9 +1450,7 @@ def _gen_rand_inputs(num_boxes):

for num_boxes, iou_thres in [(10, 0.3), (100, 0.5), (500, 0.9)]:
in_boxes, in_scores = _gen_rand_inputs(num_boxes)
traced_model = torch.jit.trace(NonMaxSupression(iou_thres), [in_boxes, in_scores])
verify_trace_model(traced_model, [in_boxes.shape, in_scores.shape],
idata=[in_boxes, in_scores])
verify_trace_model(NonMaxSupression(iou_thres), [in_boxes, in_scores])


def test_conv3d():
Expand Down Expand Up @@ -1604,65 +1602,43 @@ def test_3d_models():

def verify_script_model(pt_model, ishapes):
script_module = torch.jit.script(pt_model)
verify_model_vm(script_module, ishapes)

input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
input_shapes = list(zip(input_names, ishapes))

inputs = [torch.randn(shape, dtype=torch.float)
for shape in ishapes]

mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
def verify_trace_model(pt_model, idata):
traced_model = torch.jit.trace(pt_model, idata)
ishapes = [data.shape for data in idata]
verify_model_vm(traced_model, ishapes, idata=idata)

executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0),
target="llvm")
evaluator = executor.evaluate()

for name, inp in zip(input_names, inputs):
params[name] = inp.numpy()

op_res = evaluator(**params)

with torch.no_grad():
pt_result = pt_model(*inputs)

if not isinstance(pt_result, torch.Tensor):
tvm_res = op_res.asnumpy().item()
assert pt_result == tvm_res
else:
tvm.testing.assert_allclose(op_res.asnumpy(), pt_result.numpy(),
rtol=1e-5, atol=1e-5)


def verify_trace_model(traced_model, ishapes, idata=None):
trace_module = traced_model

def verify_model_vm(imodel, ishapes, idtype=torch.float, idata=None):
input_model = imodel
input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
input_shapes = list(zip(input_names, ishapes))
if not idata:
inputs = [torch.randn(shape, dtype=torch.float)
for shape in ishapes]
else:
inputs = idata

mod, params = relay.frontend.from_pytorch(trace_module, input_shapes)
input_data = idata if idata else [torch.randn(shape, dtype=idtype)
for shape in ishapes]
# Compile via VM
mod, params = relay.frontend.from_pytorch(input_model, input_shapes)

executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0),
target="llvm")
evaluator = executor.evaluate()

for name, inp in zip(input_names, inputs):
# Inference
for name, inp in zip(input_names, input_data):
params[name] = inp.numpy()
vm_res = evaluator(**params)

op_res = evaluator(**params)

# Baseline result
with torch.no_grad():
pt_result = trace_module(*inputs)
pt_result = input_model(*input_data)

# Verify the accuracy
if not isinstance(pt_result, torch.Tensor):
tvm_res = op_res.asnumpy().item()
tvm_res = vm_res.asnumpy().item()
assert pt_result == tvm_res
else:
tvm.testing.assert_allclose(op_res.asnumpy(), pt_result.numpy(),
tvm.testing.assert_allclose(vm_res.asnumpy(), pt_result.numpy(),
rtol=1e-5, atol=1e-5)


Expand Down

0 comments on commit 7a95355

Please sign in to comment.