diff --git a/nnvm/include/nnvm/top/tensor.h b/nnvm/include/nnvm/top/tensor.h index 87128a663c4a..22ee9d7118e6 100644 --- a/nnvm/include/nnvm/top/tensor.h +++ b/nnvm/include/nnvm/top/tensor.h @@ -48,6 +48,16 @@ struct SplitParam : public dmlc::Parameter { } }; + +struct TakeParam : public dmlc::Parameter { + dmlc::optional axis; + + DMLC_DECLARE_PARAMETER(TakeParam) { + DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional()) + .describe("the axis over which to select values."); + } +}; + struct StridedSliceParam : public dmlc::Parameter { // numpy convention, only support indices, not support list. Tuple begin; diff --git a/nnvm/python/nnvm/top/transform.py b/nnvm/python/nnvm/top/transform.py index c87b4735d73a..b5e00f012996 100644 --- a/nnvm/python/nnvm/top/transform.py +++ b/nnvm/python/nnvm/top/transform.py @@ -61,6 +61,10 @@ def compute_reshape_like(attrs, inputs, out_info): reg.register_pattern("split", OpPattern.INJECTIVE) reg.register_schedule("split", _fschedule_injective) +# take +reg.register_pattern("take", OpPattern.INJECTIVE) +reg.register_schedule("take", _fschedule_injective) + # strided_slice reg.register_pattern("strided_slice", OpPattern.INJECTIVE) reg.register_schedule("strided_slice", _fschedule_injective) diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 72e49a040efe..5bb2ec137594 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -1001,6 +1001,126 @@ Examples:: return Array{ topi::flip(inputs[0], param.axis) }; }); + +// take +DMLC_REGISTER_PARAMETER(TakeParam); + +inline bool TakeInferShape(const NodeAttrs& attrs, + std::vector* in_shape, + std::vector* out_shape) { + CHECK_EQ(in_shape->size(), 2U); + CHECK_EQ(out_shape->size(), 1U); + const TShape& dshape = (*in_shape)[0]; + const TShape& indicesshape = (*in_shape)[1]; + if (dshape.ndim() == 0) return false; + if (indicesshape.ndim() == 0) return false; + + const TakeParam& param = nnvm::get(attrs.parsed); + TShape oshape((!param.axis ? 0: dshape.ndim() - 1) + indicesshape.ndim()); + if (!param.axis) { + for (size_t j = 0; j < indicesshape.ndim(); ++j) { + oshape[j] = indicesshape[j]; + } + } else { + int axis = param.axis.value(); + if (axis < 0) { + axis += dshape.ndim(); + } + CHECK_LT(axis, dshape.ndim()); + + size_t posi = 0; + for (size_t i = 0; i < dshape.ndim(); ++i) { + if (static_cast(i) == axis) { + for (size_t j = 0; j < indicesshape.ndim(); ++j) { + oshape[posi++] = indicesshape[j]; + } + } else { + oshape[posi++] = dshape[i]; + } + } + } + NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 0, dshape); + NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 1, indicesshape); + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape); + return dshape.Size() != 0; +} + +inline bool TakeInferType(const NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + CHECK_EQ((*in_attrs)[1], kInt32); + NNVM_ASSIGN_INPUT_TYPE(attrs, *in_attrs, 0, (*in_attrs)[0]); + NNVM_ASSIGN_INPUT_TYPE(attrs, *in_attrs, 1, static_cast(kInt32)); + NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_attrs, 0, (*in_attrs)[0]); + return true; +} + +inline bool TakeCorrectLayout(const NodeAttrs& attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + CHECK_EQ(ilayouts->size(), last_ilayouts->size()); + CHECK_EQ(olayouts->size(), 1U); + + for (size_t i = 0; i < ilayouts->size(); ++i) { + const Layout& input = last_ilayouts->at(i).defined() ? + last_ilayouts->at(i) : ilayouts->at(i); + NNVM_ASSIGN_LAYOUT(*ilayouts, i, input); + } + + return true; +} + +NNVM_REGISTER_OP(take) +.describe(R"code(Take elements from an array along an axis. + +When axis is not None, this function does the same thing as 'fancy' indexing +(indexing arrays using arrays); however, it can be easier to use if you need +elements along a given axis. + +**Note** that when axis is none the flattened input array is used. + +Examples:: + + a = [[ 1, 2], + [ 3, 4]] + indices = [3, 0, 2] + take(a, indices) = [ 4, 1, 3] + + a = [[ 1., 2.], + [ 3., 4.]] + indices = [1, 0] + take(a, indices, axis=1) = [[ 2., 1.], + [ 4., 3.]] + + )code" NNVM_ADD_FILELINE) +.add_argument("data", "Tensor", "Array to be indexed") +.add_argument("indices", "Tensor", "The indices of the values to extract") +.add_arguments(TakeParam::__FIELDS__()) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", TakeInferShape) +.set_attr("FInferType", TakeInferType) +.set_attr("FCorrectLayout", TakeCorrectLayout) +.set_num_inputs(2) +.set_num_outputs(1) +.set_support_level(1) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + const TakeParam& param = nnvm::get(attrs.parsed); + if (!param.axis) { + return Array{ + topi::take(inputs[0], inputs[1]) }; + } else { + return Array{ + topi::take(inputs[0], inputs[1], param.axis.value()) }; + } + }); + + // SliceLike DMLC_REGISTER_PARAMETER(SliceLikeParam); diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py index b97aff8effb6..d9c6655fea1d 100644 --- a/nnvm/tests/python/compiler/test_top_level1.py +++ b/nnvm/tests/python/compiler/test_top_level1.py @@ -365,6 +365,40 @@ def test_strided_slice(): verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4]) verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3]) +def verify_take(src_shape, indices_src, axis=None): + src_dtype = "float32" + indices_dtype = "int32" + indices_src = np.array(indices_src, dtype=indices_dtype) + a = sym.Variable("a") + indices = sym.Variable("indices") + y = sym.take(a, indices, axis=axis) + for target, ctx in ctx_list(): + # set input + shape_dict = {"a":src_shape, "indices":indices_src.shape} + type_dict = {"a":src_dtype, "indices":indices_dtype} + graph, lib, _ = nnvm.compiler.build(y, target, shape=shape_dict, dtype=type_dict) + m = graph_runtime.create(graph, lib, ctx) + + shape_size = 1 + for i in range(len(src_shape)): + shape_size = shape_size * src_shape[i] + a_src = np.arange(shape_size, dtype=src_dtype).reshape((src_shape)) + out_np = np.take(a_src, indices_src, axis=axis) + m.run(a=a_src, indices=indices_src) + out = m.get_output(0, tvm.nd.empty(out_np.shape, dtype=src_dtype)) + np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5) + +def test_take(): + verify_take((4,), [1]) + verify_take((4,), [[0,1,2,3]]) + verify_take((3,3,3), [[11,25]]) + verify_take((4,), [[0,1],[2,3]]) + verify_take((4,), [1], 0) + verify_take((2,2), [[[1,0],[0,1]]], 0) + verify_take((2,2), [[[1,0],[0,1]]], 1) + verify_take((4,3,5,6), [[2,1,0,0]], -2) + + def verify_squeeze(dshape, axis): x = sym.Variable("x") if axis: @@ -481,6 +515,7 @@ def test_l2_normalize(): test_softmax() test_squeeze() test_pad() + test_take() test_lrn() test_l2_normalize() test_strided_slice()