Skip to content

Commit

Permalink
[Topi] Fast mode in take op (#3325)
Browse files Browse the repository at this point in the history
  • Loading branch information
hlu1 authored and icemelon committed Jun 11, 2019
1 parent d4ca627 commit 2c41fd2
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 5 deletions.
3 changes: 2 additions & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
TVM_ATTR_FIELD(mode).set_default("clip")
.describe("Specify how out-of-bound indices will behave."
"clip - clip to the range (default)"
"wrap - wrap around the indices");
"wrap - wrap around the indices"
"fast - no clip or wrap around (user must make sure indices are in-bound)");
}
};

Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,10 @@ def take(data, indices, axis=None, mode="clip"):
the flattened input array is used.
mode : str, optional
Specifies how out-of-bound indices will behave [clip, wrap].
Specifies how out-of-bound indices will behave [clip, wrap, fast].
clip: clip to the range (default).
wrap: wrap around the indices.
fast: no clip or wrap around (user must make sure indices are in-bound).
Returns
-------
Expand Down
6 changes: 5 additions & 1 deletion tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"):

func = relay.Function([x, indices], z)
x_data = np.random.uniform(low=-1, high=1, size=src_shape).astype(src_dtype)
ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=mode)
np_mode = "raise" if mode == "fast" else mode
ref_res = np.take(x_data, indices=indices_src, axis=axis, mode=np_mode)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
Expand All @@ -291,6 +292,9 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"):
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
verify_take((3,4), [-1, 2], axis=1)
verify_take((3,4), [-1, 2], axis=1, mode="wrap")
verify_take((3,3,3), [[11,25]], mode="fast")
verify_take((3,4), [0, 2], axis=0, mode="fast")
verify_take((3,4), [0, 2], axis=1, mode="fast")


def test_split_infer_type():
Expand Down
26 changes: 26 additions & 0 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,13 @@ inline Tensor take(const Tensor& a,
auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1);
return a(UnravelIndex(idx, a_shape));
}, name, tag);
} else if (mode == "fast") {
LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
"Make sure input indices are in bound";
return compute(
out_shape, [&](const Array<Var>& out_index) {
return a(UnravelIndex(indices(out_index), a_shape));
}, name, tag);
} else { // mode == "wrap"
return compute(
out_shape, [&](const Array<Var>& out_index) {
Expand Down Expand Up @@ -706,6 +713,25 @@ inline Tensor take(const Tensor& a,
}
return a(real_indices);
}, name, tag);
} else if (mode == "fast") {
LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. "
"Make sure input indices are in bound";
return compute(
out_shape, [&](const Array<Var>& out_index) {
Array<Expr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
Array<Expr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
real_indices.push_back(indices(indices_position));
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
}, name, tag);
} else { // mode == "wrap"
return compute(
out_shape, [&](const Array<Var>& out_index) {
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def take(a, indices, axis=None, mode="clip"):
Specifies how out-of-bound indices will behave.
clip - clip to the range (default)
wrap - wrap around the indices
fast - no clip or wrap around (user must make sure indices are in-bound)
Returns
-------
Expand Down
9 changes: 7 additions & 2 deletions topi/tests/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,11 @@ def check_device(device):
data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))

if axis is None:
out_npys = np.take(data_npy, indices_src, mode=mode)
np_mode = "raise" if mode == "fast" else mode
out_npys = np.take(data_npy, indices_src, mode=np_mode)
else:
out_npys = np.take(data_npy, indices_src, axis=axis, mode=mode)
np_mode = "raise" if mode == "fast" else mode
out_npys = np.take(data_npy, indices_src, axis=axis, mode=np_mode)
data_nd = tvm.nd.array(data_npy, ctx)
indices_nd = tvm.nd.array(indices_src, ctx)
out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
Expand Down Expand Up @@ -521,6 +523,9 @@ def test_take():
verify_take((3,4), [-1, 2], axis=0, mode="wrap")
verify_take((3,4), [-1, 2], axis=1)
verify_take((3,4), [-1, 2], axis=1, mode="wrap")
verify_take((3,3,3), [[11,25]], mode="fast")
verify_take((3,4), [0, 2], axis=0, mode="fast")
verify_take((3,4), [0, 2], axis=1, mode="fast")

def test_gather_nd():
for indices_dtype in ['int32', 'float32']:
Expand Down

0 comments on commit 2c41fd2

Please sign in to comment.