diff --git a/HalideIR b/HalideIR index 0b7e25275138..9204453ae8de 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit 0b7e25275138768bb05edb9b9db2c86d0fb09c9a +Subproject commit 9204453ae8de77e7dfc32c4d80f58dd788ad75ff diff --git a/docs/nnvm_top.rst b/docs/nnvm_top.rst index 5f3fd9fa1b0b..4e1e536dbb26 100644 --- a/docs/nnvm_top.rst +++ b/docs/nnvm_top.rst @@ -88,6 +88,8 @@ This level enables typical convnet models. nnvm.symbol.__rdiv_scalar__ nnvm.symbol.__pow_scalar__ nnvm.symbol.__rpow_scalar__ + nnvm.symbol.__lshift_scalar__ + nnvm.symbol.__rshift_scalar__ **Level 4: Broadcast and Reductions** @@ -164,6 +166,8 @@ Detailed Definitions .. autofunction:: nnvm.symbol.__rdiv_scalar__ .. autofunction:: nnvm.symbol.__pow_scalar__ .. autofunction:: nnvm.symbol.__rpow_scalar__ +.. autofunction:: nnvm.symbol.__lshift_scalar__ +.. autofunction:: nnvm.symbol.__rshift_scalar__ .. autofunction:: nnvm.symbol.transpose .. autofunction:: nnvm.symbol.broadcast_to diff --git a/nnvm/python/nnvm/symbol.py b/nnvm/python/nnvm/symbol.py index bc479a65e03b..6997ecc64654 100644 --- a/nnvm/python/nnvm/symbol.py +++ b/nnvm/python/nnvm/symbol.py @@ -100,6 +100,20 @@ def __rdiv__(self, other): else: raise TypeError('type %s not supported' % str(type(other))) + def __lshift__(self, other): + """x.__lshift__(y) <=> x << y""" + if isinstance(other, _Number): + return __lshift_scalar__(self, scalar=other) + else: + raise TypeError('type %s not supported' % str(type(other))) + + def __rshift__(self, other): + """x.__rshift__(y) <=> x >> y""" + if isinstance(other, _Number): + return __rshift_scalar__(self, scalar=other) + else: + raise TypeError('type %s not supported' % str(type(other))) + def __truediv__(self, other): return self.__div__(other) diff --git a/nnvm/python/nnvm/top/tensor.py b/nnvm/python/nnvm/top/tensor.py index c11486f5c77f..0250f6ddfad8 100644 --- a/nnvm/python/nnvm/top/tensor.py +++ b/nnvm/python/nnvm/top/tensor.py @@ -133,6 +133,14 @@ def compute_cast(attrs, inputs, _): reg.register_pattern("__rpow_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__rpow_scalar__", _fschedule_broadcast) +# lshift_scalar +reg.register_pattern("__lshift_scalar__", OpPattern.ELEMWISE) +reg.register_schedule("__lshift_scalar__", _fschedule_broadcast) + +# rshift_scalar +reg.register_pattern("__rshift_scalar__", OpPattern.ELEMWISE) +reg.register_schedule("__rshift_scalar__", _fschedule_broadcast) + # elemwise_add reg.register_pattern("elemwise_add", OpPattern.BROADCAST) reg.register_schedule("elemwise_add", _fschedule_broadcast) diff --git a/nnvm/src/top/tensor/elemwise.cc b/nnvm/src/top/tensor/elemwise.cc index cff85497be4f..cd38817b51a5 100644 --- a/nnvm/src/top/tensor/elemwise.cc +++ b/nnvm/src/top/tensor/elemwise.cc @@ -512,6 +512,39 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rsub_scalar__) }; }); + +NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__lshift_scalar__) +.describe(R"code(Tensor left shift by scalar + +)code" NNVM_ADD_FILELINE) +.set_support_level(3) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + const ScalarParam& param = nnvm::get(attrs.parsed); + int scalar_val = static_cast(param.scalar); + return Array{ + topi::left_shift(inputs[0], + make_const(inputs[0]->dtype, scalar_val))}; + }); + +NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__rshift_scalar__) +.describe(R"code(Tensor right shift by scalar + +)code" NNVM_ADD_FILELINE) +.set_support_level(3) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + const ScalarParam& param = nnvm::get(attrs.parsed); + int scalar_val = static_cast(param.scalar); + return Array{ + topi::right_shift(inputs[0], + make_const(inputs[0]->dtype, scalar_val))}; + }); + NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__mul_scalar__) .describe(R"code(Tensor multiplies scalar diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py index 37798d37f400..73391c80dd12 100644 --- a/nnvm/tests/python/compiler/test_top_level1.py +++ b/nnvm/tests/python/compiler/test_top_level1.py @@ -7,17 +7,21 @@ from nnvm.testing.config import ctx_list def helper(symbol, inputs, dtype, - np_forward, np_backward=None, need_input=True, need_head_grads=True): + np_forward, np_backward=None, + need_input=True, need_head_grads=True, + rnd_min=-1, rnd_max=1): ishapes = {} + itypes = {} input_syms = [] np_inputs = {} for (name, shape, s) in inputs: ishapes.update({name: shape}) - np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)}) + itypes.update({name: dtype}) + np_inputs.update({name: np.random.uniform(rnd_min, rnd_max, size=shape).astype(dtype)}) input_syms.append(s) for target, ctx in ctx_list(): - graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes) + graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes, itypes) m = graph_runtime.create(graph, lib, ctx) m.run(**np_inputs) y_np = np_forward(**np_inputs) @@ -164,7 +168,7 @@ def backward(head_grads, x): dtype = "float32" dshape = (1, 3, 32, 32) inputs = [('x', dshape, x)] - helper(y, inputs, dtype, forward, backward) + helper(y, inputs, dtype, forward, backward, rnd_min=0.001) def test_tanh(): @@ -277,7 +281,7 @@ def forward(x, gamma, beta, moving_mean, moving_var): ('moving_var', (20,), moving_mean) ] - helper(y, inputs, dtype, forward) + helper(y, inputs, dtype, forward, rnd_min=0.001) def verify_concatenate(ishape, axis): diff --git a/nnvm/tests/python/compiler/test_top_level3.py b/nnvm/tests/python/compiler/test_top_level3.py index 27a99bfb530f..125836a7848e 100644 --- a/nnvm/tests/python/compiler/test_top_level3.py +++ b/nnvm/tests/python/compiler/test_top_level3.py @@ -7,13 +7,13 @@ from nnvm.testing.config import ctx_list from test_top_level1 import helper -def check_map(symfunc, np_func, np_backward=None): +def check_map(symfunc, np_func, np_backward=None, dtype="float32", rnd_min=-1, rnd_max=1): x = sym.Variable("x") y = symfunc(x) - dtype = "float32" dshape = (1, 3, 32, 32) inputs = [('x', dshape, x)] - helper(y, inputs, dtype, lambda x: np_func(x), np_backward) + helper(y, inputs, dtype, lambda x: np_func(x), np_backward, + rnd_min=rnd_min, rnd_max=rnd_max) def test_floor(): @@ -29,7 +29,14 @@ def test_round(): check_map(sym.round, np.round) +def test_shift(): + n = 3 + for dtype in ["int32", "int8"]: + check_map(lambda x : x >> n, lambda x: x >> n, dtype=dtype, rnd_min=-100, rnd_max=100) + check_map(lambda x : x << n, lambda x: x << n, dtype=dtype, rnd_min=-100, rnd_max=100) + if __name__ == "__main__": + test_shift() test_floor() test_ceil() test_round() diff --git a/topi/python/topi/broadcast.py b/topi/python/topi/broadcast.py index caf5c77f5629..f088e48b0f14 100644 --- a/topi/python/topi/broadcast.py +++ b/topi/python/topi/broadcast.py @@ -210,7 +210,7 @@ def right_shift(lhs, rhs): Returns Expr if both operands are Expr. Otherwise returns Tensor. """ - return _cpp.left_shift(lhs, rhs) + return _cpp.right_shift(lhs, rhs) def greater(lhs, rhs): diff --git a/topi/tests/python/test_topi_broadcast.py b/topi/tests/python/test_topi_broadcast.py index 05a32ca0eac5..27e72e327232 100644 --- a/topi/tests/python/test_topi_broadcast.py +++ b/topi/tests/python/test_topi_broadcast.py @@ -68,7 +68,7 @@ def check_device(device): if rhs_shape is None: rhs_npy = float(np.random.uniform(low=rhs_min, high=rhs_max)) if dtype.startswith('int'): - lhs_npy = int(lhs_npy) + rhs_npy = int(rhs_npy) rhs_nd = rhs_npy else: rhs_npy = np.random.uniform(low=rhs_min, high=rhs_max, @@ -77,8 +77,7 @@ def check_device(device): out_npy = fnumpy(lhs_npy, rhs_npy) out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx) - for _ in range(1): - foo(lhs_nd, rhs_nd, out_nd) + foo(lhs_nd, rhs_nd, out_nd) np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4) check_device("opencl") @@ -142,8 +141,23 @@ def less(x, y): verify_broadcast_binary_ele( (2, 1, 2), (2, 3, 1), less, np.less) +def test_shift(): + # explicit specify the output type + verify_broadcast_binary_ele( + (2, 1, 2), None, topi.right_shift, np.right_shift, + dtype="int32", rhs_min=0, rhs_max=32) + + verify_broadcast_binary_ele( + (1, 2, 2), (2,), topi.left_shift, np.left_shift, + dtype="int32", rhs_min=0, rhs_max=32) + + verify_broadcast_binary_ele( + (1, 2, 2), (2,), topi.left_shift, np.left_shift, + dtype="int8", rhs_min=0, rhs_max=32) + if __name__ == "__main__": + test_shift() test_cmp() test_mod() test_add()