Skip to content

Commit

Permalink
Improve NNVM to Relay conversion (apache#2734)
Browse files Browse the repository at this point in the history
* Improve NNVM to Relay conversion

* fix pylint

* support __lshift_scalar__, abs, ceil, floor, and trunc to pass CI
  • Loading branch information
kazum authored and wweic committed Mar 9, 2019
1 parent 6897580 commit 2db78d4
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 161 deletions.
19 changes: 19 additions & 0 deletions nnvm/python/nnvm/testing/check_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import tvm
from tvm.contrib import graph_runtime
from tvm.testing import check_numerical_grads
from tvm import relay

import nnvm
from nnvm.compiler import graph_util
from nnvm.compiler.graph_attr import TCODE_TO_DTYPE, DTYPE_TO_TCODE
from nnvm.to_relay import to_relay
from .config import ctx_list

def infer_shapes_dtypes(graph, shape=None, dtype=None, fallback_dtype=None):
Expand Down Expand Up @@ -441,6 +443,23 @@ def check_function(symbol, forward=None, backward=None, grad_input_vars=None,
debug_stage = "running"
nnvm_res = main_function(**np_inputs)

try:
logging.debug("checking to_relay conversion")
inputs = np_inputs_without_head_grads.copy()
func, inputs = to_relay(main_graph, shape, dtype, params=inputs)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target=target)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**inputs)
m.set_input(**params)
m.run()
for i in range(out_len):
relay_out = m.get_output(i).asnumpy()
tvm.testing.assert_allclose(nnvm_res[i], relay_out, atol=atol, rtol=rtol)
except NotImplementedError as err:
# the NNVM operator is not supported yet
logging.warning(err)

if backward_graph is not None:
grad_var_names = [x.attr('name') for x in grad_input_vars]
nnvm_grads = {x: v for x, v in zip(grad_var_names, nnvm_res[out_len:])}
Expand Down
258 changes: 107 additions & 151 deletions nnvm/python/nnvm/to_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from tvm import relay, nd
from tvm.relay import op, expr, var
from tvm.relay.frontend.common import StrAttrsDict
from tvm.relay.frontend.nnvm_common import _rename
from tvm.relay.frontend.nnvm_common import _rename, _binop_scalar, _rbinop_scalar, \
_elemwise_sum, _softmax_op, _compare, _reduce
from .symbol import Symbol
from .compiler import graph_attr
from .graph import create as graph_create
Expand All @@ -25,11 +26,6 @@ def _dense(children, attrs, odtype='float32'):
else:
return dense

def _nn_softmax(children, attrs, odtype='float32'):
assert len(children) == 1
axis = attrs.get_int('axis', 1)
return op.nn.softmax(children[0], axis)

def _conv2d(children, attrs, odtype='float32'):
use_bias = attrs.get_bool('use_bias', True)

Expand Down Expand Up @@ -150,84 +146,6 @@ def _transpose(children, attrs, odtype='float32'):
return op.transpose(children[0], axes=axes)


def _add(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]

return op.add(left, right)


def _subtract(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]

return op.subtract(left, right)


def _rsubtract(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]

return op.subtract(right, left)


def _multiply(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]

return op.multiply(left, right)


def _divide(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]

return op.divide(left, right)


def _rshift(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype='int32')
else:
assert len(children) == 2
left = children[0]
right = children[1]

return op.right_shift(left, right)


def _clip(children, attrs, odtype='float32'):
a_min = attrs.get_float('a_min')
a_max = attrs.get_float('a_max')
Expand Down Expand Up @@ -255,9 +173,6 @@ def broadcast_to(children, attrs, odtype='float32'):
rconst = relay.Constant(nd.array(array))
return op.broadcast_to_like(data, rconst)

def _copy(children, attrs, odtype='float32'):
return op.copy(children[0])


def _global_avg_pool2d(children, attrs, odtype='float32'):
data = children[0]
Expand Down Expand Up @@ -309,42 +224,10 @@ def _full_like(children, attrs, odtype='float32'):
return op.full_like(children[0], fill_value)


def _greater(children, attrs, odtype='float32'):
out_type = attrs.get_str('out_type')
if out_type:
return op.greater(children[0], children[1]).astype(out_type)
else:
return op.greater(children[0], children[1])


