From 2db78d4006563927fc48297570eb2f18bedb3a40 Mon Sep 17 00:00:00 2001 From: MORITA Kazutaka Date: Fri, 8 Mar 2019 14:35:46 +0900 Subject: [PATCH] Improve NNVM to Relay conversion (#2734) * Improve NNVM to Relay conversion * fix pylint * support __lshift_scalar__, abs, ceil, floor, and trunc to pass CI --- nnvm/python/nnvm/testing/check_computation.py | 19 ++ nnvm/python/nnvm/to_relay.py | 258 ++++++++---------- python/tvm/relay/frontend/nnvm_common.py | 25 +- src/relay/pass/type_solver.cc | 4 +- 4 files changed, 145 insertions(+), 161 deletions(-) diff --git a/nnvm/python/nnvm/testing/check_computation.py b/nnvm/python/nnvm/testing/check_computation.py index 7ab4dc0d4c6c..68419b73523b 100644 --- a/nnvm/python/nnvm/testing/check_computation.py +++ b/nnvm/python/nnvm/testing/check_computation.py @@ -8,10 +8,12 @@ import tvm from tvm.contrib import graph_runtime from tvm.testing import check_numerical_grads +from tvm import relay import nnvm from nnvm.compiler import graph_util from nnvm.compiler.graph_attr import TCODE_TO_DTYPE, DTYPE_TO_TCODE +from nnvm.to_relay import to_relay from .config import ctx_list def infer_shapes_dtypes(graph, shape=None, dtype=None, fallback_dtype=None): @@ -441,6 +443,23 @@ def check_function(symbol, forward=None, backward=None, grad_input_vars=None, debug_stage = "running" nnvm_res = main_function(**np_inputs) + try: + logging.debug("checking to_relay conversion") + inputs = np_inputs_without_head_grads.copy() + func, inputs = to_relay(main_graph, shape, dtype, params=inputs) + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(func, target=target) + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**inputs) + m.set_input(**params) + m.run() + for i in range(out_len): + relay_out = m.get_output(i).asnumpy() + tvm.testing.assert_allclose(nnvm_res[i], relay_out, atol=atol, rtol=rtol) + except NotImplementedError as err: + # the NNVM operator is not supported yet + logging.warning(err) + if backward_graph is not None: grad_var_names = [x.attr('name') for x in grad_input_vars] nnvm_grads = {x: v for x, v in zip(grad_var_names, nnvm_res[out_len:])} diff --git a/nnvm/python/nnvm/to_relay.py b/nnvm/python/nnvm/to_relay.py index 264a18d90c77..7d792116b104 100644 --- a/nnvm/python/nnvm/to_relay.py +++ b/nnvm/python/nnvm/to_relay.py @@ -6,7 +6,8 @@ from tvm import relay, nd from tvm.relay import op, expr, var from tvm.relay.frontend.common import StrAttrsDict -from tvm.relay.frontend.nnvm_common import _rename +from tvm.relay.frontend.nnvm_common import _rename, _binop_scalar, _rbinop_scalar, \ + _elemwise_sum, _softmax_op, _compare, _reduce from .symbol import Symbol from .compiler import graph_attr from .graph import create as graph_create @@ -25,11 +26,6 @@ def _dense(children, attrs, odtype='float32'): else: return dense -def _nn_softmax(children, attrs, odtype='float32'): - assert len(children) == 1 - axis = attrs.get_int('axis', 1) - return op.nn.softmax(children[0], axis) - def _conv2d(children, attrs, odtype='float32'): use_bias = attrs.get_bool('use_bias', True) @@ -150,84 +146,6 @@ def _transpose(children, attrs, odtype='float32'): return op.transpose(children[0], axes=axes) -def _add(children, attrs, odtype='float32'): - if len(children) == 1: - left = children[0] - scalar = attrs.get_float('scalar') - right = relay.const(scalar, dtype=odtype) - else: - assert len(children) == 2 - left = children[0] - right = children[1] - - return op.add(left, right) - - -def _subtract(children, attrs, odtype='float32'): - if len(children) == 1: - left = children[0] - scalar = attrs.get_float('scalar') - right = relay.const(scalar, dtype=odtype) - else: - assert len(children) == 2 - left = children[0] - right = children[1] - - return op.subtract(left, right) - - -def _rsubtract(children, attrs, odtype='float32'): - if len(children) == 1: - left = children[0] - scalar = attrs.get_float('scalar') - right = relay.const(scalar, dtype=odtype) - else: - assert len(children) == 2 - left = children[0] - right = children[1] - - return op.subtract(right, left) - - -def _multiply(children, attrs, odtype='float32'): - if len(children) == 1: - left = children[0] - scalar = attrs.get_float('scalar') - right = relay.const(scalar, dtype=odtype) - else: - assert len(children) == 2 - left = children[0] - right = children[1] - - return op.multiply(left, right) - - -def _divide(children, attrs, odtype='float32'): - if len(children) == 1: - left = children[0] - scalar = attrs.get_float('scalar') - right = relay.const(scalar, dtype=odtype) - else: - assert len(children) == 2 - left = children[0] - right = children[1] - - return op.divide(left, right) - - -def _rshift(children, attrs, odtype='float32'): - if len(children) == 1: - left = children[0] - scalar = attrs.get_float('scalar') - right = relay.const(scalar, dtype='int32') - else: - assert len(children) == 2 - left = children[0] - right = children[1] - - return op.right_shift(left, right) - - def _clip(children, attrs, odtype='float32'): a_min = attrs.get_float('a_min') a_max = attrs.get_float('a_max') @@ -255,9 +173,6 @@ def broadcast_to(children, attrs, odtype='float32'): rconst = relay.Constant(nd.array(array)) return op.broadcast_to_like(data, rconst) -def _copy(children, attrs, odtype='float32'): - return op.copy(children[0]) - def _global_avg_pool2d(children, attrs, odtype='float32'): data = children[0] @@ -309,42 +224,10 @@ def _full_like(children, attrs, odtype='float32'): return op.full_like(children[0], fill_value) -def _greater(children, attrs, odtype='float32'): - out_type = attrs.get_str('out_type') - if out_type: - return op.greater(children[0], children[1]).astype(out_type) - else: - return op.greater(children[0], children[1]) - - -def _greater_equal(children, attrs, odtype='float32'): - out_type = attrs.get_str('out_type', None) - if out_type: - return op.greater_equal(children[0], children[1]).astype(out_type) - else: - return op.greater_equal(children[0], children[1]) - - -def _less(children, attrs, odtype='float32'): - out_type = attrs.get_str('out_type', None) - if out_type: - return op.less(children[0], children[1]).astype(out_type) - else: - return op.less(children[0], children[1]) - - -def _less_equal(children, attrs, odtype='float32'): - out_type = attrs.get_str('out_type', None) - if out_type: - return op.less_equal(children[0], children[1]).astype(out_type) - else: - return op.less_equal(children[0], children[1]) - - def _strided_slice(children, attrs, odtype='float32'): begin = attrs.get_int_list('begin') end = attrs.get_int_list('end') - strides = attrs.get_int_list('strides', None) + strides = attrs.get_int_list('stride', None) return op.strided_slice(children[0], begin, end, strides=strides) @@ -358,14 +241,11 @@ def _split(children, attrs, odtype='float32'): axis = attrs.get_int('axis', 0) - return op.split(children[0], indices_or_sections, axis) + return op.split(children[0], indices_or_sections, axis).astuple() def _squeeze(children, attrs, odtype='float32'): - axis = None - try: - axis = [attrs.get_int('axis', None)] - except ValueError: - axis = axis or attrs.get_int_tuple('axis', None) + axis = attrs.get_int_tuple('axis', None) + axis = [axis] if isinstance(axis, int) else axis return op.squeeze(children[0], axis) @@ -378,20 +258,60 @@ def _dropout(children, attrs, odtype='float32'): return op.nn.dropout(children[0], rate) def _mean(children, attrs, odtype='float32'): - axis = None - try: - axis = [attrs.get_int('axis', None)] - except ValueError: - axis = axis or attrs.get_int_tuple('axis', None) + axis = attrs.get_int_tuple('axis', None) keepdims = attrs.get_bool('keepdims') return op.mean(children[0], axis, keepdims) +def _prelu(children, attrs, odtype='float32'): + axis = attrs.get_int('axis', 1) + return op.nn.prelu(children[0], children[1], axis) + + +def _lrn(children, attrs, odtype='float32'): + size = attrs.get_int("size", 5) + axis = attrs.get_int("axis", 1) + bias = attrs.get_float("bias", 2) + alpha = attrs.get_float("alpha", 1e-05) + beta = attrs.get_float("beta", 0.75) + return op.nn.lrn(children[0], size, axis, bias, alpha, beta) + + +def _l2_nomalize(children, attrs, odtype='float32'): + eps = attrs.get_float('eps') + axis = attrs.get_int_tuple('axis', None) + return op.nn.l2_normalize(children[0], eps, axis) + + +def _take(children, attrs, odtype='float32'): + axis = attrs.get_int('axis', None) + return op.take(children[0], children[1], axis) + + +def _matmul(children, attrs, odtype='float32'): + input_1_t = op.transpose(children[1], axes=(1, 0)) + return op.nn.dense(children[0], input_1_t) + + +def _collapse_sum(children, attrs, odtype='float32'): + for key in ["axis", "keepdims", "exclude"]: + if key in attrs.attrs: + raise NotImplementedError("Parameter '" + key + "' is not supported.") + return op.collapse_sum_like(children[0], children[1]) + + +def _not_implemented(new_op): + def _impl(children, attrs, odtype='float32'): + raise NotImplementedError(str(new_op) + " is not implemented.") + return _impl + + NNVM_OP_2_RELAY_OP = { 'flatten': _nn_batch_flatten, 'dense': _dense, - 'softmax': _nn_softmax, + 'softmax': _softmax_op(op.nn.softmax), + 'log_softmax': _softmax_op(op.nn.log_softmax), 'conv2d': _conv2d, 'batch_norm': _batch_norm, 'max_pool2d': _max_pool2d, @@ -400,30 +320,47 @@ def _mean(children, attrs, odtype='float32'): 'dropout': _dropout, 'mean': _mean, # Addition - '__add_scalar__': _add, - 'broadcast_add': _add, - 'elemwise_add': _add, + '__add_scalar__': _binop_scalar(op.add), + 'broadcast_add' : _rename(op.add), + 'elemwise_add' : _rename(op.add), # Subtraction - '__sub_scalar__': _subtract, - '__rsub_scalar__': _rsubtract, - 'broadcast_sub': _subtract, - 'elemwise_sub': _subtract, + '__sub_scalar__' : _binop_scalar(op.subtract), + '__rsub_scalar__': _rbinop_scalar(op.subtract), + 'broadcast_sub' : _rename(op.subtract), + 'elemwise_sub' : _rename(op.subtract), # Multiply - '__mul_scalar__': _multiply, - 'broadcast_mul': _multiply, - 'elemwise_mul': _multiply, + '__mul_scalar__': _binop_scalar(op.multiply), + 'broadcast_mul' : _rename(op.multiply), + 'elemwise_mul' : _rename(op.multiply), # Division - '__div_scalar__': _divide, - 'broadcast_div': _divide, - 'elemwise_div': _divide, + '__div_scalar__': _binop_scalar(op.divide), + 'broadcast_div' : _rename(op.divide), + 'elemwise_div' : _rename(op.divide), + 'broadcast_mod' : _rename(op.mod), # Negative 'negative': _rename("negative"), + # Power + '__pow_scalar__': _binop_scalar(op.power), + '__rpow_scalar__': _rbinop_scalar(op.power), + 'broadcast_pow': _rename(op.power), + # Sum + 'sum': _reduce(op.sum), + 'elemwise_sum': _elemwise_sum, + 'collapse_sum': _collapse_sum, + 'broadcast_max': _rename(op.maximum), + 'broadcast_min': _rename(op.minimum), # Comparsion - 'greater': _greater, - 'greater_equal': _greater_equal, - 'less': _less, - 'less_equal': _less_equal, + 'greater': _compare(op.greater), + 'broadcast_greater': _compare(op.greater), + 'greater_equal': _compare(op.greater_equal), + 'broadcast_greater_equal': _compare(op.greater_equal), + 'less': _compare(op.less), + 'broadcast_less': _compare(op.less), + 'less_equal': _compare(op.less_equal), + 'broadcast_less_equal': _compare(op.less_equal), + 'broadcast_equal': _compare(op.equal), + 'broadcast_not_equal': _compare(op.not_equal), # Activations 'sigmoid': _rename('sigmoid'), @@ -432,13 +369,17 @@ def _mean(children, attrs, odtype='float32'): 'log': _rename('log'), 'tanh': _rename('tanh'), 'leaky_relu': _leaky_relu, + 'prelu': _prelu, 'clip': _clip, 'round': _rename('round'), 'cast': _cast, 'expand_dims': _expand_dims, 'broadcast_to': broadcast_to, - '__rshift_scalar__': _rshift, - 'copy': _copy, + '__lshift_scalar__': _binop_scalar(op.left_shift), + '__rshift_scalar__': _binop_scalar(op.right_shift), + 'broadcast_left_shift': _rename(op.left_shift), + 'broadcast_right_shift': _rename(op.right_shift), + 'copy': _rename(op.copy), 'global_avg_pool2d': _global_avg_pool2d, 'avg_pool2d': _avg_pool2d, 'conv2d_transpose': _conv2d_transpose, @@ -449,6 +390,21 @@ def _mean(children, attrs, odtype='float32'): 'split': _split, 'squeeze': _squeeze, 'concatenate': _concatenate, + 'abs': _rename(op.abs), + 'ceil': _rename(op.ceil), + 'floor': _rename(op.floor), + 'trunc': _rename(op.trunc), + 'take': _take, + 'lrn': _lrn, + 'l2_normalize': _l2_nomalize, + 'matmul': _matmul, + 'zeros_like': _rename(op.zeros_like), + 'reshape_like': _rename(op.reshape_like), + 'ones_like': _rename(op.ones_like), + + 'expand_like': _not_implemented("expand_like"), + 'gather_nd': _not_implemented("gather_nd"), + 'block_grad': _not_implemented("block_grad"), } diff --git a/python/tvm/relay/frontend/nnvm_common.py b/python/tvm/relay/frontend/nnvm_common.py index 3838c3d4aa3b..7fd6f409cfd3 100644 --- a/python/tvm/relay/frontend/nnvm_common.py +++ b/python/tvm/relay/frontend/nnvm_common.py @@ -41,7 +41,7 @@ def _impl(inputs, attrs): def _softmax_op(new_op): """softmax/log_softmax""" - def _impl(inputs, attrs): + def _impl(inputs, attrs, _dtype='float32'): assert len(inputs) == 1 axis = attrs.get_int("axis", -1) return new_op(inputs[0], axis=axis) @@ -50,13 +50,14 @@ def _impl(inputs, attrs): def _reduce(new_op): """Reduction ops like sum/min/max""" - def _impl(inputs, attrs): + def _impl(inputs, attrs, _dtype='float32'): assert len(inputs) == 1 axis = attrs.get_int_tuple("axis", []) keepdims = attrs.get_bool("keepdims", False) + exclude = attrs.get_bool("exclude", False) # use None for reduce over all axis. axis = None if len(axis) == 0 else axis - return new_op(inputs[0], axis=axis, keepdims=keepdims) + return new_op(inputs[0], axis=axis, keepdims=keepdims, exclude=exclude) return _impl @@ -97,7 +98,7 @@ def _upsampling(inputs, attrs): return _op.nn.upsampling(inputs[0], scale=scale) -def _elemwise_sum(inputs, _): +def _elemwise_sum(inputs, _, _dtype='float32'): assert len(inputs) > 0 res = inputs[0] for x in inputs[1:]: @@ -106,20 +107,28 @@ def _elemwise_sum(inputs, _): def _binop_scalar(new_op): - def _impl(inputs, attrs): + def _impl(inputs, attrs, odtype='float32'): assert len(inputs) == 1 scalar = attrs.get_float("scalar") # Note: binary scalar only works for float op for now - scalar = _expr.const(scalar, dtype="float32") + scalar = _expr.const(scalar, dtype=odtype) return new_op(inputs[0], scalar) return _impl def _rbinop_scalar(new_op): - def _impl(inputs, attrs): + def _impl(inputs, attrs, odtype='float32'): assert len(inputs) == 1 scalar = attrs.get_float("scalar") # Note: binary scalar only works for float op for now - scalar = _expr.const(scalar, dtype="float32") + scalar = _expr.const(scalar, dtype=odtype) return new_op(scalar, inputs[0]) return _impl + + +def _compare(new_op): + """Compare ops like greater/less""" + def _impl(inputs, _, odtype='float32'): + assert len(inputs) == 2 + return new_op(inputs[0], inputs[1]).astype(odtype) + return _impl diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 179f90a2fe15..abbd82977499 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -476,8 +476,8 @@ bool TypeSolver::Solve() { rnode->resolved = false; this->ReportError( RELAY_ERROR( - "an internal invariant was violdated while" \ - "typechecking your program" << + "an internal invariant was violdated while " \ + "typechecking your program " << err.what()), rnode->location); }