Skip to content

Commit

Permalink
Add where operator
Browse files Browse the repository at this point in the history
  • Loading branch information
Wang committed Jul 10, 2018
1 parent afd2b9b commit 473f603
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 2 deletions.
4 changes: 4 additions & 0 deletions nnvm/python/nnvm/top/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,7 @@ def compute_reshape_like(attrs, inputs, out_info):
# slice_like
reg.register_pattern("slice_like", OpPattern.INJECTIVE)
reg.register_schedule("slice_like", _fschedule_injective)

# where
reg.register_pattern("where", OpPattern.INJECTIVE)
reg.register_schedule("where", _fschedule_injective)
85 changes: 83 additions & 2 deletions nnvm/src/top/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1125,8 +1125,8 @@ Examples::
DMLC_REGISTER_PARAMETER(SliceLikeParam);

inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const SliceLikeParam& param = nnvm::get<SliceLikeParam>(attrs.parsed);
Expand Down Expand Up @@ -1221,5 +1221,86 @@ NNVM_REGISTER_OP(slice_like)
})
.set_support_level(4);

// where
inline bool WhereShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
const TShape& cond_shape = in_attrs->at(0);
const TShape& x_shape = in_attrs->at(1);
const TShape& y_shape = in_attrs->at(2);
CHECK_EQ(x_shape, y_shape) << "x and y must have the same shape: "
<< x_shape << " vs " << y_shape;
if (cond_shape != x_shape) {
CHECK_EQ(cond_shape.ndim(), 1)
<< "Shape of condition " << cond_shape
<< " must be either equal to x or has dimension of 1.";
}
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, x_shape);
return true;
}

inline bool WhereInferType(const NodeAttrs &attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(1));
return true;
}

inline bool WhereCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
CHECK_EQ(ilayouts->size(), last_ilayouts->size());
CHECK_EQ(olayouts->size(), 1U);

for (size_t i = 0; i < ilayouts->size(); ++i) {
const Layout& input = last_ilayouts->at(i).defined() ?
last_ilayouts->at(i) : ilayouts->at(i);
NNVM_ASSIGN_LAYOUT(*ilayouts, i, input);
}

return true;
}

NNVM_REGISTER_OP(where)
.describe(R"code(
Return the elements, either from x or y, depending on the condition.
Given three ndarrays, condition, x, and y, return an ndarray with the elements
from x or y, depending on the elements from condition are true or false.
x and y must have the same shape. If condition has the same shape as x,
each element in the output array is from x if the corresponding element
in the condition is true, and from y if false.
If condition does not have the same shape as x, it must be a 1D array whose
size is the same as x’s first dimension size. Each row of the output array
is from x’s row if the corresponding element from condition is true, and
from y’s row if false.
Note that all non-zero values are interpreted as True in condition.
)code" NNVM_ADD_FILELINE)
.add_argument("condition", "Tensor", "Condition array")
.add_argument("x", "Tensor", "First array to be selected")
.add_argument("y", "Tensor", "Second array to be selected")
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", WhereShape)
.set_attr<FInferType>("FInferType", WhereInferType)
.set_attr<FCorrectLayout>("FCorrectLayout", WhereCorrectLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{
topi::where(inputs[0], inputs[1], inputs[2])
};
})
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"condition", "x", "y"};
})
.set_support_level(4);

} // namespace top
} // namespace nnvm
31 changes: 31 additions & 0 deletions nnvm/tests/python/compiler/test_top_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,36 @@ def test_slice_like():
axis = (2, 3)
verify_slice_like(np_data, np_shape_like, axis)

def verify_where(condition, x, y):
dtype = "float32"
if len(condition.shape) == 1:
np_out = np.array([xv if c else yv for (c,xv,yv) in zip(condition,x,y)])
else:
np_out = np.where(condition, x, y)
cond_var = sym.Variable("condition")
x_var = sym.Variable("x")
y_var = sym.Variable("y")
net = sym.where(cond_var, x_var, y_var)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(net, target, {"condition": condition.shape,
"x": x.shape, "y": y.shape})
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"condition": condition, "x": x, "y": y})
m.run()
out = m.get_output(0, tvm.nd.empty(x.shape, dtype))
np.testing.assert_allclose(out.asnumpy(), np_out, atol=1e-5, rtol=1e-5)

def test_where():
shape = (13, 8, 224, 224, 6)
condition = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
x = np.random.uniform(size=shape).astype("float32")
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)
condition = np.random.uniform(low=-1, high=1, size=(shape[0],)).astype("float32")
x = np.random.uniform(size=shape).astype("float32")
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)


