Skip to content

Commit

Permalink
[relay][frontend] Return module from frontend parsers (#3353)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and tqchen committed Jun 17, 2019
1 parent 07fbe5c commit fa35104
Show file tree
Hide file tree
Showing 34 changed files with 271 additions and 216 deletions.
79 changes: 42 additions & 37 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import numpy as np

from . import _backend
from .. import _make, ir_pass
from .. import _make, ir_pass, transform
from .. import module
from ... import register_func, nd
from ..base import NodeBase, register_relay_node
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
Expand Down Expand Up @@ -191,14 +192,14 @@ def _convert_args(self, expr, args, kwargs):

return tuple(cargs)

def _make_executor(self, _):
def _make_executor(self, expr=None):
"""
Construct a Python function that implements the evaluation
of expression.
Parameters
----------
expr: relay.Expr
expr: Optional[relay.Expr]
The Relay expression to execute.
Returns
Expand All @@ -208,16 +209,16 @@ def _make_executor(self, _):
"""
raise NotImplementedError()

def evaluate(self, expr, binds=None):
def evaluate(self, expr=None, binds=None):
"""
Evaluate a Relay expression on the executor.
Parameters
----------
expr: tvm.relay.Expr
expr: Optional[tvm.relay.Expr]
The expression to evaluate.
binds: Map[tvm.relay.Var, tvm.relay.Expr]
binds: Optional[Map[tvm.relay.Var, tvm.relay.Expr]]
Additional binding of free variable.
Returns
Expand All @@ -232,6 +233,9 @@ def evaluate(self, expr, binds=None):
scope_builder.ret(expr)
expr = scope_builder.get()

if not expr:
return self._make_executor()

if isinstance(expr, Function):
assert not ir_pass.free_vars(expr)

Expand Down Expand Up @@ -264,46 +268,47 @@ def __init__(self, mod, ctx, target):
self.target = target
self._intrp = _backend.CreateInterpreter(mod, ctx, target)

def optimize(self, expr):
"""Optimize an expr.
Parameters
----------
expr : Expr
The expression to be optimized.
def optimize(self):
"""Optimize functions in a module.
Returns
-------
opt_expr : Expr
The optimized expression.
opt_mod : tvm.relay.Module
The optimized module.
"""
# TODO: We need to move this optimization code into the optimizer/pass manager
wrapped_expr = expr if isinstance(expr, Function) else Function([], expr)
if self.mod:
self.mod[self.mod.entry_func] = wrapped_expr
ck_expr = ir_pass.infer_type(wrapped_expr, mod=self.mod)
simp_expr = ir_pass.simplify_inference(ck_expr)
ck_simp = ir_pass.infer_type(simp_expr, mod=self.mod)
fused_expr = ir_pass.fuse_ops(ck_simp, 0, mod=self.mod)
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
return ck_fused if isinstance(expr, Function) else Call(ck_fused, [])

def _make_executor(self, expr):
seq = transform.Sequential([transform.SimplifyInference(),
transform.FuseOps(0),
transform.InferType()])
return seq(self.mod)

def _make_executor(self, expr=None):
if expr is None or isinstance(expr, GlobalVar):
assert self.mod is not None
def _interp_wrapper(*args, **kwargs):
args = self._convert_args(expr, args, kwargs)
if expr is None:
args = self._convert_args(self.mod[self.mod.entry_func], args, kwargs)
else:
args = self._convert_args(expr, args, kwargs)

relay_args = []
for arg in args:
relay_args.append(_arg_to_ast(arg))

if isinstance(expr, GlobalVar):
func = self.mod[expr]
func = self.optimize(func)
self.mod._add(expr, func, True)
opt_expr = Call(expr, relay_args)
return self._intrp(opt_expr)
# Set the entry function for the module.
if expr is None:
pass
elif isinstance(expr, GlobalVar):
self.mod[self.mod.entry_func] = self.mod[expr]
else:
call = Call(expr, relay_args)
opt_expr = self.optimize(call)
return self._intrp(opt_expr)
assert isinstance(expr, Function)
func = Function([], Call(expr, relay_args))
relay_args = []
if self.mod:
self.mod[self.mod.entry_func] = func
else:
self.mod = module.Module.from_expr(func)

mod = self.optimize()
opt_expr = Call(mod[self.mod.entry_func.name_hint], relay_args)
return self._intrp(opt_expr)
return _interp_wrapper
8 changes: 5 additions & 3 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,11 @@ def __init__(self, mod, ctx, target):
self.ctx = ctx
self.target = target

def _make_executor(self, expr):
assert isinstance(expr, Expr)
self.mod[self.mod.entry_func] = expr
def _make_executor(self, expr=None):
expr = expr if expr else self.mod
assert expr, "either expr or self.mod should be not null."
if isinstance(expr, Expr):
self.mod[self.mod.entry_func] = expr
main = self.mod[self.mod.entry_func]

def _vm_wrapper(*args, **kwargs):
Expand Down
11 changes: 7 additions & 4 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,19 @@ def __init__(self, mod, ctx, target):
self.ctx = ctx
self.target = target

def _make_executor(self, func):
ret_type = ir_pass.infer_type(func).ret_type
def _make_executor(self, expr=None):
if not expr:
assert self.mod, "either expr or self.mod should be not null."
expr = self.mod[self.mod.entry_func]
ret_type = ir_pass.infer_type(expr).ret_type
num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
graph_json, mod, params = build(func, target=self.target)
graph_json, mod, params = build(expr, target=self.target)
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
if params:
gmodule.set_input(**params)

def _graph_wrapper(*args, **kwargs):
args = self._convert_args(func, args, kwargs)
args = self._convert_args(expr, args, kwargs)
# Create map of inputs.
for i, arg in enumerate(args):
gmodule.set_input(i, arg)
Expand Down
14 changes: 9 additions & 5 deletions python/tvm/relay/frontend/caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tvm
from .. import ir_pass
from .. import expr as _expr
from .. import module as _module
from .. import op as _op
from ... import nd as _nd
from .common import AttrCvt, Renamer
Expand Down Expand Up @@ -382,6 +383,7 @@ def __init__(self, shape, dtype):
self._ops = {}
self._shape = shape
self._dtype = dtype
self._mod = _module.Module({})

def from_caffe2(self, init_net, predict_net):
"""Construct Relay expression from caffe2 graph.
Expand All @@ -393,8 +395,9 @@ def from_caffe2(self, init_net, predict_net):
Returns
-------
func : tvm.relay.expr.Function
Compatible relay function
mod : tvm.relay.Module
The module that optimizations will be performed on.
params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
Expand Down Expand Up @@ -448,8 +451,9 @@ def from_caffe2(self, init_net, predict_net):
outputs = out[0]

func = _expr.Function(ir_pass.free_vars(outputs), outputs)
self._mod[self._mod.entry_func] = func

return func, self._params
return self._mod, self._params

def _get_node(self, blob):
"""Get the Symbol of blob and detect cyclic dependency in the graph."""
Expand Down Expand Up @@ -560,8 +564,8 @@ def from_caffe2(init_net, predict_net, shape=None, dtype="float32"):
Returns
-------
sym : tvm.relay.expr.Function
Compatible relay function
mod : tvm.relay.Module
The module that optimizations will be performed on.
params : dict of str to tvm.ndarray
Dict of converted parameters stored in tvm.ndarray format
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/relay/frontend/coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tvm
from .. import ir_pass
from .. import expr as _expr
from .. import module as _module
from .. import op as _op
from ... import nd as _nd
from ..._ffi import base as _base
Expand Down Expand Up @@ -416,8 +417,8 @@ def from_coreml(model, shape=None):
Returns
-------
func : tvm.relay.Function
Compatible relay Function.
mod : tvm.relay.Module
The relay module for compilation.
params : dict of str to tvm.NDArray
The parameter dict to be used by Relay.
Expand Down Expand Up @@ -463,4 +464,4 @@ def from_coreml(model, shape=None):
outexpr = outexpr[0]
func = _expr.Function(ir_pass.free_vars(outexpr), outexpr)
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
return func, params
return _module.Module.from_expr(func), params
8 changes: 5 additions & 3 deletions python/tvm/relay/frontend/darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import tvm
from .. import ir_pass
from .. import expr as _expr
from .. import module as _module
from .common import get_relay_op, new_var

__all__ = ['from_darknet']
Expand Down Expand Up @@ -820,7 +821,7 @@ def from_darknet(self):
outputs = _as_list(sym) + self._outs
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
sym = _expr.Function(ir_pass.free_vars(outputs), outputs)
return sym, self._tvmparams
return _module.Module.from_expr(sym), self._tvmparams

def from_darknet(net,
shape=None,
Expand All @@ -838,8 +839,9 @@ def from_darknet(net,
Returns
-------
sym : tvm.relay.Function
Compatible relay Function
mod : tvm.relay.Module
The relay module for compilation.
params : dict of str to tvm.NDArray
The parameter dict to be used by relay
"""
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tvm
from .. import ir_pass
from .. import expr as _expr
from .. import module as _module
from .. import op as _op
from ... import nd as _nd
from .common import ExprTable, new_var
Expand Down Expand Up @@ -679,8 +680,8 @@ def from_keras(model, shape=None):
Returns
-------
func : tvm.relay.Function
Compatible relay Function.
mod : tvm.relay.Module
The relay module for compilation.
params : dict of str to tvm.NDArray
The parameter dict to be used by Relay.
Expand Down Expand Up @@ -744,4 +745,4 @@ def _convert_input_layer(keras_layer):
outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr)
func = _expr.Function(ir_pass.free_vars(outexpr), outexpr)
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
return func, params
return _module.Module.from_expr(func), params
20 changes: 14 additions & 6 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .. import ir_pass
from .. import expr as _expr
from .. import op as _op
from .. import module as _module
from ... import nd as _nd

