-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI]Add where operator #1416
[TOPI]Add where operator #1416
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1125,8 +1125,8 @@ Examples:: | |
DMLC_REGISTER_PARAMETER(SliceLikeParam); | ||
|
||
inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs, | ||
std::vector<TShape>* in_attrs, | ||
std::vector<TShape>* out_attrs) { | ||
std::vector<TShape>* in_attrs, | ||
std::vector<TShape>* out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 2U); | ||
CHECK_EQ(out_attrs->size(), 1U); | ||
const SliceLikeParam& param = nnvm::get<SliceLikeParam>(attrs.parsed); | ||
|
@@ -1221,5 +1221,86 @@ NNVM_REGISTER_OP(slice_like) | |
}) | ||
.set_support_level(4); | ||
|
||
// where | ||
inline bool WhereShape(const nnvm::NodeAttrs& attrs, | ||
std::vector<TShape>* in_attrs, | ||
std::vector<TShape>* out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 3U); | ||
CHECK_EQ(out_attrs->size(), 1U); | ||
const TShape& cond_shape = in_attrs->at(0); | ||
const TShape& x_shape = in_attrs->at(1); | ||
const TShape& y_shape = in_attrs->at(2); | ||
CHECK_EQ(x_shape, y_shape) << "x and y must have the same shape: " | ||
<< x_shape << " vs " << y_shape; | ||
if (cond_shape != x_shape) { | ||
CHECK_EQ(cond_shape.ndim(), 1) | ||
<< "Shape of condition " << cond_shape | ||
<< " must be either equal to x or has dimension of 1."; | ||
} | ||
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, x_shape); | ||
return true; | ||
} | ||
|
||
inline bool WhereInferType(const NodeAttrs &attrs, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(1)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also check for dtype of X, Y to be same. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added. |
||
return true; | ||
} | ||
|
||
inline bool WhereCorrectLayout(const NodeAttrs& attrs, | ||
std::vector<Layout> *ilayouts, | ||
const std::vector<Layout> *last_ilayouts, | ||
std::vector<Layout> *olayouts) { | ||
CHECK_EQ(ilayouts->size(), last_ilayouts->size()); | ||
CHECK_EQ(olayouts->size(), 1U); | ||
|
||
for (size_t i = 0; i < ilayouts->size(); ++i) { | ||
const Layout& input = last_ilayouts->at(i).defined() ? | ||
last_ilayouts->at(i) : ilayouts->at(i); | ||
NNVM_ASSIGN_LAYOUT(*ilayouts, i, input); | ||
} | ||
|
||
return true; | ||
} | ||
|
||
NNVM_REGISTER_OP(where) | ||
.describe(R"code( | ||
Return the elements, either from x or y, depending on the condition. | ||
|
||
Given three ndarrays, condition, x, and y, return an ndarray with the elements | ||
from x or y, depending on the elements from condition are true or false. | ||
x and y must have the same shape. If condition has the same shape as x, | ||
each element in the output array is from x if the corresponding element | ||
in the condition is true, and from y if false. | ||
|
||
If condition does not have the same shape as x, it must be a 1D array whose | ||
size is the same as x’s first dimension size. Each row of the output array | ||
is from x’s row if the corresponding element from condition is true, and | ||
from y’s row if false. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggest to add some basic example. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added. |
||
Note that all non-zero values are interpreted as True in condition. | ||
)code" NNVM_ADD_FILELINE) | ||
.add_argument("condition", "Tensor", "Condition array") | ||
.add_argument("x", "Tensor", "First array to be selected") | ||
.add_argument("y", "Tensor", "Second array to be selected") | ||
.set_num_inputs(3) | ||
.set_num_outputs(1) | ||
.set_attr<FInferShape>("FInferShape", WhereShape) | ||
.set_attr<FInferType>("FInferType", WhereInferType) | ||
.set_attr<FCorrectLayout>("FCorrectLayout", WhereCorrectLayout) | ||
.set_attr<FTVMCompute>( | ||
"FTVMCompute", [](const NodeAttrs& attrs, | ||
const Array<Tensor>& inputs, | ||
const Array<Tensor>& out_info) { | ||
return Array<Tensor>{ | ||
topi::where(inputs[0], inputs[1], inputs[2]) | ||
}; | ||
}) | ||
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) { | ||
return std::vector<std::string>{"condition", "x", "y"}; | ||
}) | ||
.set_support_level(4); | ||
|
||
} // namespace top | ||
} // namespace nnvm |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -575,5 +575,50 @@ inline Tensor take(const Tensor& a, | |
}, name, tag); | ||
} | ||
|
||
/*! | ||
* \brief Return the elements, either from x or y, depending on the condition. | ||
* | ||
* \param condition The condition array. | ||
* \param x First array to be selected. | ||
* \param y Second array to be selected. | ||
* \param name The name of the operation. | ||
* \param tag The tag to mark the operation. | ||
* | ||
* \return A Tensor selected from x or y depending on condition. | ||
*/ | ||
inline Tensor where(const Tensor& condition, | ||
const Tensor& x, | ||
const Tensor& y, | ||
std::string name = "tensor", | ||
std::string tag = kInjective) { | ||
CHECK_EQ(x->shape.size(), y->shape.size()) | ||
<< "x and y must have the same shape.Got different number of dimension: " | ||
<< x->shape.size() << " vs " << y->shape.size(); | ||
Array<Expr> oshape = x->shape; | ||
Tensor out; | ||
|
||
if (condition->shape.size() != 1) { | ||
CHECK_EQ(condition->shape.size(), x->shape.size()) | ||
<< "condition and x must have the same shape.Got " | ||
"different number of dimension: " | ||
<< condition->shape.size() << " vs " << x->shape.size(); | ||
out = compute( | ||
oshape, [&](const Array<Var>& indices) { | ||
return tvm::select(condition(indices) != 0, x(indices), y(indices)); | ||
}, name, tag); | ||
} else { | ||
CHECK_EQ(condition->shape.size(), 1) << "condition array must be either " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this check ever fail as you are checking the same condition in if statement? |
||
"have the same shape as x or to be a 1-D array."; | ||
out = compute( | ||
oshape, [&](const Array<Var>& indices) { | ||
Array<Expr> condition_idx{indices[0]}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is condition index always from the first dimension? Is it 'where' operator's property or user can specify the axis? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are two cases for condition shape. 1. the same as x. 2. to be 1-D. This is for 1-D case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In second case of 1-D condition array, 'condition_idx{indices[0]}' always take first entry of indices array. It can be done 'condition_idx{indices[axis]}' where user can specify the axis along which output is taken from x or y (example: in 2D x and y, choose complete column of x or row of x if condition array has true entry. Current implementation always takes row of x if condition is true.). But never mind, i checked TF and MxNet implementation, both are using first dimension. This looks okay. |
||
return tvm::select(condition(condition_idx) != 0, | ||
x(indices), y(indices)); | ||
}, name, tag); | ||
} | ||
return out; | ||
} | ||
|
||
|
||
} // namespace topi | ||
#endif // TOPI_TRANSFORM_H_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about proper broadcasting to allow arbitrary shape for condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Current implementation aligns with numpy to allow condition shape to be either same as x or 1-D. We can add more broadcasting later if necessary.