Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay/TOPI][OP] Add clip and wrap mode support in take #2858

Merged
merged 6 commits into from
Apr 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/HalideIR
5 changes: 5 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,15 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {

struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer axis;
std::string mode;

TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
.describe("The axis over which to select values.");
TVM_ATTR_FIELD(mode).set_default("clip")
.describe("Specify how out-of-bound indices will behave."
"clip - clip to the range (default)"
"wrap - wrap around the indices");
}
};

Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,15 @@ def _mx_tile(inputs, attrs):
return _op.tile(inputs[0], **new_attrs)


def _mx_take(inputs, attrs):
assert len(inputs) == 2
mode = attrs.get_str("mode", "clip")
if mode == "raise":
raise RuntimeError("take doesn't support raise mode")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable the support?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axis = attrs.get_int("axis", 0)
return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode)


def _mx_reverse(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
Expand Down Expand Up @@ -749,6 +758,7 @@ def _mx_deformable_convolution(inputs, attrs):
"_full" : _mx_full,
"repeat" : _mx_repeat,
"tile" : _mx_tile,
"take" : _mx_take,
"reverse" : _mx_reverse,
"squeeze" : _mx_squeeze,
"broadcast_axis": _mx_broadcast_axis,
Expand Down
9 changes: 7 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def reshape_like(data, shape_like):
return _make.reshape_like(data, shape_like)


def take(data, indices, axis=None):
def take(data, indices, axis=None, mode="clip"):
"""Take elements from an array along an axis.

Parameters
Expand All @@ -201,12 +201,17 @@ def take(data, indices, axis=None):
The axis over which to select values. By default,
the flattened input array is used.

mode : str, optional
Specifies how out-of-bound indices will behave.
clip - clip to the range (default)
wrap - wrap around the indices

Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.take(data, indices, axis)
return _make.take(data, indices, axis, mode)


def full(fill_value, shape=(), dtype=""):
Expand Down
10 changes: 6 additions & 4 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -753,24 +753,26 @@ Array<Tensor> TakeCompute(const Attrs& attrs,
const auto* param = attrs.as<TakeAttrs>();
CHECK(param != nullptr);
if (!param->axis.defined()) {
return Array<Tensor>{ topi::take(inputs[0], inputs[1]) };
return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->mode) };
} else {
return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->axis) };
return Array<Tensor>{ topi::take(inputs[0], inputs[1], param->axis, param->mode) };
}
}

Expr MakeTake(Expr data,
Expr indices,
Integer axis) {
Integer axis,
std::string mode) {
auto attrs = make_node<TakeAttrs>();
attrs->axis = std::move(axis);
attrs->mode = std::move(mode);
static const Op& op = Op::Get("take");
return CallNode::make(op, {data, indices}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op._make.take")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeTake, args, rv);
runtime::detail::unpack_call<Expr, 4>(MakeTake, args, rv);
});

RELAY_REGISTER_OP("take")
Expand Down
22 changes: 21 additions & 1 deletion tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,14 +464,33 @@ def verify(data_shape, weight_shape):
verify((2, 2), (4, 5))
verify((2, 3, 4), (4, 5))


def test_forward_smooth_l1():
data = mx.sym.var('data')
mx_sym = mx.sym.smooth_l1(data)
verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))
mx_sym = mx.sym.smooth_l1(data, scalar=1.0)
verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))

def test_forward_take():
def verify(shape, indices_src, axis, mode="clip"):
x_np = np.random.uniform(size=shape).astype("float32")
indices_np = np.array(indices_src, dtype="float32")
ref_res = mx.nd.take(mx.nd.array(x_np), mx.nd.array(indices_np), axis, mode)
mx_sym = mx.sym.take(mx.sym.var("x"), mx.sym.var("y"), axis, mode)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape, "y": indices_np.shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np, indices_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((2,2), [[[1,0],[0,1]]], 0)
verify((2,2), [[[1,0],[0,1]]], 1)
verify((4,3,5,6), [[2,1,0,0]], -2)
verify((3,4), [-1, 5], 0)
verify((3,4), [-1, 5], 0, mode="wrap")
verify((3,4), [-1, 5], 1)
verify((3,4), [-1, 5], 1, mode="wrap")

if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand Down Expand Up @@ -507,3 +526,4 @@ def test_forward_smooth_l1():
test_forward_full()
test_forward_embedding()
test_forward_smooth_l1()
test_forward_take()
12 changes: 9 additions & 3 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,17 +243,17 @@ def verify_take(dshape, indices_shape, oshape, axis=None):
verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2)

def test_take():
def verify_take(src_shape, indices_src, axis=None):
def verify_take(src_shape, indices_src, axis=None, mode="clip"):
src_dtype = "float32"
indices_dtype = "int32"
indices_src = np.array(indices_src, dtype=indices_dtype)
x = relay.var("x", relay.TensorType(src_shape, src_dtype))
indices = relay.var("indices", relay.TensorType(indices_src.shape, indices_dtype))
z = relay.take(x, indices, axis=axis)
z = relay.take(x, indices, axis=axis, mode=mode)

func = relay.Function([x, indices], z)
x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype)
ref_res = np.take(x_data, indices=indices_src, axis=axis)
ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
Expand All @@ -269,6 +269,12 @@ def verify_take(src_shape, indices_src, axis=None):
verify_take((2,2), [[[1,0],[0,1]]], 0)
verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2)
verify_take((3,4), [-5, 20])
verify_take((3,4), [-5, 20], mode="wrap")
verify_take((3,4), [-1, 2], axis=0)
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
verify_take((3,4), [-1, 2], axis=1)
verify_take((3,4), [-1, 2], axis=1, mode="wrap")


def test_split_infer_type():
Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_arith_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def test_simplify_mod():
stmt = tvm.ir_pass.CanonicalSimplify(body)
diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index - (1 + i) % 16)
assert diff.value == 0
# if we can't prove that j is non-negative, we can't prove that (j+16) % 16 is j%16
index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16)
assert index != j
index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16, {j: tvm.Range(0, 6)})
assert index == j
# if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16
index = tvm.ir_pass.CanonicalSimplify(
(j + n * 32) % 16, {j: tvm.Range(0, 6)})
Expand Down
54 changes: 43 additions & 11 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -604,22 +604,29 @@ inline Array<Tensor> split_sections(const Tensor& x,
*/
inline Tensor take(const Tensor& a,
const Tensor& indices,
std::string mode = "clip",
std::string name = "tensor",
std::string tag = kInjective) {
Array<Expr> a_shape = a->shape;
Array<Expr> out_shape;
for (size_t j = 0; j < indices->shape.size(); ++j) {
out_shape.push_back(indices->shape[j]);
Array<Expr> out_shape = indices->shape;
Expr a_size = 1;
for (size_t i = 0; i < a_shape.size(); ++i) {
a_size = a_size * a_shape[i];
}

return compute(
if (mode == "clip") {
return compute(
out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position;
for (size_t j = 0; j < indices->shape.size(); ++j) {
indices_position.push_back(out_index[j]);
}
return a(UnravelIndex(indices(indices_position), a_shape));
auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
return a(UnravelIndex(idx, a_shape));
}, name, tag);
} else { // mode == "wrap"
return compute(
out_shape, [&](const Array<Var>& out_index) {
auto idx = (indices(out_index) % a_size + a_size) % a_size;
return a(UnravelIndex(idx, a_shape));
}, name, tag);
}
}

/*!
Expand All @@ -637,12 +644,15 @@ inline Tensor take(const Tensor& a,
inline Tensor take(const Tensor& a,
const Tensor& indices,
int axis,
std::string mode = "clip",
std::string name = "tensor",
std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(a->shape.size());
}
CHECK_GE(axis, 0) << "axis out of bounds";
CHECK_LT(axis, a->shape.size()) << "axis out of bounds";
auto axis_dim = a->shape[axis];

int indices_len = static_cast<int>(indices->shape.size());
Array<Expr> out_shape;
Expand All @@ -655,7 +665,27 @@ inline Tensor take(const Tensor& a,
out_shape.push_back(a->shape[i]);
}
}
return compute(
if (mode == "clip") {
return compute(
out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<Expr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = tvm::min(tvm::max(0, indices(indices_position)),
axis_dim - 1);
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
}, name, tag);
} else { // mode == "wrap"
return compute(
out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
Expand All @@ -665,12 +695,14 @@ inline Tensor take(const Tensor& a,
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
real_indices.push_back(indices(indices_position));
auto idx = (indices(indices_position) % axis_dim + axis_dim) % axis_dim;
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
}, name, tag);
}
}

/*!
Expand Down
11 changes: 8 additions & 3 deletions topi/python/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def split(ary, indices_or_sections, axis=0):
return cpp.split(ary, indices_or_sections, axis)


def take(a, indices, axis=None):
def take(a, indices, axis=None, mode="clip"):
"""Take elements from an array along an axis.
Parameters
Expand All @@ -243,13 +243,18 @@ def take(a, indices, axis=None):
The axis over which to select values. By default,
the flattened input array is used.
mode : str, optional
Specifies how out-of-bound indices will behave.
clip - clip to the range (default)
wrap - wrap around the indices
Returns
-------
ret : tvm.Tensor
"""
if axis is None:
return cpp.take(a, indices)
return cpp.take(a, indices, int(axis))
return cpp.take(a, indices, mode)
return cpp.take(a, indices, int(axis), mode)


def gather_nd(a, indices):
Expand Down
8 changes: 5 additions & 3 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,13 @@ TVM_REGISTER_GLOBAL("topi.layout_transform")

TVM_REGISTER_GLOBAL("topi.take")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args.size() == 2) {
*rv = take(args[0], args[1]);
if (args.size() == 3) {
std::string mode = args[2];
*rv = take(args[0], args[1], mode);
} else {
int axis = args[2];
*rv = take(args[0], args[1], axis);
std::string mode = args[3];
*rv = take(args[0], args[1], axis, mode);
}
});

Expand Down
16 changes: 11 additions & 5 deletions topi/tests/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,16 +232,16 @@ def check_device(device):
for device in ["llvm", "cuda", "opencl", "sdaccel", "aocl_sw_emu"]:
check_device(device)

def verify_take(src_shape, indices_src, axis=None):
def verify_take(src_shape, indices_src, axis=None, mode="clip"):
src_dtype = "float32"
indices_dtype = "int32"
indices_src = np.array(indices_src, dtype=indices_dtype)
A = tvm.placeholder(shape=src_shape, dtype=src_dtype, name="A")
indices = tvm.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices")
if axis is None:
out_tensor = topi.take(a=A, indices=indices)
out_tensor = topi.take(a=A, indices=indices, mode=mode)
else:
out_tensor = topi.take(a=A, indices=indices, axis=axis)
out_tensor = topi.take(a=A, indices=indices, axis=axis, mode=mode)

def check_device(device):
ctx = tvm.context(device, 0)
Expand All @@ -259,9 +259,9 @@ def check_device(device):
data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))

if axis is None:
out_npys = np.take(data_npy, indices_src)
out_npys = np.take(data_npy, indices_src, mode=mode)
else:
out_npys = np.take(data_npy, indices_src, axis=axis)
out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode)
data_nd = tvm.nd.array(data_npy, ctx)
indices_nd = tvm.nd.array(indices_src, ctx)
out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
Expand Down Expand Up @@ -498,6 +498,12 @@ def test_take():
verify_take((2,2), [[[1,0],[0,1]]], 0)
verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2)
verify_take((3,4), [-5, 20])
verify_take((3,4), [-5, 20], mode="wrap")
verify_take((3,4), [-1, 2], axis=0)
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
verify_take((3,4), [-1, 2], axis=1)
verify_take((3,4), [-1, 2], axis=1, mode="wrap")

def test_gather_nd():
for indices_dtype in ['int32', 'float32']:
Expand Down