From 3907c103c2911164df19f458bceef81f98b32841 Mon Sep 17 00:00:00 2001 From: Wang Date: Thu, 17 Jan 2019 22:16:24 -0800 Subject: [PATCH] Fix ctx_list --- nnvm/tests/python/compiler/test_top_level4.py | 82 ++++++++++--------- tutorials/nnvm/deploy_ssd.py | 64 ++++++++------- 2 files changed, 75 insertions(+), 71 deletions(-) diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py index 38646b01a4c96..87620c8b3acf0 100644 --- a/nnvm/tests/python/compiler/test_top_level4.py +++ b/nnvm/tests/python/compiler/test_top_level4.py @@ -528,14 +528,13 @@ def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), if clip: np_out = np.clip(np_out, 0, 1) - target = "llvm" - ctx = tvm.cpu() - graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape}) - m = graph_runtime.create(graph, lib, ctx) - m.set_input("data", np.random.uniform(size=dshape).astype(dtype)) - m.run() - out = m.get_output(0, tvm.nd.empty(np_out.shape, dtype)) - tvm.testing.assert_allclose(out.asnumpy(), np_out, atol=1e-5, rtol=1e-5) + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape}) + m = graph_runtime.create(graph, lib, ctx) + m.set_input("data", np.random.uniform(size=dshape).astype(dtype)) + m.run() + out = m.get_output(0, tvm.nd.empty(np_out.shape, dtype)) + tvm.testing.assert_allclose(out.asnumpy(), np_out, atol=1e-5, rtol=1e-5) def test_multibox_prior(): verify_multibox_prior((1, 3, 50, 50)) @@ -562,17 +561,18 @@ def test_multibox_transform_loc(): [0, 0.44999999, 1, 1, 1, 1], [0, 0.30000001, 0, 0, 0.22903419, 0.20435292]]]) - target = "llvm" dtype = "float32" - ctx = tvm.cpu() - graph, lib, _ = nnvm.compiler.build(out, target, {"cls_prob": (batch_size, num_anchors, num_classes), - "loc_preds": (batch_size, num_anchors * 4), - "anchors": (1, num_anchors, 4)}) - m = graph_runtime.create(graph, lib, ctx) - m.set_input(**{"cls_prob": np_cls_prob.astype(dtype), "loc_preds": np_loc_preds.astype(dtype), "anchors": np_anchors.astype(dtype)}) - m.run() - out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype)) - tvm.testing.assert_allclose(out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5) + for target, ctx in ctx_list(): + if target == "cuda": + continue + graph, lib, _ = nnvm.compiler.build(out, target, {"cls_prob": (batch_size, num_anchors, num_classes), + "loc_preds": (batch_size, num_anchors * 4), + "anchors": (1, num_anchors, 4)}) + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**{"cls_prob": np_cls_prob.astype(dtype), "loc_preds": np_loc_preds.astype(dtype), "anchors": np_anchors.astype(dtype)}) + m.run() + out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype)) + tvm.testing.assert_allclose(out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5) def verify_get_valid_counts(dshape, score_threshold): dtype = "float32" @@ -594,19 +594,20 @@ def verify_get_valid_counts(dshape, score_threshold): for k in range(elem_length): np_out2[i, j, k] = -1 - target = "llvm" - ctx = tvm.cpu() - data = sym.Variable("data", dtype=dtype) - valid_counts, inter_data = sym.get_valid_counts(data, score_threshold=score_threshold) - out = sym.Group([valid_counts, inter_data]) - graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape}) - m = graph_runtime.create(graph, lib, ctx) - m.set_input("data", np_data) - m.run() - out1 = m.get_output(0, tvm.nd.empty(np_out1.shape, "int32")) - out2 = m.get_output(1, tvm.nd.empty(dshape, dtype)) - tvm.testing.assert_allclose(out1.asnumpy(), np_out1, rtol=1e-3) - tvm.testing.assert_allclose(out2.asnumpy(), np_out2, rtol=1e-3) + for target, ctx in ctx_list(): + if target == "cuda": + continue + data = sym.Variable("data", dtype=dtype) + valid_counts, inter_data = sym.get_valid_counts(data, score_threshold=score_threshold) + out = sym.Group([valid_counts, inter_data]) + graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape}) + m = graph_runtime.create(graph, lib, ctx) + m.set_input("data", np_data) + m.run() + out1 = m.get_output(0, tvm.nd.empty(np_out1.shape, "int32")) + out2 = m.get_output(1, tvm.nd.empty(dshape, dtype)) + tvm.testing.assert_allclose(out1.asnumpy(), np_out1, rtol=1e-3) + tvm.testing.assert_allclose(out2.asnumpy(), np_out2, rtol=1e-3) def test_get_valid_counts(): @@ -633,15 +634,16 @@ def test_nms(): [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]]]) - target = "llvm" - ctx = tvm.cpu() - graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape, "valid_count": (dshape[0],)}, - dtype={"data": "float32", "valid_count": "int32"}) - m = graph_runtime.create(graph, lib, ctx) - m.set_input(**{"data": np_data, "valid_count": np_valid_count}) - m.run() - out = m.get_output(0, tvm.nd.empty(np_result.shape, "float32")) - tvm.testing.assert_allclose(out.asnumpy(), np_result, atol=1e-5, rtol=1e-5) + for target, ctx in ctx_list(): + if target == "cuda": + continue + graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape, "valid_count": (dshape[0],)}, + dtype={"data": "float32", "valid_count": "int32"}) + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**{"data": np_data, "valid_count": np_valid_count}) + m.run() + out = m.get_output(0, tvm.nd.empty(np_result.shape, "float32")) + tvm.testing.assert_allclose(out.asnumpy(), np_result, atol=1e-5, rtol=1e-5) def np_slice_like(np_data, np_shape_like, axis=[]): begin_idx = [0 for _ in np_data.shape] diff --git a/tutorials/nnvm/deploy_ssd.py b/tutorials/nnvm/deploy_ssd.py index a9a279f672a0f..d83d1f86b75e7 100644 --- a/tutorials/nnvm/deploy_ssd.py +++ b/tutorials/nnvm/deploy_ssd.py @@ -11,6 +11,7 @@ from matplotlib import pyplot as plt from nnvm import compiler from nnvm.frontend import from_mxnet +from nnvm.testing.config import ctx_list from tvm import relay from tvm.contrib import graph_runtime from gluoncv import model_zoo, data, utils @@ -52,8 +53,8 @@ model_name = "ssd_512_resnet50_v1_voc" dshape = (1, 3, 512, 512) dtype = "float32" -target = "llvm" -ctx = tvm.cpu() +target_list = ctx_list() +frontend_list = ["nnvm", "relay"] ###################################################################### # Download and pre-process demo image @@ -62,45 +63,46 @@ 'gluoncv/detection/street_small.jpg?raw=true', path='street_small.jpg') x, img = data.transforms.presets.ssd.load_test(im_fname, short=512) -tvm_input = tvm.nd.array(x.asnumpy(), ctx=ctx) ###################################################################### # Convert and compile model with NNVM or Relay for CPU. block = model_zoo.get_model(model_name, pretrained=True) -import argparse -parser = argparse.ArgumentParser() -parser.add_argument( - "-f", "--frontend", - help="Frontend for compilation, nnvm or relay", - type=str, - default="nnvm") -args = parser.parse_args() -if args.frontend == "relay": - net, params = relay.frontend.from_mxnet(block, {"data": dshape}) - with relay.build_config(opt_level=3): - graph, lib, params = relay.build(net, target, params=params) -elif args.frontend == "nnvm": - net, params = from_mxnet(block) - with compiler.build_config(opt_level=3): - graph, lib, params = compiler.build( - net, target, {"data": dshape}, params=params) -else: - parser.print_help() - parser.exit() +def compile(frontend, target): + if frontend == "relay": + net, params = relay.frontend.from_mxnet(block, {"data": dshape}) + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(net, target, params=params) + else: + net, params = from_mxnet(block) + with compiler.build_config(opt_level=3): + graph, lib, params = compiler.build( + net, target, {"data": dshape}, params=params) + return graph, lib, params ###################################################################### # Create TVM runtime and do inference -# Build TVM runtime -m = graph_runtime.create(graph, lib, ctx) -m.set_input('data', tvm_input) -m.set_input(**params) -# execute -m.run() -# get outputs -class_IDs, scores, bounding_boxs = m.get_output(0), m.get_output(1), m.get_output(2) +def run(graph, lib, params, ctx): + # Build TVM runtime + m = graph_runtime.create(graph, lib, ctx) + tvm_input = tvm.nd.array(x.asnumpy(), ctx=ctx) + m.set_input('data', tvm_input) + m.set_input(**params) + # execute + m.run() + # get outputs + class_IDs, scores, bounding_boxs = m.get_output(0), m.get_output(1), m.get_output(2) + return class_IDs, scores, bounding_boxs + +for target, ctx in target_list: + if target == "cuda": + print("GPU not supported yet, skip.") + continue + for frontend in frontend_list: + graph, lib, params = compile(frontend, target) + class_IDs, scores, bounding_boxs = run(graph, lib, params, ctx) ###################################################################### # Display result