from .common import StrAttrsDict
Expand Down Expand Up @@ -992,7 +993,8 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
_convert_map.update({k : _rename(k) for k in _identity_list})


def _from_mxnet_impl(symbol, shape_dict, dtype_info):
def _from_mxnet_impl(symbol, shape_dict, dtype_info, mod=None):
#pylint: disable=unused-argument
"""Convert mxnet symbol to compatible relay Function.
Reconstruct a relay Function by traversing the mxnet symbol.
Expand All @@ -1009,6 +1011,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
dtype_info : dict or str.
Known parameter dtypes
mod : tvm.relay.Module
The module that contains global information. It will be used for
converting ops that need global information, e.g. control-flow ops.
Returns:
-------
func : tvm.relay.Function
Expand Down Expand Up @@ -1097,8 +1103,8 @@ def from_mxnet(symbol,
Returns
-------
sym : tvm.relay.Function
Compatible relay Function
mod : tvm.relay.Module
The relay module for compilation
params : dict of str to tvm.NDArray
The parameter dict to be used by nnvm
Expand All @@ -1108,6 +1114,7 @@ def from_mxnet(symbol,
except ImportError as e:
raise ImportError("{}. MXNet is required to parse symbols.".format(e))

mod = _module.Module()
if isinstance(symbol, mx.sym.Symbol):
params = {}
arg_params = arg_params if arg_params else {}
Expand All @@ -1117,7 +1124,7 @@ def from_mxnet(symbol,
for k, v in aux_params.items():
params[k] = _nd.array(v.asnumpy())
shape, dtype = _update_shape_dtype(shape, dtype, params)
sym = _from_mxnet_impl(symbol, shape, dtype)
func = _from_mxnet_impl(symbol, shape, dtype, mod)
elif isinstance(symbol, mx.gluon.HybridBlock):
if arg_params is not None or aux_params is not None:
raise ValueError("arg_params and aux_params ae not used when importing HybridBlock")
Expand All @@ -1129,10 +1136,11 @@ def from_mxnet(symbol,
if isinstance(sym, (list, tuple)):
sym = mx.sym.Group(sym)
shape, dtype = _update_shape_dtype(shape, dtype, params)
sym = _from_mxnet_impl(sym, shape, dtype)
func = _from_mxnet_impl(sym, shape, dtype, mod)
elif isinstance(symbol, mx.gluon.Block):
raise NotImplementedError("Only Hybrid Blocks are supported now.")
else:
msg = "mxnet.Symbol or gluon.HybridBlock expected, got {}".format(type(symbol))
raise ValueError(msg)
return sym, params
mod[mod.entry_func] = func
return mod, params
Loading

0 comments on commit fa35104

Please sign in to comment.