Skip to content

Commit

Permalink
[Relay/TOPI][OP] Add clip and wrap mode support in take (#2858)
Browse files Browse the repository at this point in the history
* Update take

* Add special case for canonical simplify and fix test cases

* Use lower case for wrap and clip

* remove unnecssary lower

* Fix mxnet converter for take

* fix
  • Loading branch information
icemelon authored and yzhliu committed Apr 1, 2019
1 parent 7cc9240 commit 3746d90
Show file tree
Hide file tree
Showing 12 changed files with 131 additions and 33 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/HalideIR
5 changes: 5 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,15 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {

struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer axis;
std::string mode;

TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
.describe("The axis over which to select values.");
TVM_ATTR_FIELD(mode).set_default("clip")
.describe("Specify how out-of-bound indices will behave."
"clip - clip to the range (default)"
"wrap - wrap around the indices");
}
};

Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,15 @@ def _mx_tile(inputs, attrs):
return _op.tile(inputs[0], **new_attrs)


def _mx_take(inputs, attrs):
assert len(inputs) == 2
mode = attrs.get_str("mode", "clip")
if mode == "raise":
raise RuntimeError("take doesn't support raise mode")
axis = attrs.get_int("axis", 0)
return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode)


def _mx_reverse(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
Expand Down Expand Up @@ -749,6 +758,7 @@ def _mx_deformable_convolution(inputs, attrs):
"_full" : _mx_full,
"repeat" : _mx_repeat,
"tile" : _mx_tile,
"take" : _mx_take,
"reverse" : _mx_reverse,
"squeeze" : _mx_squeeze,
"broadcast_axis": _mx_broadcast_axis,
Expand Down
9 changes: 7 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def reshape_like(data, shape_like):
return _make.reshape_like(data, shape_like)


def take(data, indices, axis=None):
def take(data, indices, axis=None, mode="clip"):
"""Take elements from an array along an axis.
Parameters
Expand All @@ -201,12 +201,17 @@ def take(data, indices, axis=None):
The axis over which to select values. By default,
the flattened input array is used.
mode : str, optional
Specifies how out-of-bound indices will behave.
clip - clip to the range (default)
wrap - wrap around the indices
Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.take(data, indices, axis)
return _make.take(data, indices, axis, mode)


def full(fill_value, shape=(), dtype=""):
Expand Down
10 changes: 6 additions & 4 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -753,24 +753,26 @@ Array<Tensor> TakeCompute(const Attrs& attrs,
const auto* param = attrs.as<TakeAttrs>();
CHECK(param != nullptr);
if (!param->axis.defined()) {
return Array<Tensor>{ topi::take(inputs[0], inputs[1]) };
return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->mode) };
} else {
return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->axis) };
return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->axis, param->mode) };
}
}

Expr MakeTake(Expr data,
Expr indices,
Integer axis) {
Integer axis,
std::string mode) {
auto attrs = make_node<TakeAttrs>();
attrs->axis = std::move(axis);
attrs->mode = std::move(mode);
static const Op& op = Op::Get("take");
return CallNode::make(op, {data, indices}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op._make.take")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeTake, args, rv);
runtime::detail::unpack_call<Expr, 4>(MakeTake, args, rv);
});

RELAY_REGISTER_OP("take")
Expand Down
22 changes: 21 additions & 1 deletion tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,14 +464,33 @@ def verify(data_shape, weight_shape):
verify((2, 2), (4, 5))
verify((2, 3, 4), (4, 5))


def test_forward_smooth_l1():
data = mx.sym.var('data')
mx_sym = mx.sym.smooth_l1(data)
verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))
mx_sym = mx.sym.smooth_l1(data, scalar=1.0)
verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))

def test_forward_take():
def verify(shape, indices_src, axis, mode="clip"):
x_np = np.random.uniform(size=shape).astype("float32")
indices_np = np.array(indices_src, dtype="float32")
ref_res = mx.nd.take(mx.nd.array(x_np), mx.nd.array(indices_np), axis, mode)
mx_sym = mx.sym.take(mx.sym.var("x"), mx.sym.var("y"), axis, mode)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape, "y": indices_np.shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np, indices_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((2,2), [[[1,0],[0,1]]], 0)
verify((2,2), [[[1,0],[0,1]]], 1)
verify((4,3,5,6), [[2,1,0,0]], -2)
verify((3,4), [-1, 5], 0)
verify((3,4), [-1, 5], 0, mode="wrap")
verify((3,4), [-1, 5], 1)
verify((3,4), [-1, 5], 1, mode="wrap")

if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand Down Expand Up @@ -507,3 +526,4 @@ def test_forward_smooth_l1():
test_forward_full()
test_forward_embedding()
test_forward_smooth_l1()
test_forward_take()
12 changes: 9 additions & 3 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,17 +243,17 @@ def verify_take(dshape, indices_shape, oshape, axis=None):
verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2)

def test_take():
def verify_take(src_shape, indices_src, axis=None):
def verify_take(src_shape, indices_src, axis=None, mode="clip"):
src_dtype = "float32"
indices_dtype = "int32"
indices_src = np.array(indices_src, dtype=indices_dtype)
x = relay.var("x", relay.TensorType(src_shape, src_dtype))
indices = relay.var("indices", relay.TensorType(indices_src.shape, indices_dtype))
z = relay.take(x, indices, axis=axis)
z = relay.take(x, indices, axis=axis, mode=mode)

func = relay.Function([x, indices], z)
x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype)
ref_res = np.take(x_data, indices=indices_src, axis=axis)
ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
Expand All @@ -269,6 +269,12 @@ def verify_take(src_shape, indices_src, axis=None):
verify_take((2,2), [[[1,0],[0,1]]], 0)
verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2)
verify_take((3,4), [-5, 20])
verify_take((3,4), [-5, 20], mode="wrap")
verify_take((3,4), [-1, 2], axis=0)
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
verify_take((3,4), [-1, 2], axis=1)
verify_take((3,4), [-1, 2], axis=1, mode="wrap")


def test_split_infer_type():
Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_arith_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def test_simplify_mod():
stmt = tvm.ir_pass.CanonicalSimplify(body)
diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index - (1 + i) % 16)
assert diff.value == 0
# if we can't prove that j is non-negative, we can't prove that (j+16) % 16 is j%16
index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16)
assert index != j
index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16, {j: tvm.Range(0, 6)})
assert index == j
# if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16
index = tvm.ir_pass.CanonicalSimplify(
(j + n * 32) % 16, {j: tvm.Range(0, 6)})
Expand Down
54 changes: 43 additions & 11 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -604,22 +604,29 @@ inline Array<Tensor> split_sections(const Tensor& x,
*/
inline Tensor take(const Tensor& a,
const Tensor& indices,
std::string mode = "clip",
std::string name = "tensor",
std::string tag = kInjective) {
Array<Expr> a_shape = a->shape;
Array<Expr> out_shape;
for (size_t j = 0; j < indices->shape.size(); ++j) {
out_shape.push_back(indices->shape[j]);
Array<Expr> out_shape = indices->shape;
Expr a_size = 1;
for (size_t i = 0; i < a_shape.size(); ++i) {
a_size = a_size * a_shape[i];
}

return compute(
if (mode == "clip") {
return compute(
out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position;
for (size_t j = 0; j < indices->shape.size(); ++j) {
indices_position.push_back(out_index[j]);
}
return a(UnravelIndex(indices(indices_position), a_shape));
auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
return a(UnravelIndex(idx, a_shape));
}, name, tag);
} else { // mode == "wrap"
return compute(
out_shape, [&](const Array<Var>& out_index) {
auto idx = (indices(out_index) % a_size + a_size) % a_size;
return a(UnravelIndex(idx, a_shape));
}, name, tag);
}
}

/*!
Expand All @@ -637,12 +644,15 @@ inline Tensor take(const Tensor& a,
inline Tensor take(const Tensor& a,
const Tensor& indices,
int axis,
std::string mode = "clip",
std::string name = "tensor",
std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(a->shape.size());
}
CHECK_GE(axis, 0) << "axis out of bounds";
CHECK_LT(axis, a->shape.size()) << "axis out of bounds";
auto axis_dim = a->shape[axis];

int indices_len = static_cast<int>(indices->shape.size());
Array<Expr> out_shape;
Expand All @@ -655,7 +665,27 @@ inline Tensor take(const Tensor& a,
out_shape.push_back(a->shape[i]);
}
}
return compute(
if (mode == "clip") {
return compute(
out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<Expr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = tvm::min(tvm::max(0, indices(indices_position)),
axis_dim - 1);
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
}, name, tag);
} else { // mode == "wrap"
return compute(
out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
Expand All @@ -665,12 +695,14 @@ inline Tensor take(const Tensor& a,
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
real_indices.push_back(indices(indices_position));
auto idx = (indices(indices_position) % axis_dim + axis_dim) % axis_dim;
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
}, name, tag);
}
}

/*!
Expand Down
11 changes: 8 additions & 3 deletions topi/python/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def split(ary, indices_or_sections, axis=0):
return cpp.split(ary, indices_or_sections, axis)


def take(a, indices, axis=None):
def take(a, indices, axis=None, mode="clip"):
"""Take elements from an array along an axis.
Parameters
Expand All @@ -243,13 +243,18 @@ def take(a, indices, axis=None):
The axis over which to select values. By default,
the flattened input array is used.
mode : str, optional
Specifies how out-of-bound indices will behave.
clip - clip to the range (default)
wrap - wrap around the indices
Returns
-------
ret : tvm.Tensor
"""
if axis is None:
return cpp.take(a, indices)
return cpp.take(a, indices, int(axis))
return cpp.take(a, indices, mode)
return cpp.take(a, indices, int(axis), mode)


def gather_nd(a, indices):
Expand Down
8 changes: 5 additions & 3 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,13 @@ TVM_REGISTER_GLOBAL("topi.layout_transform")

TVM_REGISTER_GLOBAL("topi.take")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args.size() == 2) {
*rv = take(args[0], args[1]);
if (args.size() == 3) {
std::string mode = args[2];
*rv = take(args[0], args[1], mode);
} else {
int axis = args[2];
*rv = take(args[0], args[1], axis);
std::string mode = args[3];
*rv = take(args[0], args[1], axis, mode);
}
});

Expand Down
16 changes: 11 additions & 5 deletions topi/tests/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,16 +232,16 @@ def check_device(device):
for device in ["llvm", "cuda", "opencl", "sdaccel", "aocl_sw_emu"]:
check_device(device)

def verify_take(src_shape, indices_src, axis=None):
def verify_take(src_shape, indices_src, axis=None, mode="clip"):
src_dtype = "float32"
indices_dtype = "int32"
indices_src = np.array(indices_src, dtype=indices_dtype)
A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A")
indices = tvm.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices")
if axis is None:
out_tensor = topi.take(a=A, indices=indices)
out_tensor = topi.take(a=A, indices=indices, mode=mode)
else:
out_tensor = topi.take(a=A, indices=indices, axis=axis)
out_tensor = topi.take(a=A, indices=indices, axis=axis, mode=mode)

def check_device(device):
ctx = tvm.context(device, 0)
Expand All @@ -259,9 +259,9 @@ def check_device(device):
data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))

if axis is None:
out_npys = np.take(data_npy, indices_src)
out_npys = np.take(data_npy, indices_src, mode=mode)
else:
out_npys = np.take(data_npy, indices_src, axis=axis)
out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode)
data_nd = tvm.nd.array(data_npy, ctx)
indices_nd = tvm.nd.array(indices_src, ctx)
out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
Expand Down Expand Up @@ -498,6 +498,12 @@ def test_take():
verify_take((2,2), [[[1,0],[0,1]]], 0)
verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2)
verify_take((3,4), [-5, 20])
verify_take((3,4), [-5, 20], mode="wrap")
verify_take((3,4), [-1, 2], axis=0)
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
verify_take((3,4), [-1, 2], axis=1)
verify_take((3,4), [-1, 2], axis=1, mode="wrap")

def test_gather_nd():
for indices_dtype in ['int32', 'float32']:
Expand Down

0 comments on commit 3746d90

Please sign in to comment.