Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve NNVM to Relay conversion #2734

Merged
merged 3 commits into from
Mar 8, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions nnvm/python/nnvm/testing/check_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import tvm
from tvm.contrib import graph_runtime
from tvm.testing import check_numerical_grads
from tvm import relay

import nnvm
from nnvm.compiler import graph_util
from nnvm.compiler.graph_attr import TCODE_TO_DTYPE, DTYPE_TO_TCODE
from nnvm.to_relay import to_relay
from .config import ctx_list

def infer_shapes_dtypes(graph, shape=None, dtype=None, fallback_dtype=None):
Expand Down Expand Up @@ -441,6 +443,23 @@ def check_function(symbol, forward=None, backward=None, grad_input_vars=None,
debug_stage = "running"
nnvm_res = main_function(**np_inputs)

try:
logging.debug("checking to_relay conversion")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this print every time this code is executed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it depends on the logging configuration, but, with our CI, this message will be printed only when an error occurs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, seems ok

inputs = np_inputs_without_head_grads.copy()
func, inputs = to_relay(main_graph, shape, dtype, params=inputs)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target=target)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**inputs)
m.set_input(**params)
m.run()
for i in range(out_len):
relay_out = m.get_output(i).asnumpy()
tvm.testing.assert_allclose(nnvm_res[i], relay_out, atol=atol, rtol=rtol)
except NotImplementedError as err:
# the NNVM operator is not supported yet
logging.warning(err)

if backward_graph is not None:
grad_var_names = [x.attr('name') for x in grad_input_vars]
nnvm_grads = {x: v for x, v in zip(grad_var_names, nnvm_res[out_len:])}
Expand Down
253 changes: 102 additions & 151 deletions nnvm/python/nnvm/to_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from tvm import relay, nd
from tvm.relay import op, expr, var
from tvm.relay.frontend.common import StrAttrsDict
from tvm.relay.frontend.nnvm_common import _rename
from tvm.relay.frontend.nnvm_common import _rename, _binop_scalar, _rbinop_scalar, \
_elemwise_sum, _softmax_op, _compare, _reduce
from .symbol import Symbol
from .compiler import graph_attr
from .graph import create as graph_create
Expand All @@ -25,11 +26,6 @@ def _dense(children, attrs, odtype='float32'):
else:
return dense

def _nn_softmax(children, attrs, odtype='float32'):
assert len(children) == 1
axis = attrs.get_int('axis', 1)
return op.nn.softmax(children[0], axis)

def _conv2d(children, attrs, odtype='float32'):
use_bias = attrs.get_bool('use_bias', True)

Expand Down Expand Up @@ -150,84 +146,6 @@ def _transpose(children, attrs, odtype='float32'):
return op.transpose(children[0], axes=axes)


def _add(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]

return op.add(left, right)


def _subtract(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]

return op.subtract(left, right)


def _rsubtract(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]

return op.subtract(right, left)


def _multiply(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]

return op.multiply(left, right)


def _divide(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype=odtype)
else:
assert len(children) == 2
left = children[0]
right = children[1]

return op.divide(left, right)


def _rshift(children, attrs, odtype='float32'):
if len(children) == 1:
left = children[0]
scalar = attrs.get_float('scalar')
right = relay.const(scalar, dtype='int32')
else:
assert len(children) == 2
left = children[0]
right = children[1]

return op.right_shift(left, right)


def _clip(children, attrs, odtype='float32'):
a_min = attrs.get_float('a_min')
a_max = attrs.get_float('a_max')
Expand Down Expand Up @@ -255,9 +173,6 @@ def broadcast_to(children, attrs, odtype='float32'):
rconst = relay.Constant(nd.array(array))
return op.broadcast_to_like(data, rconst)

def _copy(children, attrs, odtype='float32'):
return op.copy(children[0])


def _global_avg_pool2d(children, attrs, odtype='float32'):
data = children[0]
Expand Down Expand Up @@ -309,42 +224,10 @@ def _full_like(children, attrs, odtype='float32'):
return op.full_like(children[0], fill_value)


