From fa351045e6a0fd2f4ab57506e4899e525b3c7cef Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Mon, 17 Jun 2019 09:55:08 -0700 Subject: [PATCH] [relay][frontend] Return module from frontend parsers (#3353) --- python/tvm/relay/backend/interpreter.py | 79 ++++++----- python/tvm/relay/backend/vm.py | 8 +- python/tvm/relay/build_module.py | 11 +- python/tvm/relay/frontend/caffe2.py | 14 +- python/tvm/relay/frontend/coreml.py | 7 +- python/tvm/relay/frontend/darknet.py | 8 +- python/tvm/relay/frontend/keras.py | 7 +- python/tvm/relay/frontend/mxnet.py | 20 ++- python/tvm/relay/frontend/onnx.py | 16 ++- python/tvm/relay/frontend/tensorflow.py | 19 +-- python/tvm/relay/frontend/tflite.py | 7 +- tests/python/frontend/caffe2/test_forward.py | 5 +- tests/python/frontend/caffe2/test_graph.py | 5 +- tests/python/frontend/coreml/test_forward.py | 8 +- tests/python/frontend/darknet/test_forward.py | 6 +- tests/python/frontend/keras/test_forward.py | 6 +- tests/python/frontend/mxnet/test_forward.py | 128 +++++++++--------- tests/python/frontend/mxnet/test_graph.py | 32 ++--- tests/python/frontend/onnx/test_forward.py | 6 +- .../frontend/tensorflow/test_control_flow.py | 6 +- .../frontend/tensorflow/test_forward.py | 11 +- tests/python/frontend/tflite/test_forward.py | 10 +- tests/python/relay/test_vm.py | 4 +- tutorials/frontend/deploy_model_on_android.py | 4 +- tutorials/frontend/deploy_model_on_rasp.py | 3 +- tutorials/frontend/deploy_ssd_gluoncv.py | 4 +- tutorials/frontend/from_caffe2.py | 4 +- tutorials/frontend/from_coreml.py | 6 +- tutorials/frontend/from_darknet.py | 7 +- tutorials/frontend/from_keras.py | 6 +- tutorials/frontend/from_mxnet.py | 7 +- tutorials/frontend/from_onnx.py | 6 +- tutorials/frontend/from_tensorflow.py | 9 +- tutorials/frontend/from_tflite.py | 8 +- 34 files changed, 271 insertions(+), 216 deletions(-) diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index ea25b970f87f..c54a65b78fb2 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -21,7 +21,8 @@ import numpy as np from . import _backend -from .. import _make, ir_pass +from .. import _make, ir_pass, transform +from .. import module from ... import register_func, nd from ..base import NodeBase, register_relay_node from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const @@ -191,14 +192,14 @@ def _convert_args(self, expr, args, kwargs): return tuple(cargs) - def _make_executor(self, _): + def _make_executor(self, expr=None): """ Construct a Python function that implements the evaluation of expression. Parameters ---------- - expr: relay.Expr + expr: Optional[relay.Expr] The Relay expression to execute. Returns @@ -208,16 +209,16 @@ def _make_executor(self, _): """ raise NotImplementedError() - def evaluate(self, expr, binds=None): + def evaluate(self, expr=None, binds=None): """ Evaluate a Relay expression on the executor. Parameters ---------- - expr: tvm.relay.Expr + expr: Optional[tvm.relay.Expr] The expression to evaluate. - binds: Map[tvm.relay.Var, tvm.relay.Expr] + binds: Optional[Map[tvm.relay.Var, tvm.relay.Expr]] Additional binding of free variable. Returns @@ -232,6 +233,9 @@ def evaluate(self, expr, binds=None): scope_builder.ret(expr) expr = scope_builder.get() + if not expr: + return self._make_executor() + if isinstance(expr, Function): assert not ir_pass.free_vars(expr) @@ -264,46 +268,47 @@ def __init__(self, mod, ctx, target): self.target = target self._intrp = _backend.CreateInterpreter(mod, ctx, target) - def optimize(self, expr): - """Optimize an expr. - - Parameters - ---------- - expr : Expr - The expression to be optimized. + def optimize(self): + """Optimize functions in a module. Returns ------- - opt_expr : Expr - The optimized expression. + opt_mod : tvm.relay.Module + The optimized module. """ - # TODO: We need to move this optimization code into the optimizer/pass manager - wrapped_expr = expr if isinstance(expr, Function) else Function([], expr) - if self.mod: - self.mod[self.mod.entry_func] = wrapped_expr - ck_expr = ir_pass.infer_type(wrapped_expr, mod=self.mod) - simp_expr = ir_pass.simplify_inference(ck_expr) - ck_simp = ir_pass.infer_type(simp_expr, mod=self.mod) - fused_expr = ir_pass.fuse_ops(ck_simp, 0, mod=self.mod) - ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod) - return ck_fused if isinstance(expr, Function) else Call(ck_fused, []) - - def _make_executor(self, expr): + seq = transform.Sequential([transform.SimplifyInference(), + transform.FuseOps(0), + transform.InferType()]) + return seq(self.mod) + + def _make_executor(self, expr=None): + if expr is None or isinstance(expr, GlobalVar): + assert self.mod is not None def _interp_wrapper(*args, **kwargs): - args = self._convert_args(expr, args, kwargs) + if expr is None: + args = self._convert_args(self.mod[self.mod.entry_func], args, kwargs) + else: + args = self._convert_args(expr, args, kwargs) relay_args = [] for arg in args: relay_args.append(_arg_to_ast(arg)) - if isinstance(expr, GlobalVar): - func = self.mod[expr] - func = self.optimize(func) - self.mod._add(expr, func, True) - opt_expr = Call(expr, relay_args) - return self._intrp(opt_expr) + # Set the entry function for the module. + if expr is None: + pass + elif isinstance(expr, GlobalVar): + self.mod[self.mod.entry_func] = self.mod[expr] else: - call = Call(expr, relay_args) - opt_expr = self.optimize(call) - return self._intrp(opt_expr) + assert isinstance(expr, Function) + func = Function([], Call(expr, relay_args)) + relay_args = [] + if self.mod: + self.mod[self.mod.entry_func] = func + else: + self.mod = module.Module.from_expr(func) + + mod = self.optimize() + opt_expr = Call(mod[self.mod.entry_func.name_hint], relay_args) + return self._intrp(opt_expr) return _interp_wrapper diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 4cb3d611abd4..ceb403fe7717 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -130,9 +130,11 @@ def __init__(self, mod, ctx, target): self.ctx = ctx self.target = target - def _make_executor(self, expr): - assert isinstance(expr, Expr) - self.mod[self.mod.entry_func] = expr + def _make_executor(self, expr=None): + expr = expr if expr else self.mod + assert expr, "either expr or self.mod should be not null." + if isinstance(expr, Expr): + self.mod[self.mod.entry_func] = expr main = self.mod[self.mod.entry_func] def _vm_wrapper(*args, **kwargs): diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 8f9b0481a22c..1aa4d5ae57c4 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -219,16 +219,19 @@ def __init__(self, mod, ctx, target): self.ctx = ctx self.target = target - def _make_executor(self, func): - ret_type = ir_pass.infer_type(func).ret_type + def _make_executor(self, expr=None): + if not expr: + assert self.mod, "either expr or self.mod should be not null." + expr = self.mod[self.mod.entry_func] + ret_type = ir_pass.infer_type(expr).ret_type num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1 - graph_json, mod, params = build(func, target=self.target) + graph_json, mod, params = build(expr, target=self.target) gmodule = _graph_rt.create(graph_json, mod, self.ctx) if params: gmodule.set_input(**params) def _graph_wrapper(*args, **kwargs): - args = self._convert_args(func, args, kwargs) + args = self._convert_args(expr, args, kwargs) # Create map of inputs. for i, arg in enumerate(args): gmodule.set_input(i, arg) diff --git a/python/tvm/relay/frontend/caffe2.py b/python/tvm/relay/frontend/caffe2.py index eb8e717bb343..18489b380ee7 100644 --- a/python/tvm/relay/frontend/caffe2.py +++ b/python/tvm/relay/frontend/caffe2.py @@ -20,6 +20,7 @@ import tvm from .. import ir_pass from .. import expr as _expr +from .. import module as _module from .. import op as _op from ... import nd as _nd from .common import AttrCvt, Renamer @@ -382,6 +383,7 @@ def __init__(self, shape, dtype): self._ops = {} self._shape = shape self._dtype = dtype + self._mod = _module.Module({}) def from_caffe2(self, init_net, predict_net): """Construct Relay expression from caffe2 graph. @@ -393,8 +395,9 @@ def from_caffe2(self, init_net, predict_net): Returns ------- - func : tvm.relay.expr.Function - Compatible relay function + mod : tvm.relay.Module + The module that optimizations will be performed on. + params : dict A dict of name: tvm.nd.array pairs, used as pretrained weights """ @@ -448,8 +451,9 @@ def from_caffe2(self, init_net, predict_net): outputs = out[0] func = _expr.Function(ir_pass.free_vars(outputs), outputs) + self._mod[self._mod.entry_func] = func - return func, self._params + return self._mod, self._params def _get_node(self, blob): """Get the Symbol of blob and detect cyclic dependency in the graph.""" @@ -560,8 +564,8 @@ def from_caffe2(init_net, predict_net, shape=None, dtype="float32"): Returns ------- - sym : tvm.relay.expr.Function - Compatible relay function + mod : tvm.relay.Module + The module that optimizations will be performed on. params : dict of str to tvm.ndarray Dict of converted parameters stored in tvm.ndarray format diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index 653df92b71fc..1cac547d07c9 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -21,6 +21,7 @@ import tvm from .. import ir_pass from .. import expr as _expr +from .. import module as _module from .. import op as _op from ... import nd as _nd from ..._ffi import base as _base @@ -416,8 +417,8 @@ def from_coreml(model, shape=None): Returns ------- - func : tvm.relay.Function - Compatible relay Function. + mod : tvm.relay.Module + The relay module for compilation. params : dict of str to tvm.NDArray The parameter dict to be used by Relay. @@ -463,4 +464,4 @@ def from_coreml(model, shape=None): outexpr = outexpr[0] func = _expr.Function(ir_pass.free_vars(outexpr), outexpr) params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} - return func, params + return _module.Module.from_expr(func), params diff --git a/python/tvm/relay/frontend/darknet.py b/python/tvm/relay/frontend/darknet.py index 6da3525eec21..7b26ed5692df 100644 --- a/python/tvm/relay/frontend/darknet.py +++ b/python/tvm/relay/frontend/darknet.py @@ -25,6 +25,7 @@ import tvm from .. import ir_pass from .. import expr as _expr +from .. import module as _module from .common import get_relay_op, new_var __all__ = ['from_darknet'] @@ -820,7 +821,7 @@ def from_darknet(self): outputs = _as_list(sym) + self._outs outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) sym = _expr.Function(ir_pass.free_vars(outputs), outputs) - return sym, self._tvmparams + return _module.Module.from_expr(sym), self._tvmparams def from_darknet(net, shape=None, @@ -838,8 +839,9 @@ def from_darknet(net, Returns ------- - sym : tvm.relay.Function - Compatible relay Function + mod : tvm.relay.Module + The relay module for compilation. + params : dict of str to tvm.NDArray The parameter dict to be used by relay """ diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 5d5e50ff3559..ad033f9bf326 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -22,6 +22,7 @@ import tvm from .. import ir_pass from .. import expr as _expr +from .. import module as _module from .. import op as _op from ... import nd as _nd from .common import ExprTable, new_var @@ -679,8 +680,8 @@ def from_keras(model, shape=None): Returns ------- - func : tvm.relay.Function - Compatible relay Function. + mod : tvm.relay.Module + The relay module for compilation. params : dict of str to tvm.NDArray The parameter dict to be used by Relay. @@ -744,4 +745,4 @@ def _convert_input_layer(keras_layer): outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr) func = _expr.Function(ir_pass.free_vars(outexpr), outexpr) params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} - return func, params + return _module.Module.from_expr(func), params diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index ff5f81dc7069..00cbc7067f98 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -23,6 +23,7 @@ from .. import ir_pass from .. import expr as _expr from .. import op as _op +from .. import module as _module from ... import nd as _nd from .common import StrAttrsDict @@ -992,7 +993,8 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): _convert_map.update({k : _rename(k) for k in _identity_list}) -def _from_mxnet_impl(symbol, shape_dict, dtype_info): +def _from_mxnet_impl(symbol, shape_dict, dtype_info, mod=None): + #pylint: disable=unused-argument """Convert mxnet symbol to compatible relay Function. Reconstruct a relay Function by traversing the mxnet symbol. @@ -1009,6 +1011,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): dtype_info : dict or str. Known parameter dtypes + mod : tvm.relay.Module + The module that contains global information. It will be used for + converting ops that need global information, e.g. control-flow ops. + Returns: ------- func : tvm.relay.Function @@ -1097,8 +1103,8 @@ def from_mxnet(symbol, Returns ------- - sym : tvm.relay.Function - Compatible relay Function + mod : tvm.relay.Module + The relay module for compilation params : dict of str to tvm.NDArray The parameter dict to be used by nnvm @@ -1108,6 +1114,7 @@ def from_mxnet(symbol, except ImportError as e: raise ImportError("{}. MXNet is required to parse symbols.".format(e)) + mod = _module.Module() if isinstance(symbol, mx.sym.Symbol): params = {} arg_params = arg_params if arg_params else {} @@ -1117,7 +1124,7 @@ def from_mxnet(symbol, for k, v in aux_params.items(): params[k] = _nd.array(v.asnumpy()) shape, dtype = _update_shape_dtype(shape, dtype, params) - sym = _from_mxnet_impl(symbol, shape, dtype) + func = _from_mxnet_impl(symbol, shape, dtype, mod) elif isinstance(symbol, mx.gluon.HybridBlock): if arg_params is not None or aux_params is not None: raise ValueError("arg_params and aux_params ae not used when importing HybridBlock") @@ -1129,10 +1136,11 @@ def from_mxnet(symbol, if isinstance(sym, (list, tuple)): sym = mx.sym.Group(sym) shape, dtype = _update_shape_dtype(shape, dtype, params) - sym = _from_mxnet_impl(sym, shape, dtype) + func = _from_mxnet_impl(sym, shape, dtype, mod) elif isinstance(symbol, mx.gluon.Block): raise NotImplementedError("Only Hybrid Blocks are supported now.") else: msg = "mxnet.Symbol or gluon.HybridBlock expected, got {}".format(type(symbol)) raise ValueError(msg) - return sym, params + mod[mod.entry_func] = func + return mod, params diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 98ff10bd8318..468a7486ca5c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -24,6 +24,7 @@ from ... import nd as _nd from .. import ir_pass from .. import expr as _expr +from .. import module as _module from .. import op as _op from .common import AttrCvt, Renamer from .common import get_relay_op, new_var, infer_shape, infer_channels, get_name @@ -999,8 +1000,9 @@ def from_onnx(self, graph, opset): Returns ------- - sym : tvm.relay.expr.Function - The returned relay function + mod : tvm.relay.Module + The returned relay module + params : dict A dict of name: tvm.nd.array pairs, used as pretrained weights """ @@ -1090,7 +1092,7 @@ def from_onnx(self, graph, opset): outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) func = _expr.Function(ir_pass.free_vars(outputs), outputs) - return func, self._params + return _module.Module.from_expr(func), self._params def _parse_value_proto(self, value_proto): """Parse ValueProto or raw str.""" @@ -1219,8 +1221,8 @@ def from_onnx(model, Returns ------- - sym : tvm.relay.expr.Function - Compatible relay function + mod : tvm.relay.Module + The relay module for compilation params : dict of str to tvm.NDArray The parameter dict to be used by relay @@ -1243,5 +1245,5 @@ def from_onnx(model, opset = model.opset_import[0].version if model.opset_import else 1 except AttributeError: opset = 1 - sym, params = g.from_onnx(graph, opset) - return sym, params + mod, params = g.from_onnx(graph, opset) + return mod, params diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 866a6228980e..c0df8e679153 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -31,6 +31,7 @@ from .. import expr as _expr from .. import op as _op from ..expr_functor import ExprMutator +from .. import module as _module __all__ = ['from_tensorflow'] @@ -1823,6 +1824,7 @@ def __init__(self): self._input_shapes = {} self._loops = {} self._branches = {} + self._mod = _module.Module({}) def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): """Construct relay nodes from tensorflow graph definition - GraphDef. @@ -1856,8 +1858,9 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): Returns ------- - sym : relay.op - The returned relay operator + mod : tvm.relay.Module + The module that optimizations will be performed on. + params : dict A dict of name: tvm.nd.array pairs, used as pretrained weights """ @@ -2046,8 +2049,8 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): out = out[0] if len(out) == 1 else _expr.Tuple(out) func = _expr.Function(ir_pass.free_vars(out), out) - - return func, self._params + self._mod[self._mod.entry_func] = func + return self._mod, self._params def _parse_import_prerequisites(self, graph): """ Calculate the named preconditions from TensorFlow `graph`. @@ -2336,12 +2339,12 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): Returns ------- - sym : relay.op - Compatible relay operator + mod : tvm.relay.Module + The module that optimizations will be performed on. params : dict of str to tvm.ndarray Dict of converted parameters stored in tvm.ndarray format """ g = GraphProto() - sym, params = g.from_tensorflow(graph, layout, shape, outputs) - return sym, params + mod, params = g.from_tensorflow(graph, layout, shape, outputs) + return mod, params diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 3b27428537e9..7b4394e7facb 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -22,6 +22,7 @@ import tvm from .. import ir_pass from .. import expr as _expr +from .. import module as _module from .. import op as _op from ... import nd as _nd from .common import ExprTable @@ -749,8 +750,8 @@ def from_tflite(model, shape_dict, dtype_dict): Returns ------- - func : tvm.relay.Function - Compatible relay Function + mod : tvm.relay.Module + The relay module for compilation. params : dict of str to tvm.NDArray The parameter dict to be used by relay @@ -788,4 +789,4 @@ def from_tflite(model, shape_dict, dtype_dict): outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) func = _expr.Function(ir_pass.free_vars(outputs), outputs) - return func, params + return _module.Module.from_expr(func), params diff --git a/tests/python/frontend/caffe2/test_forward.py b/tests/python/frontend/caffe2/test_forward.py index 73cde08ea8b5..465ac70afc5f 100644 --- a/tests/python/frontend/caffe2/test_forward.py +++ b/tests/python/frontend/caffe2/test_forward.py @@ -40,9 +40,10 @@ def get_tvm_output(model, input_names = model.predict_net.op[0].input[0] shape_dict = {input_names: input_data.shape} dtype_dict = {input_names: input_data.dtype} - func, params = relay.frontend.from_caffe2(model.init_net, model.predict_net, shape_dict, dtype_dict) + mod, params = relay.frontend.from_caffe2( + model.init_net, model.predict_net, shape_dict, dtype_dict) with relay.build_config(opt_level=3): - graph, lib, params = relay.build(func, target, params=params) + graph, lib, params = relay.build(mod[mod.entry_func], target, params=params) m = graph_runtime.create(graph, lib, ctx) diff --git a/tests/python/frontend/caffe2/test_graph.py b/tests/python/frontend/caffe2/test_graph.py index e5280932463d..ea3a36e60663 100644 --- a/tests/python/frontend/caffe2/test_graph.py +++ b/tests/python/frontend/caffe2/test_graph.py @@ -28,9 +28,10 @@ def compare_graph(f1, f2): def test_squeeze_net(): shape_dict = {'data': (1, 3, 224, 224)} dtype_dict = {'data': 'float32'} - from_c2_func, _ = relay.frontend.from_caffe2(c2_squeezenet.init_net, c2_squeezenet.predict_net, shape_dict, dtype_dict) + mod, _, = relay.frontend.from_caffe2( + c2_squeezenet.init_net, c2_squeezenet.predict_net, shape_dict, dtype_dict) relay_func, _ = relay_squeezenet() - compare_graph(from_c2_func, relay_func) + compare_graph(mod[mod.entry_func], relay_func) if __name__ == '__main__': diff --git a/tests/python/frontend/coreml/test_forward.py b/tests/python/frontend/coreml/test_forward.py index 0b6f91bed54f..7f0cdd1a5120 100644 --- a/tests/python/frontend/coreml/test_forward.py +++ b/tests/python/frontend/coreml/test_forward.py @@ -46,9 +46,9 @@ def run_model_checkonly(model_file, model_name='', input_name='image'): model = cm.models.MLModel(model_file) x = model_zoo.get_cat_image() shape_dict = {input_name : x.shape} - func, params = relay.frontend.from_coreml(model, shape_dict) + mod, params = relay.frontend.from_coreml(model, shape_dict) for target, ctx in ctx_list(): - tvm_output = get_tvm_output(func, x, params, target, ctx) + tvm_output = get_tvm_output(mod[mod.entry_func], x, params, target, ctx) print(target, ctx, model_name, 'prediction id: ', np.argmax(tvm_output.flat)) def test_mobilenet_checkonly(): @@ -71,9 +71,9 @@ def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shap shape_dict = {input_name: input_data.shape} dtype_dict = {input_name: input_data.dtype} - func, params = relay.frontend.from_coreml(coreml_model, shape_dict) + mod, params = relay.frontend.from_coreml(coreml_model, shape_dict) with relay.transform.build_config(opt_level=3): - graph, lib, params = relay.build(func, target, params=params) + graph, lib, params = relay.build(mod[mod.entry_func], target, params=params) from tvm.contrib import graph_runtime m = graph_runtime.create(graph, lib, ctx) diff --git a/tests/python/frontend/darknet/test_forward.py b/tests/python/frontend/darknet/test_forward.py index 3545e8a902bd..06916172f2bf 100644 --- a/tests/python/frontend/darknet/test_forward.py +++ b/tests/python/frontend/darknet/test_forward.py @@ -52,10 +52,12 @@ def _read_memory_buffer(shape, data, dtype='float32'): def _get_tvm_output(net, data, build_dtype='float32', states=None): '''Compute TVM output''' dtype = 'float32' - sym, params = relay.frontend.from_darknet(net, data.shape, dtype) + mod, params = relay.frontend.from_darknet(net, data.shape, dtype) target = 'llvm' shape_dict = {'data': data.shape} - graph, library, params = relay.build(sym, target, params=params) + graph, library, params = relay.build(mod[mod.entry_func], + target, + params=params) # Execute on TVM ctx = tvm.cpu(0) diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 0794db987892..2dfb43b70455 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -42,9 +42,11 @@ def get_keras_output(xs, dtype='float32'): def get_tvm_output(xs, target, ctx, dtype='float32'): shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)} - func, params = relay.frontend.from_keras(keras_model, shape_dict) + mod, params = relay.frontend.from_keras(keras_model, shape_dict) with relay.transform.build_config(opt_level=2): - graph, lib, params = relay.build(func, target, params=params) + graph, lib, params = relay.build(mod[mod.entry_func], + target, + params=params) m = graph_runtime.create(graph, lib, ctx) for name, x in zip(keras_model.input_names, xs): m.set_input(name, tvm.nd.array(x.astype(dtype))) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 45e2ab58cae3..d70b22284706 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -59,14 +59,14 @@ def get_mxnet_output(symbol, x, dtype='float32'): def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'): shape_dict = {"data": x.shape} if gluon_impl: - new_sym, params = relay.frontend.from_mxnet(symbol, shape_dict) + mod, params = relay.frontend.from_mxnet(symbol, shape_dict) else: - new_sym, params = relay.frontend.from_mxnet(symbol, - shape_dict, - arg_params=args, - aux_params=auxs) + mod, params = relay.frontend.from_mxnet(symbol, + shape_dict, + arg_params=args, + aux_params=auxs) with relay.build_config(opt_level=3): - graph, lib, params = relay.build(new_sym, target, params=params) + graph, lib, params = relay.build(mod[mod.entry_func], target, params=params) m = graph_runtime.create(graph, lib, ctx) # set inputs m.set_input("data", tvm.nd.array(x.astype(dtype))) @@ -242,11 +242,11 @@ def test_forward_where(): args, auxs = mod.get_params() mx_out = mx.nd.where(mx_cond, mx_x, mx_y).asnumpy() - new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, args, auxs) + mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, args, auxs) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)(np_cond, np_x, np_y) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(np_cond, np_x, np_y) tvm.testing.assert_allclose(op_res.asnumpy(), mx_out) @@ -265,11 +265,11 @@ def _mx_symbol(F, start, stop, step): def verify(start, stop, step): ref_res = _mx_symbol(mx.nd, start, stop, step).asnumpy() mx_sym = _mx_symbol(mx.sym, start, stop, step) - new_sym, _ = relay.frontend.from_mxnet(mx_sym, {}) + mod, _ = relay.frontend.from_mxnet(mx_sym, {}) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)() + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()() tvm.testing.assert_allclose(op_res.asnumpy(), ref_res) verify(0, 20, None) verify(0, 20, 2) @@ -304,11 +304,11 @@ def test_forward_broadcast_ops(): mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')]) ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)]) shapes = {'a': a_shape, 'b': b_shape} - new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) + mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)(a_np, b_np) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(a_np, b_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) def test_forward_elemwise_ops(): @@ -321,11 +321,11 @@ def test_forward_elemwise_ops(): mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')]) ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)]) shapes = {'a': shape, 'b': shape} - new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) + mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)(a_np, b_np) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(a_np, b_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) def test_forward_scalar_ops(): @@ -339,11 +339,11 @@ def test_forward_scalar_ops(): mx_sym = op(mx.sym.var('a'), b_scalar) ref_res = op(mx.nd.array(a_np), b_scalar) shapes = {'a': a_shape} - new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) + mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)(a_np) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(a_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) for op in ["maximum", "minimum"]: dtype='float32' @@ -353,11 +353,11 @@ def test_forward_scalar_ops(): mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), b_scalar]) ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), b_scalar]) shapes = {'a': a_shape} - new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) + mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)(a_np) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(a_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) def test_forward_slice_axis(): @@ -365,11 +365,11 @@ def verify(shape, axis, begin, end): data_np = np.random.uniform(size=shape).astype("float32") ref_res = mx.nd.slice_axis(mx.nd.array(data_np), axis, begin, end) mx_sym = mx.sym.slice_axis(mx.sym.var("data"), axis, begin, end) - new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape}) + mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape}) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)(data_np) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(data_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) verify((3, 4), 0, 1, 2) verify((3, 4), 0, 1, None) @@ -387,11 +387,11 @@ def verify(x_shape, y_shape, axes): else: ref_res = mx.nd.slice_like(mx.nd.array(x_np), mx.nd.array(y_np), axes=axes) mx_sym = mx.sym.slice_like(mx.sym.var("x"), mx.sym.var("y"), axes=axes) - new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": x_shape, "y": y_shape}) + mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": x_shape, "y": y_shape}) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)(x_np, y_np) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_np, y_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) verify((3, 4), (2, 3), None) verify((3, 4), (2, 3), (0, 1)) @@ -408,11 +408,11 @@ def verify(shape): x_np = np.random.uniform(size=shape).astype("float32") ref_res = mx.nd.shape_array(mx.nd.array(x_np)) mx_sym = mx.sym.shape_array(mx.sym.var("x")) - new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) + mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, ctx in ctx_list(): for kind in ["debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)(x_np) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) verify((1,)) verify((3, 4, 5)) @@ -427,11 +427,11 @@ def verify(shape, axis): else: ref_res = mx.nd.squeeze(mx.nd.array(x_np), axis=axis) mx_sym = mx.sym.squeeze(mx.sym.var("x"), axis=axis) - new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) + mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)(x_np) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) verify((1, 3, 1), None) verify((1, 3, 1), 0) @@ -443,11 +443,11 @@ def verify(shape, axis, size): x_np = np.random.uniform(size=shape).astype("float32") ref_res = mx.nd.broadcast_axis(mx.nd.array(x_np), axis=axis, size=size) mx_sym = mx.sym.broadcast_axis(mx.sym.var("x"), axis=axis, size=size) - new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) + mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)(x_np) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) verify((1, 2, 1), 2, 3) verify((1, 2, 1), (0, 2), (2, 3)) @@ -457,13 +457,13 @@ def verify(val, shape, dtype): ctx = mx.cpu() ref_res = mx.nd.full(shape, val, dtype=dtype) mx_sym = mx.sym.full(shape, val, dtype=dtype) - new_sym, _ = relay.frontend.from_mxnet(mx_sym, {}) + mod, _ = relay.frontend.from_mxnet(mx_sym, {}) for target, ctx in ctx_list(): # Skip testing graph runtime because this op will be optimized out # by constant folding. for kind in ["debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)() + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()() tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) verify(2, (3, 4), "float32") verify(2, (3, 4), "int32") @@ -478,12 +478,12 @@ def verify(data_shape, weight_shape): input_dim=in_dim, output_dim=out_dim) mx_sym = mx.sym.Embedding(mx.sym.var("x"), mx.sym.var("w"), input_dim=in_dim, output_dim=out_dim) - new_sym, _ = relay.frontend.from_mxnet( + mod, _ = relay.frontend.from_mxnet( mx_sym, {"x": data_shape, "w": weight_shape}) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)(x=x_np, w=w_np) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x=x_np, w=w_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) verify((2, 2), (4, 5)) verify((2, 3, 4), (4, 5)) @@ -501,11 +501,11 @@ def verify(shape, indices_src, axis, mode="clip"): indices_np = np.array(indices_src, dtype="float32") ref_res = mx.nd.take(mx.nd.array(x_np), mx.nd.array(indices_np), axis, mode) mx_sym = mx.sym.take(mx.sym.var("x"), mx.sym.var("y"), axis, mode) - new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape, "y": indices_np.shape}) + mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape, "y": indices_np.shape}) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)(x_np, indices_np) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_np, indices_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) verify((2,2), [[[1,0],[0,1]]], 0) verify((2,2), [[[1,0],[0,1]]], 1) @@ -520,11 +520,11 @@ def verify(xshape, yshape, y_data): x_data = np.random.uniform(size=xshape).astype("float32") ref_res = mx.nd.gather_nd(mx.nd.array(x_data), mx.nd.array(y_data)) mx_sym = mx.sym.gather_nd(mx.sym.var("x_data"), mx.sym.var("y_data")) - new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x_data": xshape, "y_data": yshape}, {"x_data": "float32", "y_data": "int32"}) + mod, _ = relay.frontend.from_mxnet(mx_sym, {"x_data": xshape, "y_data": yshape}, {"x_data": "float32", "y_data": "int32"}) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)(x_data, y_data) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_data, y_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) verify((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]]) verify((2, 2, 2), (2, 2), [[0, 1], [1, 0]]) @@ -575,13 +575,13 @@ def verify(mode, input_size, seq_len, hidden_size, num_layers, init_states=True) for name, param in layer.collect_params().items(): mx_params[name] = param._reduce() - new_sym, params = relay.frontend.from_mxnet( + mod, params = relay.frontend.from_mxnet( mx_sym, shape=shape_dict, arg_params=mx_params) for target, ctx in ctx_list(): # only test graph runtime because debug runtime is too slow for kind in ["graph"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)(**inputs, **params) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(**inputs, **params) if init_states: assert len(op_res) == len(mx_res) for i, val in enumerate(op_res): @@ -607,14 +607,14 @@ def verify(xshape, yshape, offset=None): else: mx_sym = mx.sym.Crop(mx.sym.var("x"), mx.sym.var("y"), offset=offset) ref_res = mx.nd.Crop(mx.nd.array(x_data), mx.nd.array(y_data), offset=offset) - new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": xshape, "y": yshape}) + mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": xshape, "y": yshape}) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) if offset is None or offset == (0, 0): - op_res = intrp.evaluate(new_sym)(x_data, y_data) + op_res = intrp.evaluate()(x_data, y_data) else: - op_res = intrp.evaluate(new_sym)(x_data) + op_res = intrp.evaluate()(x_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) verify((1, 3, 40, 40), (1, 3, 20, 20)) verify((1, 3, 40, 40), (1, 3, 20, 20), (0, 0)) @@ -627,11 +627,11 @@ def verify(shape, axis, is_ascend, dtype="float32"): x_np = np.random.uniform(size=shape).astype("float32") ref_res = mx.nd.argsort(mx.nd.array(x_np), axis=axis, is_ascend=is_ascend, dtype=dtype) mx_sym = mx.sym.argsort(mx.sym.var("x"), axis=axis, is_ascend=is_ascend, dtype=dtype) - new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) + mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)(x_np) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) verify((2, 3, 4), axis=0, is_ascend=False) verify((1, 4, 6), axis=1, is_ascend=True) @@ -644,11 +644,11 @@ def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"): is_ascend=is_ascend, dtype=dtype) mx_sym = mx.sym.topk(mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type, is_ascend=is_ascend, dtype=dtype) - new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) + mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(new_sym)(x_np) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_np) if isinstance(ref_res, list): assert len(op_res) == len(ref_res) for i, t in enumerate(op_res): diff --git a/tests/python/frontend/mxnet/test_graph.py b/tests/python/frontend/mxnet/test_graph.py index 86312648c32c..b7d3ba4a5b60 100644 --- a/tests/python/frontend/mxnet/test_graph.py +++ b/tests/python/frontend/mxnet/test_graph.py @@ -26,60 +26,60 @@ def compare_graph(f1, f2): def test_mlp(): shape = {"data": (1, 1, 28, 28)} mx_fun = model_zoo.mx_mlp() - from_mx_fun, _ = relay.frontend.from_mxnet(mx_fun, shape=shape) + mod, _ = relay.frontend.from_mxnet(mx_fun, shape=shape) relay_fun = model_zoo.relay_mlp() - compare_graph(from_mx_fun, relay_fun) + compare_graph(mod[mod.entry_func], relay_fun) def test_vgg(): shape = {"data": (1, 3, 224, 224)} for n in [11, 13, 16, 19]: mx_sym = model_zoo.mx_vgg(n) - from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape=shape) + mod, _ = relay.frontend.from_mxnet(mx_sym, shape=shape) relay_sym = model_zoo.relay_vgg(n) - compare_graph(from_mx_sym, relay_sym) + compare_graph(mod[mod.entry_func], relay_sym) def test_resnet(): shape = {"data": (1, 3, 224, 224)} for n in [18, 34, 50, 101]: mx_sym = model_zoo.mx_resnet(n) - from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape=shape) + mod, _ = relay.frontend.from_mxnet(mx_sym, shape=shape) relay_sym = model_zoo.relay_resnet(n) - compare_graph(from_mx_sym, relay_sym) + compare_graph(mod[mod.entry_func], relay_sym) def test_squeezenet(): shape = {"data": (1, 3, 224, 224)} for version in ['1.0', '1.1']: mx_sym = model_zoo.mx_squeezenet(version) - from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape) + mod, _ = relay.frontend.from_mxnet(mx_sym, shape) relay_sym = model_zoo.relay_squeezenet(version) - compare_graph(from_mx_sym, relay_sym) + compare_graph(mod[mod.entry_func], relay_sym) def test_inception_v3(): shape = {"data": (1, 3, 299, 299)} mx_sym = model_zoo.mx_inception_v3() - from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape) + mod, _ = relay.frontend.from_mxnet(mx_sym, shape) relay_sym = model_zoo.relay_inception_v3() - compare_graph(from_mx_sym, relay_sym) + compare_graph(mod[mod.entry_func], relay_sym) def test_dqn(): shape = {"data": (1, 4, 84, 84)} mx_sym = model_zoo.mx_dqn() - from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape) + mod, _ = relay.frontend.from_mxnet(mx_sym, shape) relay_sym = model_zoo.relay_dqn() - compare_graph(from_mx_sym, relay_sym) + compare_graph(mod[mod.entry_func], relay_sym) def test_dcgan(): shape = {"data": (2, 100)} mx_sym = model_zoo.mx_dcgan() - from_mx_sym, _ = relay.frontend.from_mxnet(mx_sym, shape) + mod, _ = relay.frontend.from_mxnet(mx_sym, shape) relay_sym = model_zoo.relay_dcgan(batch_size=2) - compare_graph(from_mx_sym, relay_sym) + compare_graph(mod[mod.entry_func], relay_sym) def test_multi_outputs(): @@ -100,10 +100,10 @@ def relay_compose(F, **kwargs): return relay.Function(relay.ir_pass.free_vars(z), z) mx_sym = mx_compose(mx, num_outputs=3, axis=1) - from_mx_sym, _ = relay.frontend.from_mxnet( + mod, _ = relay.frontend.from_mxnet( mx_sym, shape={"x":xshape, "y":yshape}) relay_sym = relay_compose(relay, indices_or_sections=3, axis=1) - compare_graph(from_mx_sym, relay_sym) + compare_graph(mod[mod.entry_func], relay_sym) if __name__ == "__main__": diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 095f1feb246a..7371a88ca677 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -42,9 +42,11 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output shape_dict = {input_names: input_data.shape} dtype_dict = {input_names: input_data.dtype} - sym, params = relay.frontend.from_onnx(graph_def, shape_dict) + mod, params = relay.frontend.from_onnx(graph_def, shape_dict) with relay.build_config(opt_level=1): - graph, lib, params = relay.build(sym, target, params=params) + graph, lib, params = relay.build(mod[mod.entry_func], + target, + params=params) ctx = tvm.cpu(0) from tvm.contrib import graph_runtime diff --git a/tests/python/frontend/tensorflow/test_control_flow.py b/tests/python/frontend/tensorflow/test_control_flow.py index b1860658a961..b08da1476601 100644 --- a/tests/python/frontend/tensorflow/test_control_flow.py +++ b/tests/python/frontend/tensorflow/test_control_flow.py @@ -22,9 +22,9 @@ def check_equal(graph, tf_out): - expr, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) - ex = relay.create_executor('debug') - relay_out = ex.evaluate(expr)(**params) + mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) + ex = relay.create_executor('debug', mod=mod) + relay_out = ex.evaluate()(**params) if isinstance(relay_out, relay.backend.interpreter.TensorValue): np.testing.assert_allclose(tf_out, relay_out.asnumpy()) else: diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 6fc825a8924c..bda14c18cde8 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -60,13 +60,12 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, shape_dict = {e: i.shape for e, i in zip(input_node, input_data)} - sym, params = relay.frontend.from_tensorflow(graph_def, + mod, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict, outputs=out_names) - with relay.build_config(opt_level=opt_level): - graph, lib, params = relay.build(sym, target, target_host, params) + graph, lib, params = relay.build(mod[mod.entry_func], target, target_host, params) ctx = tvm.context(target, 0) from tvm.contrib import graph_runtime @@ -1442,14 +1441,16 @@ def _get_tvm_graph_module(graph_def): 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':(num_layers, batch_size, num_hidden), 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':(num_layers, batch_size, num_hidden)} - sym, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict) + mod, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict) dtype_dict = {'Model/Placeholder': 'int32', 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c':'float32', 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h':'float32'} target = 'llvm' with relay.build_config(opt_level=0): - graph, lib, params = relay.build(sym, target, params=params) + graph, lib, params = relay.build(mod[mod.entry_func], + target, + params=params) from tvm.contrib import graph_runtime ctx = tvm.cpu(0) return params, graph_runtime.create(graph, lib, ctx) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index ec345ee78961..3b76fad1c073 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -63,11 +63,13 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target shape_dict[e] = input_data[i].shape dtype_dict[e] = input_data[i].dtype.name - func, params = relay.frontend.from_tflite(tflite_model, - shape_dict=shape_dict, - dtype_dict=dtype_dict) + mod, params = relay.frontend.from_tflite(tflite_model, + shape_dict=shape_dict, + dtype_dict=dtype_dict) with relay.build_config(opt_level=3): - graph, lib, params = relay.build(func, target, params=params) + graph, lib, params = relay.build(mod[mod.entry_func], + target, + params=params) ctx = tvm.context(target, 0) from tvm.contrib import graph_runtime diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 12e343be02ac..dbe0c1741d48 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -35,9 +35,9 @@ def veval(f, *args, ctx=tvm.cpu()): mod = f ex = relay.create_executor('vm', mod=mod, ctx=ctx) if len(args) == 0: - return ex.evaluate(mod[mod.entry_func]) + return ex.evaluate() else: - return ex.evaluate(mod[mod.entry_func])(*args) + return ex.evaluate()(*args) def test_split(): x = relay.var('x', shape=(12,)) diff --git a/tutorials/frontend/deploy_model_on_android.py b/tutorials/frontend/deploy_model_on_android.py index a3ea8651b110..2e416b7f379a 100644 --- a/tutorials/frontend/deploy_model_on_android.py +++ b/tutorials/frontend/deploy_model_on_android.py @@ -260,10 +260,10 @@ def transform_image(image): input_name = 'input_1' shape_dict = {input_name: x.shape} -func, params = relay.frontend.from_keras(keras_mobilenet_v2, shape_dict) +mod, params = relay.frontend.from_keras(keras_mobilenet_v2, shape_dict) with relay.build_config(opt_level=3): - graph, lib, params = relay.build(func, target=target, + graph, lib, params = relay.build(mod[mod.entry_func], target=target, target_host=target_host, params=params) # After `relay.build`, you will get three return values: graph, diff --git a/tutorials/frontend/deploy_model_on_rasp.py b/tutorials/frontend/deploy_model_on_rasp.py index c471e8228840..78377849c10b 100644 --- a/tutorials/frontend/deploy_model_on_rasp.py +++ b/tutorials/frontend/deploy_model_on_rasp.py @@ -140,8 +140,9 @@ def transform_image(image): # We support MXNet static graph(symbol) and HybridBlock in mxnet.gluon shape_dict = {'data': x.shape} -func, params = relay.frontend.from_mxnet(block, shape_dict) +mod, params = relay.frontend.from_mxnet(block, shape_dict) # we want a probability so add a softmax operator +func = mod[mod.entry_func] func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs) ###################################################################### diff --git a/tutorials/frontend/deploy_ssd_gluoncv.py b/tutorials/frontend/deploy_ssd_gluoncv.py index 829957b3c658..92b488f8fa07 100644 --- a/tutorials/frontend/deploy_ssd_gluoncv.py +++ b/tutorials/frontend/deploy_ssd_gluoncv.py @@ -76,9 +76,9 @@ block = model_zoo.get_model(model_name, pretrained=True) def build(target): - net, params = relay.frontend.from_mxnet(block, {"data": dshape}) + mod, params = relay.frontend.from_mxnet(block, {"data": dshape}) with relay.build_config(opt_level=3): - graph, lib, params = relay.build(net, target, params=params) + graph, lib, params = relay.build(mod[mod.entry_func], target, params=params) return graph, lib, params ###################################################################### diff --git a/tutorials/frontend/from_caffe2.py b/tutorials/frontend/from_caffe2.py index ceec8c0ad119..082dfafb33e7 100644 --- a/tutorials/frontend/from_caffe2.py +++ b/tutorials/frontend/from_caffe2.py @@ -83,13 +83,13 @@ def transform_image(image): # parse Caffe2 model and convert into Relay computation graph from tvm import relay -func, params = relay.frontend.from_caffe2(resnet50.init_net, resnet50.predict_net, shape_dict, dtype_dict) +mod, params = relay.frontend.from_caffe2(resnet50.init_net, resnet50.predict_net, shape_dict, dtype_dict) # compile the model # target x86 CPU target = 'llvm' with relay.build_config(opt_level=3): - graph, lib, params = relay.build(func, target, params=params) + graph, lib, params = relay.build(mod[mod.entry_func], target, params=params) ###################################################################### # Execute on TVM diff --git a/tutorials/frontend/from_coreml.py b/tutorials/frontend/from_coreml.py index e0c31445f6a8..7eeefb3f2e5b 100644 --- a/tutorials/frontend/from_coreml.py +++ b/tutorials/frontend/from_coreml.py @@ -68,10 +68,12 @@ shape_dict = {'image': x.shape} # Parse CoreML model and convert into Relay computation graph -func, params = relay.frontend.from_coreml(mlmodel, shape_dict) +mod, params = relay.frontend.from_coreml(mlmodel, shape_dict) with relay.build_config(opt_level=3): - graph, lib, params = relay.build(func, target, params=params) + graph, lib, params = relay.build(mod[mod.entry_func], + target, + params=params) ###################################################################### # Execute on TVM diff --git a/tutorials/frontend/from_darknet.py b/tutorials/frontend/from_darknet.py index 2658a353e34e..d9014e092ff0 100644 --- a/tutorials/frontend/from_darknet.py +++ b/tutorials/frontend/from_darknet.py @@ -82,7 +82,7 @@ data = np.empty([batch_size, net.c, net.h, net.w], dtype) shape_dict = {'data': data.shape} print("Converting darknet to relay functions...") -sym, params = relay.frontend.from_darknet(net, dtype=dtype, shape=data.shape) +mod, params = relay.frontend.from_darknet(net, dtype=dtype, shape=data.shape) ###################################################################### # Import the graph to Relay @@ -95,7 +95,10 @@ shape = {'data': data.shape} print("Compiling the model...") with relay.build_config(opt_level=3): - graph, lib, params = relay.build(sym, target=target, target_host=target_host, params=params) + graph, lib, params = relay.build(mod[mod.entry_func], + target=target, + target_host=target_host, + params=params) [neth, netw] = shape['data'][2:] # Current image shape is 608x608 ###################################################################### diff --git a/tutorials/frontend/from_keras.py b/tutorials/frontend/from_keras.py index 23c7eaf6b2b1..c1f3471bb644 100644 --- a/tutorials/frontend/from_keras.py +++ b/tutorials/frontend/from_keras.py @@ -74,18 +74,18 @@ # ---------------------------- # convert the keras model(NHWC layout) to Relay format(NCHW layout). shape_dict = {'input_1': data.shape} -func, params = relay.frontend.from_keras(keras_resnet50, shape_dict) +mod, params = relay.frontend.from_keras(keras_resnet50, shape_dict) # compile the model target = 'cuda' ctx = tvm.gpu(0) with relay.build_config(opt_level=3): - executor = relay.build_module.create_executor('graph', func, ctx, target) + executor = relay.build_module.create_executor('graph', mod, ctx, target) ###################################################################### # Execute on TVM # --------------- dtype = 'float32' -tvm_out = executor.evaluate(func)(tvm.nd.array(data.astype(dtype)), **params) +tvm_out = executor.evaluate()(tvm.nd.array(data.astype(dtype)), **params) top1_tvm = np.argmax(tvm_out.asnumpy()[0]) ##################################################################### diff --git a/tutorials/frontend/from_mxnet.py b/tutorials/frontend/from_mxnet.py index 2629dfaafaab..1109fd9b7d1c 100644 --- a/tutorials/frontend/from_mxnet.py +++ b/tutorials/frontend/from_mxnet.py @@ -82,8 +82,9 @@ def transform_image(image): # It's as easy as several lines. # We support MXNet static graph(symbol) and HybridBlock in mxnet.gluon shape_dict = {'data': x.shape} -func, params = relay.frontend.from_mxnet(block, shape_dict) +mod, params = relay.frontend.from_mxnet(block, shape_dict) ## we want a probability so add a softmax operator +func = mod[mod.entry_func] func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs) ###################################################################### @@ -132,6 +133,6 @@ def block2symbol(block): # for a normal mxnet model, we start from here mx_sym, args, auxs = mx.model.load_checkpoint('resnet18_v1', 0) # now we use the same API to get Relay computation graph -relay_func, relay_params = relay.frontend.from_mxnet(mx_sym, shape_dict, - arg_params=args, aux_params=auxs) +mod, relay_params = relay.frontend.from_mxnet(mx_sym, shape_dict, + arg_params=args, aux_params=auxs) # repeat the same steps to run this model using TVM diff --git a/tutorials/frontend/from_onnx.py b/tutorials/frontend/from_onnx.py index b177516b9a92..7a615930a905 100644 --- a/tutorials/frontend/from_onnx.py +++ b/tutorials/frontend/from_onnx.py @@ -71,16 +71,16 @@ input_name = '1' shape_dict = {input_name: x.shape} -sym, params = relay.frontend.from_onnx(onnx_model, shape_dict) +mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) with relay.build_config(opt_level=1): - intrp = relay.build_module.create_executor('graph', sym, tvm.cpu(0), target) + intrp = relay.build_module.create_executor('graph', mod, tvm.cpu(0), target) ###################################################################### # Execute on TVM # --------------------------------------------- dtype = 'float32' -tvm_output = intrp.evaluate(sym)(tvm.nd.array(x.astype(dtype)), **params).asnumpy() +tvm_output = intrp.evaluate()(tvm.nd.array(x.astype(dtype)), **params).asnumpy() ###################################################################### # Display results diff --git a/tutorials/frontend/from_tensorflow.py b/tutorials/frontend/from_tensorflow.py index 8d402820377e..6603c2da15bc 100644 --- a/tutorials/frontend/from_tensorflow.py +++ b/tutorials/frontend/from_tensorflow.py @@ -124,7 +124,9 @@ # params: params converted from tensorflow params (tensor protobuf). shape_dict = {'DecodeJpeg/contents': x.shape} dtype_dict = {'DecodeJpeg/contents': 'uint8'} -sym, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict) +mod, params = relay.frontend.from_tensorflow(graph_def, + layout=layout, + shape=shape_dict) print("Tensorflow protobuf imported to relay frontend.") ###################################################################### @@ -138,7 +140,10 @@ # lib: target library which can be deployed on target with TVM runtime. with relay.build_config(opt_level=3): - graph, lib, params = relay.build(sym, target=target, target_host=target_host, params=params) + graph, lib, params = relay.build(mod[mod.entry_func], + target=target, + target_host=target_host, + params=params) ###################################################################### # Execute the portable graph on TVM diff --git a/tutorials/frontend/from_tflite.py b/tutorials/frontend/from_tflite.py index f8cdd991c984..5a8525133f7c 100644 --- a/tutorials/frontend/from_tflite.py +++ b/tutorials/frontend/from_tflite.py @@ -138,14 +138,14 @@ def extract(path): # parse TFLite model and convert into Relay computation graph from tvm import relay -func, params = relay.frontend.from_tflite(tflite_model, - shape_dict={input_tensor: input_shape}, - dtype_dict={input_tensor: input_dtype}) +mod, params = relay.frontend.from_tflite(tflite_model, + shape_dict={input_tensor: input_shape}, + dtype_dict={input_tensor: input_dtype}) # target x86 CPU target = "llvm" with relay.build_config(opt_level=3): - graph, lib, params = relay.build(func, target, params=params) + graph, lib, params = relay.build(mod[mod.entry_func], target, params=params) ###################################################################### # Execute on TVM