Skip to content

Commit

Permalink
[NNVM] Introduce const shift ops (apache#1325)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Jun 24, 2018
1 parent 5cdfd5b commit 83ce1d6
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 13 deletions.
2 changes: 1 addition & 1 deletion HalideIR
4 changes: 4 additions & 0 deletions docs/nnvm_top.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ This level enables typical convnet models.
nnvm.symbol.__rdiv_scalar__
nnvm.symbol.__pow_scalar__
nnvm.symbol.__rpow_scalar__
nnvm.symbol.__lshift_scalar__
nnvm.symbol.__rshift_scalar__

**Level 4: Broadcast and Reductions**

Expand Down Expand Up @@ -164,6 +166,8 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.__rdiv_scalar__
.. autofunction:: nnvm.symbol.__pow_scalar__
.. autofunction:: nnvm.symbol.__rpow_scalar__
.. autofunction:: nnvm.symbol.__lshift_scalar__
.. autofunction:: nnvm.symbol.__rshift_scalar__

.. autofunction:: nnvm.symbol.transpose
.. autofunction:: nnvm.symbol.broadcast_to
Expand Down
14 changes: 14 additions & 0 deletions nnvm/python/nnvm/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,20 @@ def __rdiv__(self, other):
else:
raise TypeError('type %s not supported' % str(type(other)))

def __lshift__(self, other):
"""x.__lshift__(y) <=> x << y"""
if isinstance(other, _Number):
return __lshift_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))

def __rshift__(self, other):
"""x.__rshift__(y) <=> x >> y"""
if isinstance(other, _Number):
return __rshift_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))

def __truediv__(self, other):
return self.__div__(other)

Expand Down
8 changes: 8 additions & 0 deletions nnvm/python/nnvm/top/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ def compute_cast(attrs, inputs, _):
reg.register_pattern("__rpow_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rpow_scalar__", _fschedule_broadcast)

# lshift_scalar
reg.register_pattern("__lshift_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__lshift_scalar__", _fschedule_broadcast)

# rshift_scalar
reg.register_pattern("__rshift_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rshift_scalar__", _fschedule_broadcast)

# elemwise_add
reg.register_pattern("elemwise_add", OpPattern.BROADCAST)
reg.register_schedule("elemwise_add", _fschedule_broadcast)
Expand Down
33 changes: 33 additions & 0 deletions nnvm/src/top/tensor/elemwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,39 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rsub_scalar__)
};
});


NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__lshift_scalar__)
.describe(R"code(Tensor left shift by scalar
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ScalarParam& param = nnvm::get<ScalarParam>(attrs.parsed);
int scalar_val = static_cast<int>(param.scalar);
return Array<Tensor>{
topi::left_shift(inputs[0],
make_const(inputs[0]->dtype, scalar_val))};
});

NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rshift_scalar__)
.describe(R"code(Tensor right shift by scalar
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ScalarParam& param = nnvm::get<ScalarParam>(attrs.parsed);
int scalar_val = static_cast<int>(param.scalar);
return Array<Tensor>{
topi::right_shift(inputs[0],
make_const(inputs[0]->dtype, scalar_val))};
});

NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__mul_scalar__)
.describe(R"code(Tensor multiplies scalar
Expand Down
14 changes: 9 additions & 5 deletions nnvm/tests/python/compiler/test_top_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,21 @@
from nnvm.testing.config import ctx_list

def helper(symbol, inputs, dtype,
np_forward, np_backward=None, need_input=True, need_head_grads=True):
np_forward, np_backward=None,
need_input=True, need_head_grads=True,
rnd_min=-1, rnd_max=1):
ishapes = {}
itypes = {}
input_syms = []
np_inputs = {}
for (name, shape, s) in inputs:
ishapes.update({name: shape})
np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
itypes.update({name: dtype})
np_inputs.update({name: np.random.uniform(rnd_min, rnd_max, size=shape).astype(dtype)})
input_syms.append(s)

for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes)
graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes, itypes)
m = graph_runtime.create(graph, lib, ctx)
m.run(**np_inputs)
y_np = np_forward(**np_inputs)
Expand Down Expand Up @@ -164,7 +168,7 @@ def backward(head_grads, x):
dtype = "float32"
dshape = (1, 3, 32, 32)
inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward)
helper(y, inputs, dtype, forward, backward, rnd_min=0.001)


def test_tanh():
Expand Down Expand Up @@ -277,7 +281,7 @@ def forward(x, gamma, beta, moving_mean, moving_var):
('moving_var', (20,), moving_mean)
]

helper(y, inputs, dtype, forward)
helper(y, inputs, dtype, forward, rnd_min=0.001)


def verify_concatenate(ishape, axis):
Expand Down
13 changes: 10 additions & 3 deletions nnvm/tests/python/compiler/test_top_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from nnvm.testing.config import ctx_list
from test_top_level1 import helper

def check_map(symfunc, np_func, np_backward=None):
def check_map(symfunc, np_func, np_backward=None, dtype="float32", rnd_min=-1, rnd_max=1):
x = sym.Variable("x")
y = symfunc(x)
dtype = "float32"
dshape = (1, 3, 32, 32)
inputs = [('x', dshape, x)]
helper(y, inputs, dtype, lambda x: np_func(x), np_backward)
helper(y, inputs, dtype, lambda x: np_func(x), np_backward,
rnd_min=rnd_min, rnd_max=rnd_max)


def test_floor():
Expand All @@ -29,7 +29,14 @@ def test_round():
check_map(sym.round, np.round)


def test_shift():
n = 3
for dtype in ["int32", "int8"]:
check_map(lambda x : x >> n, lambda x: x >> n, dtype=dtype, rnd_min=-100, rnd_max=100)
check_map(lambda x : x << n, lambda x: x << n, dtype=dtype, rnd_min=-100, rnd_max=100)

if __name__ == "__main__":
test_shift()
test_floor()
test_ceil()
test_round()
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def right_shift(lhs, rhs):
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.left_shift(lhs, rhs)
return _cpp.right_shift(lhs, rhs)


def greater(lhs, rhs):
Expand Down
20 changes: 17 additions & 3 deletions topi/tests/python/test_topi_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def check_device(device):
if rhs_shape is None:
rhs_npy = float(np.random.uniform(low=rhs_min, high=rhs_max))
if dtype.startswith('int'):
lhs_npy = int(lhs_npy)
rhs_npy = int(rhs_npy)
rhs_nd = rhs_npy
else:
rhs_npy = np.random.uniform(low=rhs_min, high=rhs_max,
Expand All @@ -77,8 +77,7 @@ def check_device(device):

out_npy = fnumpy(lhs_npy, rhs_npy)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
for _ in range(1):
foo(lhs_nd, rhs_nd, out_nd)
foo(lhs_nd, rhs_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)

check_device("opencl")
Expand Down Expand Up @@ -142,8 +141,23 @@ def less(x, y):
verify_broadcast_binary_ele(
(2, 1, 2), (2, 3, 1), less, np.less)

def test_shift():
# explicit specify the output type
verify_broadcast_binary_ele(
(2, 1, 2), None, topi.right_shift, np.right_shift,
dtype="int32", rhs_min=0, rhs_max=32)

verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.left_shift, np.left_shift,
dtype="int32", rhs_min=0, rhs_max=32)

verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.left_shift, np.left_shift,
dtype="int8", rhs_min=0, rhs_max=32)


if __name__ == "__main__":
test_shift()
test_cmp()
test_mod()
test_add()
Expand Down

0 comments on commit 83ce1d6

Please sign in to comment.