def _greater(children, attrs, odtype='float32'):
out_type = attrs.get_str('out_type')
if out_type:
return op.greater(children[0], children[1]).astype(out_type)
else:
return op.greater(children[0], children[1])


def _greater_equal(children, attrs, odtype='float32'):
out_type = attrs.get_str('out_type', None)
if out_type:
return op.greater_equal(children[0], children[1]).astype(out_type)
else:
return op.greater_equal(children[0], children[1])


def _less(children, attrs, odtype='float32'):
out_type = attrs.get_str('out_type', None)
if out_type:
return op.less(children[0], children[1]).astype(out_type)
else:
return op.less(children[0], children[1])


def _less_equal(children, attrs, odtype='float32'):
out_type = attrs.get_str('out_type', None)
if out_type:
return op.less_equal(children[0], children[1]).astype(out_type)
else:
return op.less_equal(children[0], children[1])


def _strided_slice(children, attrs, odtype='float32'):
begin = attrs.get_int_list('begin')
end = attrs.get_int_list('end')
strides = attrs.get_int_list('strides', None)
strides = attrs.get_int_list('stride', None)
return op.strided_slice(children[0], begin, end, strides=strides)


Expand All @@ -358,14 +241,11 @@ def _split(children, attrs, odtype='float32'):

axis = attrs.get_int('axis', 0)

return op.split(children[0], indices_or_sections, axis)
return op.split(children[0], indices_or_sections, axis).astuple()

def _squeeze(children, attrs, odtype='float32'):
axis = None
try:
axis = [attrs.get_int('axis', None)]
except ValueError:
axis = axis or attrs.get_int_tuple('axis', None)
axis = attrs.get_int_tuple('axis', None)
axis = [axis] if isinstance(axis, int) else axis

return op.squeeze(children[0], axis)

Expand All @@ -378,20 +258,60 @@ def _dropout(children, attrs, odtype='float32'):
return op.nn.dropout(children[0], rate)

def _mean(children, attrs, odtype='float32'):
axis = None
try:
axis = [attrs.get_int('axis', None)]
except ValueError:
axis = axis or attrs.get_int_tuple('axis', None)
axis = attrs.get_int_tuple('axis', None)
keepdims = attrs.get_bool('keepdims')

return op.mean(children[0], axis, keepdims)


def _prelu(children, attrs, odtype='float32'):
axis = attrs.get_int('axis', 1)
return op.nn.prelu(children[0], children[1], axis)


def _lrn(children, attrs, odtype='float32'):
size = attrs.get_int("size", 5)
axis = attrs.get_int("axis", 1)
bias = attrs.get_float("bias", 2)
alpha = attrs.get_float("alpha", 1e-05)
beta = attrs.get_float("beta", 0.75)
return op.nn.lrn(children[0], size, axis, bias, alpha, beta)


def _l2_nomalize(children, attrs, odtype='float32'):
eps = attrs.get_float('eps')
axis = attrs.get_int_tuple('axis', None)
return op.nn.l2_normalize(children[0], eps, axis)


def _take(children, attrs, odtype='float32'):
axis = attrs.get_int('axis', None)
return op.take(children[0], children[1], axis)


def _matmul(children, attrs, odtype='float32'):
input_1_t = op.transpose(children[1], axes=(1, 0))
return op.nn.dense(children[0], input_1_t)


def _collapse_sum(children, attrs, odtype='float32'):
for key in ["axis", "keepdims", "exclude"]:
if key in attrs.attrs:
raise NotImplementedError("Parameter '" + key + "' is not supported.")
return op.collapse_sum_like(children[0], children[1])


def _not_implemented(new_op):
def _impl(children, attrs, odtype='float32'):
raise NotImplementedError(str(new_op) + " is not implemented.")
return _impl