def _greater_equal(children, attrs, odtype='float32'):
out_type = attrs.get_str('out_type', None)
if out_type:
return op.greater_equal(children[0], children[1]).astype(out_type)
else:
return op.greater_equal(children[0], children[1])


def _less(children, attrs, odtype='float32'):
out_type = attrs.get_str('out_type', None)
if out_type:
return op.less(children[0], children[1]).astype(out_type)
else:
return op.less(children[0], children[1])


def _less_equal(children, attrs, odtype='float32'):
out_type = attrs.get_str('out_type', None)
if out_type:
return op.less_equal(children[0], children[1]).astype(out_type)
else:
return op.less_equal(children[0], children[1])


def _strided_slice(children, attrs, odtype='float32'):
begin = attrs.get_int_list('begin')
end = attrs.get_int_list('end')
strides = attrs.get_int_list('strides', None)
strides = attrs.get_int_list('stride', None)
return op.strided_slice(children[0], begin, end, strides=strides)


Expand All @@ -358,14 +241,11 @@ def _split(children, attrs, odtype='float32'):

axis = attrs.get_int('axis', 0)

return op.split(children[0], indices_or_sections, axis)
return op.split(children[0], indices_or_sections, axis).astuple()

def _squeeze(children, attrs, odtype='float32'):
axis = None
try:
axis = [attrs.get_int('axis', None)]
except ValueError:
axis = axis or attrs.get_int_tuple('axis', None)
axis = attrs.get_int_tuple('axis', None)
axis = [axis] if isinstance(axis, int) else axis

return op.squeeze(children[0], axis)

Expand All @@ -378,20 +258,60 @@ def _dropout(children, attrs, odtype='float32'):
return op.nn.dropout(children[0], rate)

def _mean(children, attrs, odtype='float32'):
axis = None
try:
axis = [attrs.get_int('axis', None)]
except ValueError:
axis = axis or attrs.get_int_tuple('axis', None)
axis = attrs.get_int_tuple('axis', None)
keepdims = attrs.get_bool('keepdims')

return op.mean(children[0], axis, keepdims)


def _prelu(children, attrs, odtype='float32'):
axis = attrs.get_int('axis', 1)
return op.nn.prelu(children[0], children[1], axis)


def _lrn(children, attrs, odtype='float32'):
size = attrs.get_int("size", 5)
axis = attrs.get_int("axis", 1)
bias = attrs.get_float("bias", 2)
alpha = attrs.get_float("alpha", 1e-05)
beta = attrs.get_float("beta", 0.75)
return op.nn.lrn(children[0], size, axis, bias, alpha, beta)


def _l2_nomalize(children, attrs, odtype='float32'):
eps = attrs.get_float('eps')
axis = attrs.get_int_tuple('axis', None)
return op.nn.l2_normalize(children[0], eps, axis)


def _take(children, attrs, odtype='float32'):
axis = attrs.get_int('axis', None)
return op.take(children[0], children[1], axis)


def _matmul(children, attrs, odtype='float32'):
input_1_t = op.transpose(children[1], axes=(1, 0))
return op.nn.dense(children[0], input_1_t)


def _collapse_sum(children, attrs, odtype='float32'):
for key in ["axis", "keepdims", "exclude"]:
if key in attrs.attrs:
raise NotImplementedError("Parameter '" + key + "' is not supported.")
return op.collapse_sum_like(children[0], children[1])


def _not_implemented(new_op):
def _impl(children, attrs, odtype='float32'):
raise NotImplementedError(str(new_op) + " is not implemented.")
return _impl


