From 98262b9dba57841bc5f483d5e89a20f14ed0b79b Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Tue, 19 Mar 2019 16:51:56 -0700 Subject: [PATCH 1/6] Update take --- 3rdparty/HalideIR | 2 +- include/tvm/relay/attrs/transform.h | 5 ++ python/tvm/relay/frontend/mxnet.py | 10 ++++ python/tvm/relay/op/transform.py | 9 +++- src/relay/op/tensor/transform.cc | 10 ++-- tests/python/frontend/mxnet/test_forward.py | 22 ++++++++- tests/python/relay/test_op_level3.py | 12 +++-- topi/include/topi/transform.h | 54 ++++++++++++++++----- topi/python/topi/transform.py | 11 +++-- topi/src/topi.cc | 8 +-- topi/tests/python/test_topi_transform.py | 16 ++++-- 11 files changed, 126 insertions(+), 33 deletions(-) diff --git a/3rdparty/HalideIR b/3rdparty/HalideIR index 86351c40824d..55ba1778fd26 160000 --- a/3rdparty/HalideIR +++ b/3rdparty/HalideIR @@ -1 +1 @@ -Subproject commit 86351c40824dfc4cbb7447d70e5e63d9bd76eb90 +Subproject commit 55ba1778fd264c7507953552d8e51212ed11f748 diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index af4938236054..25b40dcd40ac 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -75,10 +75,15 @@ struct ReshapeAttrs : public tvm::AttrsNode { struct TakeAttrs : public tvm::AttrsNode { Integer axis; + std::string mode; TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") { TVM_ATTR_FIELD(axis).set_default(NullValue()) .describe("The axis over which to select values."); + TVM_ATTR_FIELD(mode).set_default("CLIP") + .describe("Specify how out-of-bound indices will behave." + "CLIP - clip to the range (default)" + "WRAP - wrap around the indices"); } }; diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 47ad5d7a1fa0..3c1922cde337 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -444,6 +444,15 @@ def _mx_tile(inputs, attrs): return _op.tile(inputs[0], **new_attrs) +def _mx_take(inputs, attrs): + assert len(inputs) == 2 + mode = attrs.get_str("mode", "clip") + if mode == "raise": + raise RuntimeError("take doesn't support raise mode") + axis = attrs.get_int("axis", 0) + return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode.upper()) + + def _mx_reverse(inputs, attrs): assert len(inputs) == 1 new_attrs = {} @@ -749,6 +758,7 @@ def _mx_deformable_convolution(inputs, attrs): "_full" : _mx_full, "repeat" : _mx_repeat, "tile" : _mx_tile, + "take" : _mx_take, "reverse" : _mx_reverse, "squeeze" : _mx_squeeze, "broadcast_axis": _mx_broadcast_axis, diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 37aace5afe4a..0df626134b39 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -186,7 +186,7 @@ def reshape_like(data, shape_like): return _make.reshape_like(data, shape_like) -def take(data, indices, axis=None): +def take(data, indices, axis=None, mode="CLIP"): """Take elements from an array along an axis. Parameters @@ -201,12 +201,17 @@ def take(data, indices, axis=None): The axis over which to select values. By default, the flattened input array is used. + mode : str, optional + Specifies how out-of-bound indices will behave. + CLIP - clip to the range (default) + WRAP - wrap around the indices + Returns ------- ret : relay.Expr The computed result. """ - return _make.take(data, indices, axis) + return _make.take(data, indices, axis, mode) def full(fill_value, shape=(), dtype=""): diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index a0ea8f2e60a3..08b06a2a084f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -753,24 +753,26 @@ Array TakeCompute(const Attrs& attrs, const auto* param = attrs.as(); CHECK(param != nullptr); if (!param->axis.defined()) { - return Array{ topi::take(inputs[0], inputs[1]) }; + return Array{ topi::take(inputs[0], inputs[1], param->mode) }; } else { - return Array{ topi::take(inputs[0], inputs[1], param->axis) }; + return Array{ topi::take(inputs[0], inputs[1], param->axis, param->mode) }; } } Expr MakeTake(Expr data, Expr indices, - Integer axis) { + Integer axis, + std::string mode) { auto attrs = make_node(); attrs->axis = std::move(axis); + attrs->mode = std::move(mode); static const Op& op = Op::Get("take"); return CallNode::make(op, {data, indices}, Attrs(attrs), {}); } TVM_REGISTER_API("relay.op._make.take") .set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeTake, args, rv); + runtime::detail::unpack_call(MakeTake, args, rv); }); RELAY_REGISTER_OP("take") diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index faccfbfd12fe..9d0d59402ecd 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -464,7 +464,6 @@ def verify(data_shape, weight_shape): verify((2, 2), (4, 5)) verify((2, 3, 4), (4, 5)) - def test_forward_smooth_l1(): data = mx.sym.var('data') mx_sym = mx.sym.smooth_l1(data) @@ -472,6 +471,26 @@ def test_forward_smooth_l1(): mx_sym = mx.sym.smooth_l1(data, scalar=1.0) verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4)) +def test_forward_take(): + def verify(shape, indices_src, axis, mode="clip"): + x_np = np.random.uniform(size=shape).astype("float32") + indices_np = np.array(indices_src, dtype="float32") + ref_res = mx.nd.take(mx.nd.array(x_np), mx.nd.array(indices_np), axis, mode) + mx_sym = mx.sym.take(mx.sym.var("x"), mx.sym.var("y"), axis, mode) + new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape, "y": indices_np.shape}) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(x_np, indices_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + verify((2,2), [[[1,0],[0,1]]], 0) + verify((2,2), [[[1,0],[0,1]]], 1) + verify((4,3,5,6), [[2,1,0,0]], -2) + verify((3,4), [-1, 5], 0) + verify((3,4), [-1, 5], 0, mode="wrap") + verify((3,4), [-1, 5], 1) + verify((3,4), [-1, 5], 1, mode="wrap") + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -507,3 +526,4 @@ def test_forward_smooth_l1(): test_forward_full() test_forward_embedding() test_forward_smooth_l1() + test_forward_take() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 10ace54e8b12..e088ed00c3c8 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -243,17 +243,17 @@ def verify_take(dshape, indices_shape, oshape, axis=None): verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2) def test_take(): - def verify_take(src_shape, indices_src, axis=None): + def verify_take(src_shape, indices_src, axis=None, mode="CLIP"): src_dtype = "float32" indices_dtype = "int32" indices_src = np.array(indices_src, dtype=indices_dtype) x = relay.var("x", relay.TensorType(src_shape, src_dtype)) indices = relay.var("indices", relay.TensorType(indices_src.shape, indices_dtype)) - z = relay.take(x, indices, axis=axis) + z = relay.take(x, indices, axis=axis, mode=mode) func = relay.Function([x, indices], z) x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype) - ref_res = np.take(x_data, indices=indices_src, axis=axis) + ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode.lower()) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: @@ -269,6 +269,12 @@ def verify_take(src_shape, indices_src, axis=None): verify_take((2,2), [[[1,0],[0,1]]], 0) verify_take((2,2), [[[1,0],[0,1]]], 1) verify_take((4,3,5,6), [[2,1,0,0]], -2) + verify_take((3,4), [-5, 20]) + verify_take((3,4), [-5, 20], mode="WRAP") + verify_take((3,4), [-1, 2], axis=0) + verify_take((3,4), [-1, 2], axis=0, mode="WRAP") + verify_take((3,4), [-1, 2], axis=1) + verify_take((3,4), [-1, 2], axis=1, mode="WRAP") def test_split_infer_type(): diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 464bd6facad5..620915f07903 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -604,22 +604,29 @@ inline Array split_sections(const Tensor& x, */ inline Tensor take(const Tensor& a, const Tensor& indices, + std::string mode = "CLIP", std::string name = "tensor", std::string tag = kInjective) { Array a_shape = a->shape; - Array out_shape; - for (size_t j = 0; j < indices->shape.size(); ++j) { - out_shape.push_back(indices->shape[j]); + Array out_shape = indices->shape; + Expr a_size = 1; + for (size_t i = 0; i < a_shape.size(); ++i) { + a_size = a_size * a_shape[i]; } - return compute( + if (mode == "CLIP") { + return compute( out_shape, [&](const Array& out_index) { - Array indices_position; - for (size_t j = 0; j < indices->shape.size(); ++j) { - indices_position.push_back(out_index[j]); - } - return a(UnravelIndex(indices(indices_position), a_shape)); + auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1); + return a(UnravelIndex(idx, a_shape)); + }, name, tag); + } else { // mode == "WRAP" + return compute( + out_shape, [&](const Array& out_index) { + auto idx = (indices(out_index) % a_size + a_size) % a_size; + return a(UnravelIndex(idx, a_shape)); }, name, tag); + } } /*! @@ -637,12 +644,15 @@ inline Tensor take(const Tensor& a, inline Tensor take(const Tensor& a, const Tensor& indices, int axis, + std::string mode = "CLIP", std::string name = "tensor", std::string tag = kInjective) { if (axis < 0) { axis += static_cast(a->shape.size()); } + CHECK_GE(axis, 0) << "axis out of bounds"; CHECK_LT(axis, a->shape.size()) << "axis out of bounds"; + auto axis_dim = a->shape[axis]; int indices_len = static_cast(indices->shape.size()); Array out_shape; @@ -655,7 +665,27 @@ inline Tensor take(const Tensor& a, out_shape.push_back(a->shape[i]); } } - return compute( + if (mode == "CLIP") { + return compute( + out_shape, [&](const Array& out_index) { + Array indices_position; + for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { + indices_position.push_back(out_index[j]); + } + Array real_indices; + for (size_t j = 0; j < static_cast(axis); ++j) { + real_indices.push_back(out_index[j]); + } + auto idx = tvm::min(tvm::max(0, indices(indices_position)), + axis_dim - 1); + real_indices.push_back(idx); + for (size_t j = axis + indices_len; j < out_index.size(); ++j) { + real_indices.push_back(out_index[j]); + } + return a(real_indices); + }, name, tag); + } else { // mode == "WRAP" + return compute( out_shape, [&](const Array& out_index) { Array indices_position; for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { @@ -665,12 +695,14 @@ inline Tensor take(const Tensor& a, for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } - real_indices.push_back(indices(indices_position)); + auto idx = (indices(indices_position) % axis_dim + axis_dim) % axis_dim; + real_indices.push_back(idx); for (size_t j = axis + indices_len; j < out_index.size(); ++j) { real_indices.push_back(out_index[j]); } return a(real_indices); }, name, tag); + } } /*! diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 2c109cd92c52..4c34ab2bcaf8 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -228,7 +228,7 @@ def split(ary, indices_or_sections, axis=0): return cpp.split(ary, indices_or_sections, axis) -def take(a, indices, axis=None): +def take(a, indices, axis=None, mode="CLIP"): """Take elements from an array along an axis. Parameters @@ -243,13 +243,18 @@ def take(a, indices, axis=None): The axis over which to select values. By default, the flattened input array is used. + mode : str, optional + Specifies how out-of-bound indices will behave. + CLIP - clip to the range (default) + WRAP - wrap around the indices + Returns ------- ret : tvm.Tensor """ if axis is None: - return cpp.take(a, indices) - return cpp.take(a, indices, int(axis)) + return cpp.take(a, indices, mode) + return cpp.take(a, indices, int(axis), mode) def gather_nd(a, indices): diff --git a/topi/src/topi.cc b/topi/src/topi.cc index aed2eab9c6bc..1df73d8f1a03 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -297,11 +297,13 @@ TVM_REGISTER_GLOBAL("topi.layout_transform") TVM_REGISTER_GLOBAL("topi.take") .set_body([](TVMArgs args, TVMRetValue *rv) { - if (args.size() == 2) { - *rv = take(args[0], args[1]); + if (args.size() == 3) { + std::string mode = args[2]; + *rv = take(args[0], args[1], mode); } else { int axis = args[2]; - *rv = take(args[0], args[1], axis); + std::string mode = args[3]; + *rv = take(args[0], args[1], axis, mode); } }); diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 59c1090480c2..340189ec39c8 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -232,16 +232,16 @@ def check_device(device): for device in ["llvm", "cuda", "opencl", "sdaccel", "aocl_sw_emu"]: check_device(device) -def verify_take(src_shape, indices_src, axis=None): +def verify_take(src_shape, indices_src, axis=None, mode="CLIP"): src_dtype = "float32" indices_dtype = "int32" indices_src = np.array(indices_src, dtype=indices_dtype) A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A") indices = tvm.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices") if axis is None: - out_tensor = topi.take(a=A, indices=indices) + out_tensor = topi.take(a=A, indices=indices, mode=mode) else: - out_tensor = topi.take(a=A, indices=indices, axis=axis) + out_tensor = topi.take(a=A, indices=indices, axis=axis, mode=mode) def check_device(device): ctx = tvm.context(device, 0) @@ -259,9 +259,9 @@ def check_device(device): data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape)) if axis is None: - out_npys = np.take(data_npy, indices_src) + out_npys = np.take(data_npy, indices_src, mode=mode.lower()) else: - out_npys = np.take(data_npy, indices_src, axis=axis) + out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode.lower()) data_nd = tvm.nd.array(data_npy, ctx) indices_nd = tvm.nd.array(indices_src, ctx) out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype) @@ -498,6 +498,12 @@ def test_take(): verify_take((2,2), [[[1,0],[0,1]]], 0) verify_take((2,2), [[[1,0],[0,1]]], 1) verify_take((4,3,5,6), [[2,1,0,0]], -2) + verify_take((3,4), [-5, 20]) + verify_take((3,4), [-5, 20], mode="WRAP") + verify_take((3,4), [-1, 2], axis=0) + verify_take((3,4), [-1, 2], axis=0, mode="WRAP") + verify_take((3,4), [-1, 2], axis=1) + verify_take((3,4), [-1, 2], axis=1, mode="WRAP") def test_gather_nd(): for indices_dtype in ['int32', 'float32']: From 76bc47339d9c498a9e12f635f5627c4f610146b4 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Wed, 20 Mar 2019 14:13:06 -0700 Subject: [PATCH 2/6] Add special case for canonical simplify and fix test cases --- tests/python/unittest/test_arith_simplify.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py index a327650fd045..9d2eeb7bea32 100644 --- a/tests/python/unittest/test_arith_simplify.py +++ b/tests/python/unittest/test_arith_simplify.py @@ -36,9 +36,14 @@ def test_simplify_mod(): with ib.for_range(0, 16, name="i") as i: A[i] = A[(j * 32 + i+1) % 16] body = ib.get() - stmt = tvm.ir_pass.CanonicalSimplify(body) - diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index - (1 + i) % 16) + stmt = tvm.ir_pass.CanonicalSimplify(body, {j: tvm.Range(0, 6), n: tvm.Range(0, 10)}) + diff = tvm.ir_pass.CanonicalSimplify(stmt.body.value.index - (1 + i) % 16) assert diff.value == 0 + # if we can't prove that j is non-negative, we can't prove that (j+16) % 16 is j%16 + index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16) + assert index != j + index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16, {j: tvm.Range(0, 6)}) + assert index == j # if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16 index = tvm.ir_pass.CanonicalSimplify( (j + n * 32) % 16, {j: tvm.Range(0, 6)}) From 16ffd4956c27808b156f27fa3c54440479860e85 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Fri, 29 Mar 2019 19:19:56 -0700 Subject: [PATCH 3/6] Use lower case for wrap and clip --- include/tvm/relay/attrs/transform.h | 6 +++--- python/tvm/relay/op/transform.py | 6 +++--- tests/python/relay/test_op_level3.py | 10 +++++----- topi/include/topi/transform.h | 12 ++++++------ topi/python/topi/transform.py | 6 +++--- topi/tests/python/test_topi_transform.py | 8 ++++---- 6 files changed, 24 insertions(+), 24 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 25b40dcd40ac..9f1720540599 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -80,10 +80,10 @@ struct TakeAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") { TVM_ATTR_FIELD(axis).set_default(NullValue()) .describe("The axis over which to select values."); - TVM_ATTR_FIELD(mode).set_default("CLIP") + TVM_ATTR_FIELD(mode).set_default("clip") .describe("Specify how out-of-bound indices will behave." - "CLIP - clip to the range (default)" - "WRAP - wrap around the indices"); + "clip - clip to the range (default)" + "wrap - wrap around the indices"); } }; diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 0df626134b39..73573043946c 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -186,7 +186,7 @@ def reshape_like(data, shape_like): return _make.reshape_like(data, shape_like) -def take(data, indices, axis=None, mode="CLIP"): +def take(data, indices, axis=None, mode="clip"): """Take elements from an array along an axis. Parameters @@ -203,8 +203,8 @@ def take(data, indices, axis=None, mode="CLIP"): mode : str, optional Specifies how out-of-bound indices will behave. - CLIP - clip to the range (default) - WRAP - wrap around the indices + clip - clip to the range (default) + wrap - wrap around the indices Returns ------- diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index e088ed00c3c8..0cfbcc2c0378 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -243,7 +243,7 @@ def verify_take(dshape, indices_shape, oshape, axis=None): verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2) def test_take(): - def verify_take(src_shape, indices_src, axis=None, mode="CLIP"): + def verify_take(src_shape, indices_src, axis=None, mode="clip"): src_dtype = "float32" indices_dtype = "int32" indices_src = np.array(indices_src, dtype=indices_dtype) @@ -253,7 +253,7 @@ def verify_take(src_shape, indices_src, axis=None, mode="CLIP"): func = relay.Function([x, indices], z) x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype) - ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode.lower()) + ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: @@ -270,11 +270,11 @@ def verify_take(src_shape, indices_src, axis=None, mode="CLIP"): verify_take((2,2), [[[1,0],[0,1]]], 1) verify_take((4,3,5,6), [[2,1,0,0]], -2) verify_take((3,4), [-5, 20]) - verify_take((3,4), [-5, 20], mode="WRAP") + verify_take((3,4), [-5, 20], mode="wrap") verify_take((3,4), [-1, 2], axis=0) - verify_take((3,4), [-1, 2], axis=0, mode="WRAP") + verify_take((3,4), [-1, 2], axis=0, mode="wrap") verify_take((3,4), [-1, 2], axis=1) - verify_take((3,4), [-1, 2], axis=1, mode="WRAP") + verify_take((3,4), [-1, 2], axis=1, mode="wrap") def test_split_infer_type(): diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 620915f07903..bbe1a316f038 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -604,7 +604,7 @@ inline Array split_sections(const Tensor& x, */ inline Tensor take(const Tensor& a, const Tensor& indices, - std::string mode = "CLIP", + std::string mode = "clip", std::string name = "tensor", std::string tag = kInjective) { Array a_shape = a->shape; @@ -614,13 +614,13 @@ inline Tensor take(const Tensor& a, a_size = a_size * a_shape[i]; } - if (mode == "CLIP") { + if (mode == "clip") { return compute( out_shape, [&](const Array& out_index) { auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1); return a(UnravelIndex(idx, a_shape)); }, name, tag); - } else { // mode == "WRAP" + } else { // mode == "wrap" return compute( out_shape, [&](const Array& out_index) { auto idx = (indices(out_index) % a_size + a_size) % a_size; @@ -644,7 +644,7 @@ inline Tensor take(const Tensor& a, inline Tensor take(const Tensor& a, const Tensor& indices, int axis, - std::string mode = "CLIP", + std::string mode = "clip", std::string name = "tensor", std::string tag = kInjective) { if (axis < 0) { @@ -665,7 +665,7 @@ inline Tensor take(const Tensor& a, out_shape.push_back(a->shape[i]); } } - if (mode == "CLIP") { + if (mode == "clip") { return compute( out_shape, [&](const Array& out_index) { Array indices_position; @@ -684,7 +684,7 @@ inline Tensor take(const Tensor& a, } return a(real_indices); }, name, tag); - } else { // mode == "WRAP" + } else { // mode == "wrap" return compute( out_shape, [&](const Array& out_index) { Array indices_position; diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 4c34ab2bcaf8..e674b9e11d14 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -228,7 +228,7 @@ def split(ary, indices_or_sections, axis=0): return cpp.split(ary, indices_or_sections, axis) -def take(a, indices, axis=None, mode="CLIP"): +def take(a, indices, axis=None, mode="clip"): """Take elements from an array along an axis. Parameters @@ -245,8 +245,8 @@ def take(a, indices, axis=None, mode="CLIP"): mode : str, optional Specifies how out-of-bound indices will behave. - CLIP - clip to the range (default) - WRAP - wrap around the indices + clip - clip to the range (default) + wrap - wrap around the indices Returns ------- diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 340189ec39c8..106fd68c3f9c 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -232,7 +232,7 @@ def check_device(device): for device in ["llvm", "cuda", "opencl", "sdaccel", "aocl_sw_emu"]: check_device(device) -def verify_take(src_shape, indices_src, axis=None, mode="CLIP"): +def verify_take(src_shape, indices_src, axis=None, mode="clip"): src_dtype = "float32" indices_dtype = "int32" indices_src = np.array(indices_src, dtype=indices_dtype) @@ -499,11 +499,11 @@ def test_take(): verify_take((2,2), [[[1,0],[0,1]]], 1) verify_take((4,3,5,6), [[2,1,0,0]], -2) verify_take((3,4), [-5, 20]) - verify_take((3,4), [-5, 20], mode="WRAP") + verify_take((3,4), [-5, 20], mode="wrap") verify_take((3,4), [-1, 2], axis=0) - verify_take((3,4), [-1, 2], axis=0, mode="WRAP") + verify_take((3,4), [-1, 2], axis=0, mode="wrap") verify_take((3,4), [-1, 2], axis=1) - verify_take((3,4), [-1, 2], axis=1, mode="WRAP") + verify_take((3,4), [-1, 2], axis=1, mode="wrap") def test_gather_nd(): for indices_dtype in ['int32', 'float32']: From c9984fc6e0bfad00011125f864a5a7a1f781f513 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Fri, 29 Mar 2019 19:21:26 -0700 Subject: [PATCH 4/6] remove unnecssary lower --- topi/tests/python/test_topi_transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 106fd68c3f9c..b56df9f2acdd 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -259,9 +259,9 @@ def check_device(device): data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape)) if axis is None: - out_npys = np.take(data_npy, indices_src, mode=mode.lower()) + out_npys = np.take(data_npy, indices_src, mode=mode) else: - out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode.lower()) + out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode) data_nd = tvm.nd.array(data_npy, ctx) indices_nd = tvm.nd.array(indices_src, ctx) out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype) From 861c166cbab60516b0eacabfae140c587f09f895 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Fri, 29 Mar 2019 22:35:33 -0700 Subject: [PATCH 5/6] Fix mxnet converter for take --- python/tvm/relay/frontend/mxnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 3c1922cde337..8e36801f98ba 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -450,7 +450,7 @@ def _mx_take(inputs, attrs): if mode == "raise": raise RuntimeError("take doesn't support raise mode") axis = attrs.get_int("axis", 0) - return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode.upper()) + return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode) def _mx_reverse(inputs, attrs): From 9da223f99394c6621b5499eec58b8f3351933b0e Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Mon, 1 Apr 2019 12:22:40 -0700 Subject: [PATCH 6/6] fix --- tests/python/unittest/test_arith_simplify.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py index 9d2eeb7bea32..6ee3bc6b57e5 100644 --- a/tests/python/unittest/test_arith_simplify.py +++ b/tests/python/unittest/test_arith_simplify.py @@ -36,8 +36,8 @@ def test_simplify_mod(): with ib.for_range(0, 16, name="i") as i: A[i] = A[(j * 32 + i+1) % 16] body = ib.get() - stmt = tvm.ir_pass.CanonicalSimplify(body, {j: tvm.Range(0, 6), n: tvm.Range(0, 10)}) - diff = tvm.ir_pass.CanonicalSimplify(stmt.body.value.index - (1 + i) % 16) + stmt = tvm.ir_pass.CanonicalSimplify(body) + diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index - (1 + i) % 16) assert diff.value == 0 # if we can't prove that j is non-negative, we can't prove that (j+16) % 16 is j%16 index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16)