NNVM_OP_2_RELAY_OP = {
'flatten': _nn_batch_flatten,
'dense': _dense,
'softmax': _nn_softmax,
'softmax': _softmax_op(op.nn.softmax),
'log_softmax': _softmax_op(op.nn.log_softmax),
'conv2d': _conv2d,
'batch_norm': _batch_norm,
'max_pool2d': _max_pool2d,
Expand All @@ -400,30 +320,47 @@ def _mean(children, attrs, odtype='float32'):
'dropout': _dropout,
'mean': _mean,
# Addition
'__add_scalar__': _add,
'broadcast_add': _add,
'elemwise_add': _add,
'__add_scalar__': _binop_scalar(op.add),
'broadcast_add' : _rename(op.add),
'elemwise_add' : _rename(op.add),
# Subtraction
'__sub_scalar__': _subtract,
'__rsub_scalar__': _rsubtract,
'broadcast_sub': _subtract,
'elemwise_sub': _subtract,
'__sub_scalar__' : _binop_scalar(op.subtract),
'__rsub_scalar__': _rbinop_scalar(op.subtract),
'broadcast_sub' : _rename(op.subtract),
'elemwise_sub' : _rename(op.subtract),
# Multiply
'__mul_scalar__': _multiply,
'broadcast_mul': _multiply,
'elemwise_mul': _multiply,
'__mul_scalar__': _binop_scalar(op.multiply),
'broadcast_mul' : _rename(op.multiply),
'elemwise_mul' : _rename(op.multiply),
# Division
'__div_scalar__': _divide,
'broadcast_div': _divide,
'elemwise_div': _divide,
'__div_scalar__': _binop_scalar(op.divide),
'broadcast_div' : _rename(op.divide),
'elemwise_div' : _rename(op.divide),
'broadcast_mod' : _rename(op.mod),
# Negative
'negative': _rename("negative"),
# Power
'__pow_scalar__': _binop_scalar(op.power),
'__rpow_scalar__': _rbinop_scalar(op.power),
'broadcast_pow': _rename(op.power),
# Sum
'sum': _reduce(op.sum),
'elemwise_sum': _elemwise_sum,
'collapse_sum': _collapse_sum,
'broadcast_max': _rename(op.maximum),
'broadcast_min': _rename(op.minimum),

# Comparsion
'greater': _greater,
'greater_equal': _greater_equal,
'less': _less,
'less_equal': _less_equal,
'greater': _compare(op.greater),
'broadcast_greater': _compare(op.greater),
'greater_equal': _compare(op.greater_equal),
'broadcast_greater_equal': _compare(op.greater_equal),
'less': _compare(op.less),
'broadcast_less': _compare(op.less),
'less_equal': _compare(op.less_equal),
'broadcast_less_equal': _compare(op.less_equal),
'broadcast_equal': _compare(op.equal),
'broadcast_not_equal': _compare(op.not_equal),

# Activations
'sigmoid': _rename('sigmoid'),
Expand All @@ -432,13 +369,16 @@ def _mean(children, attrs, odtype='float32'):
'log': _rename('log'),
'tanh': _rename('tanh'),
'leaky_relu': _leaky_relu,
'prelu': _prelu,
'clip': _clip,
'round': _rename('round'),
'cast': _cast,
'expand_dims': _expand_dims,
'broadcast_to': broadcast_to,
'__rshift_scalar__': _rshift,
'copy': _copy,
'__rshift_scalar__': _binop_scalar(op.right_shift),
'broadcast_left_shift': _rename(op.left_shift),
'broadcast_right_shift': _rename(op.right_shift),
'copy': _rename(op.copy),
'global_avg_pool2d': _global_avg_pool2d,
'avg_pool2d': _avg_pool2d,
'conv2d_transpose': _conv2d_transpose,
Expand All @@ -449,6 +389,17 @@ def _mean(children, attrs, odtype='float32'):
'split': _split,
'squeeze': _squeeze,
'concatenate': _concatenate,
'take': _take,
'lrn': _lrn,
'l2_normalize': _l2_nomalize,
'matmul': _matmul,
'zeros_like': _rename(op.zeros_like),
'reshape_like': _rename(op.reshape_like),
'ones_like': _rename(op.ones_like),

'expand_like': _not_implemented("expand_like"),
'gather_nd': _not_implemented("gather_nd"),
'block_grad': _not_implemented("block_grad"),
}


Expand Down
Loading