Skip to content

Commit

Permalink
add take frontend (apache#1307)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dayananda-V authored and tqchen committed Jul 4, 2018
1 parent e8d80db commit 4f184db
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 0 deletions.
10 changes: 10 additions & 0 deletions nnvm/include/nnvm/top/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ struct SplitParam : public dmlc::Parameter<SplitParam> {
}
};


struct TakeParam : public dmlc::Parameter<TakeParam> {
dmlc::optional<int> axis;

DMLC_DECLARE_PARAMETER(TakeParam) {
DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>())
.describe("the axis over which to select values.");
}
};

struct StridedSliceParam : public dmlc::Parameter<StridedSliceParam> {
// numpy convention, only support indices, not support list.
Tuple<int64_t> begin;
Expand Down
4 changes: 4 additions & 0 deletions nnvm/python/nnvm/top/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def compute_reshape_like(attrs, inputs, out_info):
reg.register_pattern("split", OpPattern.INJECTIVE)
reg.register_schedule("split", _fschedule_injective)

# take
reg.register_pattern("take", OpPattern.INJECTIVE)
reg.register_schedule("take", _fschedule_injective)

# strided_slice
reg.register_pattern("strided_slice", OpPattern.INJECTIVE)
reg.register_schedule("strided_slice", _fschedule_injective)
Expand Down
120 changes: 120 additions & 0 deletions nnvm/src/top/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,126 @@ Examples::
return Array<Tensor>{ topi::flip(inputs[0], param.axis) };
});


// take
DMLC_REGISTER_PARAMETER(TakeParam);

inline bool TakeInferShape(const NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
CHECK_EQ(in_shape->size(), 2U);
CHECK_EQ(out_shape->size(), 1U);
const TShape& dshape = (*in_shape)[0];
const TShape& indicesshape = (*in_shape)[1];
if (dshape.ndim() == 0) return false;
if (indicesshape.ndim() == 0) return false;

const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
TShape oshape((!param.axis ? 0: dshape.ndim() - 1) + indicesshape.ndim());
if (!param.axis) {
for (size_t j = 0; j < indicesshape.ndim(); ++j) {
oshape[j] = indicesshape[j];
}
} else {
int axis = param.axis.value();
if (axis < 0) {
axis += dshape.ndim();
}
CHECK_LT(axis, dshape.ndim());

size_t posi = 0;
for (size_t i = 0; i < dshape.ndim(); ++i) {
if (static_cast<int>(i) == axis) {
for (size_t j = 0; j < indicesshape.ndim(); ++j) {
oshape[posi++] = indicesshape[j];
}
} else {
oshape[posi++] = dshape[i];
}
}
}
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 0, dshape);
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 1, indicesshape);
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return dshape.Size() != 0;
}

inline bool TakeInferType(const NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
CHECK_EQ((*in_attrs)[1], kInt32);
NNVM_ASSIGN_INPUT_TYPE(attrs, *in_attrs, 0, (*in_attrs)[0]);
NNVM_ASSIGN_INPUT_TYPE(attrs, *in_attrs, 1, static_cast<int>(kInt32));
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_attrs, 0, (*in_attrs)[0]);
return true;
}

inline bool TakeCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
CHECK_EQ(ilayouts->size(), last_ilayouts->size());
CHECK_EQ(olayouts->size(), 1U);

for (size_t i = 0; i < ilayouts->size(); ++i) {
const Layout& input = last_ilayouts->at(i).defined() ?
last_ilayouts->at(i) : ilayouts->at(i);
NNVM_ASSIGN_LAYOUT(*ilayouts, i, input);
}

return true;
}

NNVM_REGISTER_OP(take)
.describe(R"code(Take elements from an array along an axis.
When axis is not None, this function does the same thing as 'fancy' indexing
(indexing arrays using arrays); however, it can be easier to use if you need
elements along a given axis.
**Note** that when axis is none the flattened input array is used.
Examples::
a = [[ 1, 2],
[ 3, 4]]
indices = [3, 0, 2]
take(a, indices) = [ 4, 1, 3]
a = [[ 1., 2.],
[ 3., 4.]]
indices = [1, 0]
take(a, indices, axis=1) = [[ 2., 1.],
[ 4., 3.]]
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Array to be indexed")
.add_argument("indices", "Tensor", "The indices of the values to extract")
.add_arguments(TakeParam::__FIELDS__())
.set_attr_parser(ParamParser<TakeParam>)
.set_attr<FInferShape>("FInferShape", TakeInferShape)
.set_attr<FInferType>("FInferType", TakeInferType)
.set_attr<FCorrectLayout>("FCorrectLayout", TakeCorrectLayout)
.set_num_inputs(2)
.set_num_outputs(1)
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
if (!param.axis) {
return Array<Tensor>{
topi::take(inputs[0], inputs[1]) };
} else {
return Array<Tensor>{
topi::take(inputs[0], inputs[1], param.axis.value()) };
}
});


// SliceLike
DMLC_REGISTER_PARAMETER(SliceLikeParam);

Expand Down
35 changes: 35 additions & 0 deletions nnvm/tests/python/compiler/test_top_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,40 @@ def test_strided_slice():
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4])
verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3])

def verify_take(src_shape, indices_src, axis=None):
src_dtype = "float32"
indices_dtype = "int32"
indices_src = np.array(indices_src, dtype=indices_dtype)
a = sym.Variable("a")
indices = sym.Variable("indices")
y = sym.take(a, indices, axis=axis)
for target, ctx in ctx_list():
# set input
shape_dict = {"a":src_shape, "indices":indices_src.shape}
type_dict = {"a":src_dtype, "indices":indices_dtype}
graph, lib, _ = nnvm.compiler.build(y, target, shape=shape_dict, dtype=type_dict)
m = graph_runtime.create(graph, lib, ctx)

shape_size = 1
for i in range(len(src_shape)):
shape_size = shape_size * src_shape[i]
a_src = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
out_np = np.take(a_src, indices_src, axis=axis)
m.run(a=a_src, indices=indices_src)
out = m.get_output(0, tvm.nd.empty(out_np.shape, dtype=src_dtype))
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)

def test_take():
verify_take((4,), [1])
verify_take((4,), [[0,1,2,3]])
verify_take((3,3,3), [[11,25]])
verify_take((4,), [[0,1],[2,3]])
verify_take((4,), [1], 0)
verify_take((2,2), [[[1,0],[0,1]]], 0)
verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2)


def verify_squeeze(dshape, axis):
x = sym.Variable("x")
if axis:
Expand Down Expand Up @@ -481,6 +515,7 @@ def test_l2_normalize():
test_softmax()
test_squeeze()
test_pad()
test_take()
test_lrn()
test_l2_normalize()
test_strided_slice()

0 comments on commit 4f184db

Please sign in to comment.