diff --git a/nnvm/python/nnvm/compiler/__init__.py b/nnvm/python/nnvm/compiler/__init__.py index 993d0ed7c8a2..1d8b9219b4f7 100644 --- a/nnvm/python/nnvm/compiler/__init__.py +++ b/nnvm/python/nnvm/compiler/__init__.py @@ -4,7 +4,7 @@ import tvm from . import build_module -from . build_module import build, precompute_prune, _run_graph +from . build_module import build, optimize, build_config from .. import symbol as _symbol from .. import graph as _graph diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index 3b053e2fd806..891a0b65729d 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -3,10 +3,74 @@ from __future__ import absolute_import as _abs import tvm -from . import graph_attr, graph_pass +from . import graph_attr, graph_util from .. import graph as _graph from .. import runtime +OPT_PASS_LEVEL = { + "SimplifyBatchNormInference": 2, + "PrecomputePrune": 2, + "OpFusion": 1 +} + +# List of optimization pass and level when switch on +class BuildConfig(object): + """Configuration scope to set a build config option. + + Parameters + ---------- + kwargs + Keyword arguments of configurations to set. + """ + current = None + defaults = { + "opt_level": 2, + } + def __init__(self, **kwargs): + self._old_scope = None + for k, _ in kwargs.items(): + if k not in BuildConfig.defaults: + raise ValueError( + "invalid argument %s, candidates are %s" % (k, BuildConfig.defaults.keys())) + self._attr = kwargs + + def __getattr__(self, name): + if name not in self._attr: + return BuildConfig.defaults[name] + return self._attr[name] + + def __enter__(self): + # pylint: disable=protected-access + self._old_scope = BuildConfig.current + attr = BuildConfig.current._attr.copy() + attr.update(self._attr) + self._attr = attr + BuildConfig.current = self + return self + + def __exit__(self, ptype, value, trace): + assert self._old_scope + BuildConfig.current = self._old_scope + + +BuildConfig.current = BuildConfig() + +def build_config(**kwargs): + """Configure the build behavior by setting config variables. + + Parameters + ---------- + opt_level: int, default=2 + Optimization level. See OPT_PASS_LEVEL for level of each pass. + + Returns + ------- + config: BuildConfig + The build configuration + """ + return BuildConfig(**kwargs) + + @tvm.register_func("nnvm.compiler.lower") def _lower(sch, inputs, func_name): f = tvm.lower(sch, inputs, name=func_name) @@ -19,23 +83,45 @@ def _build(funcs, target): return tvm.build(funcs, target=target) -def optimize(graph): - """Perform graph optimization +def _update_shape_dtype(shape, dtype, params): + """Update shape dtype given params information""" + if not params: + return shape, dtype + shape = shape.copy() + shape.update({k : v.shape for k, v in params.items()}) + if isinstance(dtype, str): + for k, v in params.items(): + if v.dtype != dtype: + raise ValueError( + "%s: dtype not expected %s vs %s" % (k, dtype, v.dtype)) + else: + dtype = dtype.copy() + dtype.update({k : str(v.dtype) for k, v in params.items()}) + return shape, dtype + + +def optimize(graph, shape, dtype="float32"): + """Perform target and parameter invariant graph optimization. Parameters ---------- graph : Graph - The graph to be used in lowering. + The graph to be used in optimized. Returns ------- graph : Graph - The optimized execution graph. + The optimized graph. """ + # pylint: disable=unused-argument + cfg = BuildConfig.current + if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyBatchNormInference"]: + graph = graph_attr.set_shape_inputs(graph, shape) + graph = graph.apply(["InferShape", "SimplifyBatchNormInference"]) return graph -def build(graph, target, shape, dtype="float32"): +def build(graph, target, shape, dtype="float32", params=None): """Build graph into runtime library. This is the final step of graph compilation. @@ -54,6 +140,11 @@ def build(graph, target, shape, dtype="float32"): dtype : str or dict of str to str The input types to the graph + params : dict of str to NDArray + Input parameetrs to the graph that do not change + during inference time. Used for pre-compute + folding optimization. + Returns ------- graph : Graph @@ -61,20 +152,33 @@ def build(graph, target, shape, dtype="float32"): libmod : tvm.Module The modue that comes with the execution graph + + params : dict of str to NDArray + The updated parameters of graph if params is passed. + This can be different from the params passed in. """ if not isinstance(target, str): raise TypeError("require target to be str") if not isinstance(shape, dict): raise TypeError("require shape to be dict") - + cfg = BuildConfig.current graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph) + shape, dtype = _update_shape_dtype(shape, dtype, params) + # Apply optimization + graph = optimize(graph, shape, dtype) + # Precompute prune + if params and cfg.opt_level >= OPT_PASS_LEVEL["PrecomputePrune"]: + graph, params = precompute_prune(graph, params) + shape, dtype = _update_shape_dtype(shape, dtype, params) + # Operator Fusion and generatiom graph = graph_attr.set_shape_inputs(graph, shape) graph = graph_attr.set_dtype_inputs(graph, dtype) graph._set_json_attr("target", target, "str") + graph._set_json_attr("opt_level", cfg.opt_level, "int") graph = graph.apply("InferShape").apply("InferType") graph = graph.apply("GraphFusePartition").apply("GraphFuse") libmod = graph_attr._move_out_module(graph, "module") - return graph, libmod + return graph, libmod, params def _run_graph(graph, params): @@ -98,9 +202,9 @@ def _run_graph(graph, params): dtype = {k : v.dtype for k, v in params.items()} target = "llvm" ctx = tvm.cpu(0) - _, oshape = graph_pass.infer_shape(graph, **shape) - _, odtype = graph_pass.infer_dtype(graph, **dtype) - graph, libmod = build(graph, target, shape, dtype) + _, oshape = graph_util.infer_shape(graph, **shape) + _, odtype = graph_util.infer_dtype(graph, **dtype) + graph, libmod, _ = build(graph, target, shape, dtype) m = runtime.create(graph, libmod, ctx) set_input, run, get_output = m["set_input"], m["run"], m["get_output"] for k, v in params.items(): diff --git a/nnvm/python/nnvm/compiler/graph_pass.py b/nnvm/python/nnvm/compiler/graph_pass.py index 3e98615d8ff2..a37e83a2c5c0 100644 --- a/nnvm/python/nnvm/compiler/graph_pass.py +++ b/nnvm/python/nnvm/compiler/graph_pass.py @@ -6,81 +6,3 @@ - Composable API: break graph transformation pass as segments of small transformations. """ from __future__ import absolute_import as _abs - -import tvm -from . import graph_attr - - -def infer_shape(graph, **shape): - """Infer the shape given the shape of inputs. - - Parameters - ---------- - graph : Graph - The graph to perform shape inference from - - Returns - ------- - in_shape : list of tuple - Shape of inputs - - out_shape: list of tuple - Shape of outputs - """ - graph = graph_attr.set_shape_inputs(graph, shape) - graph = graph.apply("InferShape") - shape = graph.json_attr("shape") - index = graph.index - input_shape = [shape[index.entry_id(x)] for x in index.input_names] - output_shape = [shape[index.entry_id(x)] for x in index.output_entries] - return input_shape, output_shape - - -def infer_dtype(graph, **dtype): - """Infer the type given the typeS of inputs. - - Parameters - ---------- - graph : Graph - The graph to perform type inference from - - Returns - ------- - in_dtype : list of tuple - Dtype of inputs - - out_dtype: list of tuple - Dtype of outputs - """ - graph = graph_attr.set_dtype_inputs(graph, dtype) - graph = graph.apply("InferType") - dtype = graph.json_attr("dtype") - index = graph.index - input_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]] - for x in index.input_names] - output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]] - for x in index.output_entries] - return input_dtype, output_dtype - - -_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare") - -def check_graph_equal(grapha, graphb): - """Check if two graphs have equal structure. - - Parameters - ---------- - grapha : Graph - The first graph - - graphb : Graph - The second graph - - Raises - ------ - ValueError - ValueError is raised with error message when graph not equal - """ - err = _deep_compare(grapha, graphb) - if err: - raise ValueError("Graph compare error: " + err) diff --git a/nnvm/python/nnvm/compiler/graph_util.py b/nnvm/python/nnvm/compiler/graph_util.py new file mode 100644 index 000000000000..fcca00b0abe0 --- /dev/null +++ b/nnvm/python/nnvm/compiler/graph_util.py @@ -0,0 +1,80 @@ +# pylint: disable=invalid-name +"""Utility function to get information from graph.""" +from __future__ import absolute_import as _abs + +import tvm +from . import graph_attr + +def infer_shape(graph, **shape): + """Infer the shape given the shape of inputs. + + Parameters + ---------- + graph : Graph + The graph to perform shape inference from + + Returns + ------- + in_shape : list of tuple + Shape of inputs + + out_shape: list of tuple + Shape of outputs + """ + graph = graph_attr.set_shape_inputs(graph, shape) + graph = graph.apply("InferShape") + shape = graph.json_attr("shape") + index = graph.index + input_shape = [shape[index.entry_id(x)] for x in index.input_names] + output_shape = [shape[index.entry_id(x)] for x in index.output_entries] + return input_shape, output_shape + + +def infer_dtype(graph, **dtype): + """Infer the type given the typeS of inputs. + + Parameters + ---------- + graph : Graph + The graph to perform type inference from + + Returns + ------- + in_dtype : list of tuple + Dtype of inputs + + out_dtype: list of tuple + Dtype of outputs + """ + graph = graph_attr.set_dtype_inputs(graph, dtype) + graph = graph.apply("InferType") + dtype = graph.json_attr("dtype") + index = graph.index + input_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]] + for x in index.input_names] + output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]] + for x in index.output_entries] + return input_dtype, output_dtype + + +_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare") + +def check_graph_equal(grapha, graphb): + """Check if two graphs have equal structure. + + Parameters + ---------- + grapha : Graph + The first graph + + graphb : Graph + The second graph + + Raises + ------ + ValueError + ValueError is raised with error message when graph not equal + """ + err = _deep_compare(grapha, graphb) + if err: + raise ValueError("Graph compare error: " + err) diff --git a/nnvm/python/nnvm/runtime.py b/nnvm/python/nnvm/runtime.py index 733514db71af..dd9866cb4b3a 100644 --- a/nnvm/python/nnvm/runtime.py +++ b/nnvm/python/nnvm/runtime.py @@ -2,6 +2,82 @@ import tvm from tvm.contrib import rpc +class Module(object): + """Wrapper runtime module. + + This is a thin wrapper of the underlying TVM module. + you can also directly call set_input, run, and get_output + of underlying module functions + + Parameters + ---------- + tvm_module : tvm.Module + The interal tvm module + """ + def __init__(self, tvm_module): + self.tvm_module = tvm_module + self._set_input = tvm_module["set_input"] + self._run = tvm_module["run"] + self._get_output = tvm_module["get_output"] + + def set_input(self, key=None, value=None, **params): + """Set inputs to the module via kwargs + + Parameters + ---------- + key : int or str + The input key + + value : the input value. + The input key + + params : dict of str to NDArray + Additonal arguments + """ + if key: + self._set_input(key, tvm.nd.array(value)) + for k, v in params.items(): + self._set_input(k, tvm.nd.array(v)) + return self + + def run(self, **input_dict): + """Run forward execution of the graph + + Parameters + ---------- + input_dict: dict of str to NDArray + List of input values to be feed to + """ + if input_dict: + self.set_input(**input_dict) + self._run() + + def get_output(self, index, out): + """Get index-th output to out + + Parameters + ---------- + index : int + The input index + + out : tvm.NDArray + The output array container + """ + self._get_output(index, out) + return out + + def __getitem__(self, key): + """Get internal module function + + Parameters + ---------- + key : str + The key to the module. + """ + return self.tvm_module[key] + + + def create(graph, libmod, ctx): """Create a runtime executor module given the graph and module. @@ -30,7 +106,6 @@ def create(graph, libmod, ctx): hmod = rpc._ModuleHandle(libmod) fcreate = ctx._rpc_sess.get_function("nnvm.runtime.remote_create") device_type = device_type % rpc.RPC_SESS_MASK - return fcreate(json_str, hmod, device_type, device_id) - + return Module(fcreate(json_str, hmod, device_type, device_id)) fcreate = tvm.get_global_func("nnvm.runtime.create") - return fcreate(json_str, libmod, device_type, device_id) + return Module(fcreate(json_str, libmod, device_type, device_id)) diff --git a/nnvm/src/compiler/graph_fuse.cc b/nnvm/src/compiler/graph_fuse.cc index 9496e110ba64..acf0f5677187 100644 --- a/nnvm/src/compiler/graph_fuse.cc +++ b/nnvm/src/compiler/graph_fuse.cc @@ -41,6 +41,11 @@ DLDataType GetDLType(int type_flag) { nnvm::Graph GraphFusePartition(nnvm::Graph g) { // setup ref counter const IndexedGraph& idx = g.indexed_graph(); + int opt_level = 2; + if (g.attrs.count("opt_level") != 0) { + opt_level = g.MoveCopyAttr("opt_level"); + } + // Get attributes from the graph const ShapeVector& shape_vec = g.GetAttr("shape"); const DTypeVector& dtype_vec = g.GetAttr("dtype"); @@ -65,7 +70,7 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { // this line will realize all the outputs ref_count[e.node_id] += 2; } - // Pattern fo the subgraph + // Pattern for the subgraph std::vector pattern_vec(idx.num_nodes(), kExtern); // Whether node can be fused to parent. std::vector fuse_vec(idx.num_nodes(), FuseRule::kUknown); @@ -123,7 +128,7 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { } pattern_vec[nid] = pt; - if (ref_count[nid] > 1) { + if (ref_count[nid] > 1 || opt_level < 1) { fuse_vec[nid] = FuseRule::kRealize; if (master_vec[nid] == -1) { master_vec[nid] = nid; diff --git a/nnvm/src/compiler/packed_func_ext.cc b/nnvm/src/compiler/packed_func_ext.cc index 5444ece9c110..e67f74312605 100644 --- a/nnvm/src/compiler/packed_func_ext.cc +++ b/nnvm/src/compiler/packed_func_ext.cc @@ -20,6 +20,10 @@ TVM_REGISTER_EXT_TYPE(nnvm::compiler::AttrDict); } // namespace runtime } // namespace tvm +namespace nnvm { +DMLC_JSON_ENABLE_ANY(int, int); +} // namespace nnvm + namespace nnvm { namespace compiler { diff --git a/nnvm/tests/python/compiler/test_build.py b/nnvm/tests/python/compiler/test_build.py index 3abe94913a8f..bb10e8400bb1 100644 --- a/nnvm/tests/python/compiler/test_build.py +++ b/nnvm/tests/python/compiler/test_build.py @@ -4,6 +4,7 @@ import nnvm.symbol as sym import nnvm.compiler import nnvm.runtime +from nnvm.compiler.build_module import _run_graph, precompute_prune def test_compile(): x = sym.Variable("x") @@ -12,23 +13,34 @@ def test_compile(): shape = (10, 128) dtype = tvm.float32 shape_dict = {"x": shape, "y": shape} + def verify(graph, lib): + m = nnvm.runtime.create(graph, lib, tvm.cpu(0)) + # get member functions + set_input, run, get_output = m["set_input"], m["run"], m["get_output"] + na = tvm.nd.array(np.random.uniform(size=shape).astype(dtype)) + nb = tvm.nd.array(np.random.uniform(size=shape).astype(dtype)) + # set inputs + set_input("x", na) + set_input("y", nb) + # execute + run() + # get outputs + out = tvm.nd.empty(shape, dtype) + get_output(0, out) + np.testing.assert_allclose( + out.asnumpy(), np.exp(na.asnumpy() + nb.asnumpy())) + + graph, lib, _ = nnvm.compiler.build(z, "llvm", shape_dict) + assert graph.index.num_nodes == 3 + verify(graph, lib) + + with nnvm.compiler.build_config(opt_level=0): + graph, lib, _ = nnvm.compiler.build(z, "llvm", shape_dict) + # print(graph.ir()) + assert graph.index.num_nodes == 4 + verify(graph, lib) + - graph, lib = nnvm.compiler.build(z, "llvm", shape_dict) - m = nnvm.runtime.create(graph, lib, tvm.cpu(0)) - # get member functions - set_input, run, get_output = m["set_input"], m["run"], m["get_output"] - na = tvm.nd.array(np.random.uniform(size=shape).astype(dtype)) - nb = tvm.nd.array(np.random.uniform(size=shape).astype(dtype)) - # set inputs - set_input("x", na) - set_input("y", nb) - # execute - run() - # get outputs - out = tvm.nd.empty(shape, dtype) - get_output(0, out) - np.testing.assert_allclose( - out.asnumpy(), np.exp(na.asnumpy() + nb.asnumpy())) def test_run(): @@ -39,7 +51,7 @@ def test_run(): dtype = tvm.float32 nx = tvm.nd.array(np.random.uniform(size=shape).astype(dtype)) ny = tvm.nd.array(np.random.uniform(size=shape).astype(dtype)) - res = nnvm.compiler._run_graph(z, {"x": nx, "y": ny}) + res = _run_graph(z, {"x": nx, "y": ny}) np.testing.assert_allclose( res[0].asnumpy(), np.exp(nx.asnumpy() + ny.asnumpy())) @@ -53,11 +65,16 @@ def test_precompute_prune(): nx = tvm.nd.array(np.random.uniform(size=shape).astype(dtype)) ny = tvm.nd.array(np.random.uniform(size=shape).astype(dtype)) params = {"x": nx} - graph, pdict = nnvm.compiler.precompute_prune(z, params) - pdict["y"] = ny - res = nnvm.compiler._run_graph(z, pdict) + graph, lib, params = nnvm.compiler.build( + z, "llvm", shape={"y": ny.shape}, params=params) + assert graph.index.num_nodes == 3 + m = nnvm.runtime.create(graph, lib, tvm.cpu(0)) + params["y"] = ny + res = tvm.nd.empty(shape) + m.run(**params) + out = m.get_output(0, out=res) np.testing.assert_allclose( - res[0].asnumpy(), nx.asnumpy() + 1 + ny.asnumpy()) + res.asnumpy(), nx.asnumpy() + 1 + ny.asnumpy()) if __name__ == "__main__": diff --git a/nnvm/tests/python/compiler/test_graph_pass.py b/nnvm/tests/python/compiler/test_graph_pass.py index 9f2c36a1d3f9..ec5ab0479389 100644 --- a/nnvm/tests/python/compiler/test_graph_pass.py +++ b/nnvm/tests/python/compiler/test_graph_pass.py @@ -2,16 +2,16 @@ import nnvm import nnvm.compiler from nnvm import symbol as sym -from nnvm.compiler import graph_pass, graph_attr +from nnvm.compiler import graph_util, graph_attr def test_infer_attr(): x = sym.Variable("x") y = x * 2 g = nnvm.graph.create(y) - ishape, oshape = graph_pass.infer_shape(g, x=(10,20)) + ishape, oshape = graph_util.infer_shape(g, x=(10,20)) assert tuple(oshape[0]) == (10, 20) - itype, otype = graph_pass.infer_dtype(g, x="float32") + itype, otype = graph_util.infer_dtype(g, x="float32") assert otype[0] == "float32" if __name__ == "__main__": diff --git a/nnvm/tests/python/compiler/test_rpc_exec.py b/nnvm/tests/python/compiler/test_rpc_exec.py index 89f67d5945e6..4a94a1d7686e 100644 --- a/nnvm/tests/python/compiler/test_rpc_exec.py +++ b/nnvm/tests/python/compiler/test_rpc_exec.py @@ -19,7 +19,7 @@ def test_rpc_executor(): tmp = util.tempdir() lib_name = tmp.relpath("net.o") - graph, lib = nnvm.compiler.build(z, "llvm", shape_dict) + graph, lib, _ = nnvm.compiler.build(z, "llvm", shape_dict) # save module lib.save(lib_name) remote = rpc.connect(host, port) diff --git a/nnvm/tests/python/compiler/test_simplify_batchnorm.py b/nnvm/tests/python/compiler/test_simplify_batchnorm.py index 67d582769ef2..ec6dfb86ac47 100644 --- a/nnvm/tests/python/compiler/test_simplify_batchnorm.py +++ b/nnvm/tests/python/compiler/test_simplify_batchnorm.py @@ -1,7 +1,7 @@ """Unittest cases for simplify batch_norm""" import nnvm from nnvm import symbol as sym -from nnvm.compiler import graph_pass, graph_attr +from nnvm.compiler import graph_util, graph_attr def test_simplify_batchnorm(): def simple_bn(x, gamma, beta, moving_mean, moving_var, @@ -40,7 +40,7 @@ def check(dim, axis, nstep): # Some prints for debug # print(g1.ir()) # assert graph equals as expected - graph_pass.check_graph_equal(g1, g2) + graph_util.check_graph_equal(g1, g2) check(2, 1, 1) check(4, 0, 3) diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py index 88aeb3459886..14108d18db0e 100644 --- a/nnvm/tests/python/compiler/test_top_level1.py +++ b/nnvm/tests/python/compiler/test_top_level1.py @@ -26,7 +26,7 @@ def test_relu(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape}) + graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape}) m = nnvm.runtime.create(graph, lib, default_ctx()) # get member functions set_input, run, get_output = m["set_input"], m["run"], m["get_output"] @@ -48,7 +48,7 @@ def test_exp(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape}) + graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape}) m = nnvm.runtime.create(graph, lib, default_ctx()) # get member functions set_input, run, get_output = m["set_input"], m["run"], m["get_output"] @@ -70,7 +70,8 @@ def test_log(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape}) + with nnvm.compiler.build_config(opt_level=1): + graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape}) m = nnvm.runtime.create(graph, lib, default_ctx()) # get member functions set_input, run, get_output = m["set_input"], m["run"], m["get_output"] @@ -92,7 +93,8 @@ def test_tanh(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape}) + with nnvm.compiler.build_config(opt_level=1): + graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape}) m = nnvm.runtime.create(graph, lib, default_ctx()) # get member functions set_input, run, get_output = m["set_input"], m["run"], m["get_output"] @@ -114,7 +116,7 @@ def test_sigmoid(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape}) + graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape}) m = nnvm.runtime.create(graph, lib, default_ctx()) # get member functions set_input, run, get_output = m["set_input"], m["run"], m["get_output"] @@ -136,7 +138,8 @@ def test_softmax(): dtype = "float32" dshape = (10, 1000) oshape = dshape - graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape}) + with nnvm.compiler.build_config(opt_level=1): + graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape}) m = nnvm.runtime.create(graph, lib, default_ctx()) # get member functions set_input, run, get_output = m["set_input"], m["run"], m["get_output"] diff --git a/nnvm/tests/python/compiler/test_top_level2.py b/nnvm/tests/python/compiler/test_top_level2.py index 133b22fdcb65..32d84158b336 100644 --- a/nnvm/tests/python/compiler/test_top_level2.py +++ b/nnvm/tests/python/compiler/test_top_level2.py @@ -29,7 +29,7 @@ def test_conv2d(): kshape = (10, 3, 3, 3) oshape = (1, 10, 18, 18) shape_dict = {"x": dshape} - graph, lib = nnvm.compiler.build(y, default_target(), shape_dict) + graph, lib, _ = nnvm.compiler.build(y, default_target(), shape_dict) m = nnvm.runtime.create(graph, lib, default_ctx()) # get member functions set_input, run, get_output = m["set_input"], m["run"], m["get_output"] @@ -57,7 +57,7 @@ def test_grouped_conv2d(): kshape = (32, 1, 3, 3) oshape = (1, 32, 18, 18) shape_dict = {"x": dshape} - graph, lib = nnvm.compiler.build(y, default_target(), shape_dict) + graph, lib, _ = nnvm.compiler.build(y, default_target(), shape_dict) m = nnvm.runtime.create(graph, lib, default_ctx()) # get member functions set_input, run, get_output = m["set_input"], m["run"], m["get_output"]