diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 17cb148ec999..dabe55f72d1b 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -27,12 +27,29 @@ from .. import function as _function from .. import op as _op from .. import vision as _vision + +from ..function import Function +from ..expr import Call, Let +from ..expr import If, Tuple, TupleGetItem +from ..expr import RefCreate, RefRead, RefWrite +from ..expr_functor import ExprFunctor +from ..adt import Match, Clause + from .common import AttrCvt, Renamer from .common import get_relay_op, new_var, infer_shape, infer_channels -from .common import infer_type, infer_value, infer_value_simulated, get_name +from .common import infer_type, get_name +from .common import infer_value as _infer_value +from .common import infer_value_simulated as _infer_value_simulated __all__ = ['from_onnx'] +g = None + +def infer_value(input_val, params, mod=None): + return g.infer_value(input_val, params, mod) + +def infer_value_simulated(input_val, params): + return g.infer_value_simulated(input_val, params) class onnx_input(): """ Dual purpose list or dictionary access object.""" @@ -1891,8 +1908,7 @@ def _get_convert_map(opset): 'NonZero': NonZero.get_converter(opset), } - -class GraphProto(object): +class GraphProto(ExprFunctor): """A helper class for handling Relay expression copying from pb2.GraphProto. Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto @@ -1914,6 +1930,101 @@ def __init__(self, shape, dtype): self._shape = shape if shape else {} self._dtype = dtype + #For infering Values + self._tmp_params = {} + self._infer_simulated = True + self._mod = None + super(GraphProto, self).__init__() + + def infer_value(self, input_val, params, mod=None): + self._tmp_params = params + self._infer_simulated = False + self._mod = mod + return self.visit(input_val).data + #return _infer_value(input_val, params, mod) + + def infer_value_simulated(self, input_val, params): + self._tmp_params = params + self._infer_simulated = True + return self.visit(input_val).data + #return _infer_value_simulated(input_val, params) + + def infer(self, expr): + if self._infer_simulated: + out = _infer_value_simulated(expr, self._tmp_params) + else: + out = _infer_value(expr, self._tmp_params) + return _expr.const(out.asnumpy()) + + def visit_function(self, fn): + new_params = [self.visit(x) for x in fn.params] + new_body = self.visit(fn.body) + return self.infer(Function( + list(new_params), + new_body, + fn.ret_type, + fn.type_params, + fn.attrs)) + + def visit_let(self, let): + newvar = self.visit(let.var) + newval = self.visit(let.value) + newbody = self.visit(let.body) + return self.infer(Let(newvar, newval, newbody)) + + def visit_call(self, call): + new_fn = self.visit(call.op) + new_args = [self.visit(arg) for arg in call.args] + return self.infer(Call(new_fn, new_args, call.attrs)) + + def visit_var(self, var): + return self.infer(var) + + def visit_global_id(self, global_var): + return self.infer(global_var) + + def visit_if(self, ite): + return self.infer(If( + self.visit(ite.cond), + self.visit(ite.true_branch), + self.visit(ite.false_branch))) + + def visit_tuple(self, tup): + return Tuple([self.visit(field) for field in tup.fields]) + + def visit_tuple_getitem(self, op): + tuple_value = self.visit(op.tuple_value) + if not tuple_value.same_as(op.tuple_value): + return self.infer(TupleGetItem(tuple_value, op.index)) + return self.infer(op) + + def visit_global_var(self, gvar): + return self.infer(gvar) + + def visit_op(self, op): + return op + + def visit_constant(self, const): + return const + + def visit_constructor(self, con): + return con + + def visit_match(self, m): + return self.infer(Match( + self.visit(m.data), + [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses], + complete=m.complete)) + + def visit_ref_create(self, r): + return RefCreate(self.visit(r.value)) + + def visit_ref_write(self, r): + return RefWrite(self.visit(r.ref), self.visit(r.value)) + + def visit_ref_read(self, r): + return RefRead(self.visit(r.ref)) + def from_onnx(self, graph, opset): """Construct Relay expression from ONNX graph. @@ -2172,6 +2283,7 @@ def from_onnx(model, warnings.warn(str(e)) except ImportError: pass + global g g = GraphProto(shape, dtype) graph = model.graph if opset is None: @@ -2180,4 +2292,5 @@ def from_onnx(model, except AttributeError: opset = 1 mod, params = g.from_onnx(graph, opset) + g = None return mod, params