NNVM_OP_2_RELAY_OP = {
'flatten': _nn_batch_flatten,
'dense': _dense,
'softmax': _nn_softmax,
'softmax': _softmax_op(op.nn.softmax),
'log_softmax': _softmax_op(op.nn.log_softmax),
'conv2d': _conv2d,
'batch_norm': _batch_norm,
'max_pool2d': _max_pool2d,
Expand All @@ -400,30 +320,47 @@ def _mean(children, attrs, odtype='float32'):
'dropout': _dropout,
'mean': _mean,
# Addition
'__add_scalar__': _add,
'broadcast_add': _add,
'elemwise_add': _add,
'__add_scalar__': _binop_scalar(op.add),
'broadcast_add' : _rename(op.add),
'elemwise_add' : _rename(op.add),
# Subtraction
'__sub_scalar__': _subtract,
'__rsub_scalar__': _rsubtract,
'broadcast_sub': _subtract,
'elemwise_sub': _subtract,
'__sub_scalar__' : _binop_scalar(op.subtract),
'__rsub_scalar__': _rbinop_scalar(op.subtract),
'broadcast_sub' : _rename(op.subtract),
'elemwise_sub' : _rename(op.subtract),
# Multiply
'__mul_scalar__': _multiply,
'broadcast_mul': _multiply,
'elemwise_mul': _multiply,
'__mul_scalar__': _binop_scalar(op.multiply),
'broadcast_mul' : _rename(op.multiply),
'elemwise_mul' : _rename(op.multiply),
# Division
'__div_scalar__': _divide,
'broadcast_div': _divide,
'elemwise_div': _divide,
'__div_scalar__': _binop_scalar(op.divide),
'broadcast_div' : _rename(op.divide),
'elemwise_div' : _rename(op.divide),
'broadcast_mod' : _rename(op.mod),
# Negative
'negative': _rename("negative"),
# Power
'__pow_scalar__': _binop_scalar(op.power),
'__rpow_scalar__': _rbinop_scalar(op.power),
'broadcast_pow': _rename(op.power),
# Sum
'sum': _reduce(op.sum),
'elemwise_sum': _elemwise_sum,
'collapse_sum': _collapse_sum,
'broadcast_max': _rename(op.maximum),
'broadcast_min': _rename(op.minimum),

# Comparsion
'greater': _greater,
'greater_equal': _greater_equal,
'less': _less,
'less_equal': _less_equal,
'greater': _compare(op.greater),
'broadcast_greater': _compare(op.greater),
'greater_equal': _compare(op.greater_equal),
'broadcast_greater_equal': _compare(op.greater_equal),
'less': _compare(op.less),
'broadcast_less': _compare(op.less),
'less_equal': _compare(op.less_equal),
'broadcast_less_equal': _compare(op.less_equal),
'broadcast_equal': _compare(op.equal),
'broadcast_not_equal': _compare(op.not_equal),

# Activations
'sigmoid': _rename('sigmoid'),
Expand All @@ -432,13 +369,17 @@ def _mean(children, attrs, odtype='float32'):
'log': _rename('log'),
'tanh': _rename('tanh'),
'leaky_relu': _leaky_relu,
'prelu': _prelu,
'clip': _clip,
'round': _rename('round'),
'cast': _cast,
'expand_dims': _expand_dims,
'broadcast_to': broadcast_to,
'__rshift_scalar__': _rshift,
'copy': _copy,
'__lshift_scalar__': _binop_scalar(op.left_shift),
'__rshift_scalar__': _binop_scalar(op.right_shift),
'broadcast_left_shift': _rename(op.left_shift),
'broadcast_right_shift': _rename(op.right_shift),
'copy': _rename(op.copy),
'global_avg_pool2d': _global_avg_pool2d,
'avg_pool2d': _avg_pool2d,
'conv2d_transpose': _conv2d_transpose,
Expand All @@ -449,6 +390,21 @@ def _mean(children, attrs, odtype='float32'):
'split': _split,
'squeeze': _squeeze,
'concatenate': _concatenate,
'abs': _rename(op.abs),
'ceil': _rename(op.ceil),
'floor': _rename(op.floor),
'trunc': _rename(op.trunc),
'take': _take,
'lrn': _lrn,
'l2_normalize': _l2_nomalize,
'matmul': _matmul,
'zeros_like': _rename(op.zeros_like),
'reshape_like': _rename(op.reshape_like),
'ones_like': _rename(op.ones_like),

'expand_like': _not_implemented("expand_like"),
'gather_nd': _not_implemented("gather_nd"),
'block_grad': _not_implemented("block_grad"),
}


Expand Down
Loading

0 comments on commit 2db78d4

Please sign in to comment.