if __name__ == "__main__":
test_reshape()
Expand All @@ -665,4 +695,5 @@ def test_slice_like():
test_multibox_transform_loc()
test_nms()
test_slice_like()
test_where()
print(nnvm.compiler.engine.dump())
45 changes: 45 additions & 0 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -575,5 +575,50 @@ inline Tensor take(const Tensor& a,
}, name, tag);
}

/*!
* \brief Return the elements, either from x or y, depending on the condition.
*
* \param condition The condition array.
* \param x First array to be selected.
* \param y Second array to be selected.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor selected from x or y depending on condition.
*/
inline Tensor where(const Tensor& condition,
const Tensor& x,
const Tensor& y,
std::string name = "tensor",
std::string tag = kInjective) {
CHECK_EQ(x->shape.size(), y->shape.size())
<< "x and y must have the same shape.Got different number of dimension: "
<< x->shape.size() << " vs " << y->shape.size();
Array<Expr> oshape = x->shape;
Tensor out;

if (condition->shape.size() != 1) {
CHECK_EQ(condition->shape.size(), x->shape.size())
<< "condition and x must have the same shape.Got "
"different number of dimension: "
<< condition->shape.size() << " vs " << x->shape.size();
out = compute(
oshape, [&](const Array<Var>& indices) {
return tvm::select(condition(indices) != 0, x(indices), y(indices));
}, name, tag);
} else {
CHECK_EQ(condition->shape.size(), 1) << "condition array must be either "
"have the same shape as x or to be a 1-D array.";
out = compute(
oshape, [&](const Array<Var>& indices) {
Array<Expr> condition_idx{indices[0]};
return tvm::select(condition(condition_idx) != 0,
x(indices), y(indices));
}, name, tag);
}
return out;
}


} // namespace topi
#endif // TOPI_TRANSFORM_H_
5 changes: 5 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,11 @@ TVM_REGISTER_GLOBAL("topi.take")
}
});

TVM_REGISTER_GLOBAL("topi.where")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = where(args[0], args[1], args[2]);
});

TVM_REGISTER_GLOBAL("topi.strided_slice")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = strided_slice(args[0], args[1], args[2], args[3]);
Expand Down
42 changes: 42 additions & 0 deletions topi/tests/python_cpp/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,35 @@ def check_device(device):
for device in ["llvm", "opencl"]:
check_device(device)

def verify_where(condition, x, y):
dtype = "float32"
if len(condition.shape) == 1:
np_out = np.array([xv if c else yv for (c,xv,yv) in zip(condition,x,y)])
else:
np_out = np.where(condition, x, y)
A = tvm.placeholder(shape=condition.shape, dtype=dtype, name="condition")
B = tvm.placeholder(shape=x.shape, dtype=dtype, name="x")
C = tvm.placeholder(shape=y.shape, dtype=dtype, name="y")
out_tensor = topi.cpp.where(A, B, C)

def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(out_tensor)

foo = tvm.build(s, [A, B, C, out_tensor], device, name="where")
tvm_out = tvm.nd.empty(x.shape, ctx=ctx, dtype=dtype)
foo(tvm.nd.array(condition), tvm.nd.array(x),
tvm.nd.array(y), tvm_out)
np.testing.assert_allclose(tvm_out.asnumpy(), np_out)

for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
check_device(device)

def verify_concatenate_split(shapes, axis, indices_or_sections):
tensor_l_concatenate = []
for i, shape in enumerate(shapes):
Expand Down Expand Up @@ -324,6 +353,18 @@ def test_take():
verify_take((2,2), [[[1,0],[0,1]]], 1)
verify_take((4,3,5,6), [[2,1,0,0]], -2)

def test_where():
shape = (10, 3, 7, 13)
condition = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
x = np.random.uniform(size=shape).astype("float32")
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)
condition = np.random.uniform(low=-1, high=1, size=(shape[0],)).astype("float32")
x = np.random.uniform(size=shape).astype("float32")
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)


def test_regression_1():
verify_concatenate_split([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1, [3, 7])
verify_concatenate_split([(3, 4), (2, 4), (3, 4)], 0, [1, 2, 3, 4])
Expand All @@ -340,5 +381,6 @@ def test_regression_2():
test_squeeze()
test_split()
test_take()
test_where()
test_regression_1()
test_regression_2()

0 comments on commit 473f603

Please sign in to comment.