-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI]Add where operator #1416
[TOPI]Add where operator #1416
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1125,8 +1125,8 @@ Examples:: | |
DMLC_REGISTER_PARAMETER(SliceLikeParam); | ||
|
||
inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs, | ||
std::vector<TShape>* in_attrs, | ||
std::vector<TShape>* out_attrs) { | ||
std::vector<TShape>* in_attrs, | ||
std::vector<TShape>* out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 2U); | ||
CHECK_EQ(out_attrs->size(), 1U); | ||
const SliceLikeParam& param = nnvm::get<SliceLikeParam>(attrs.parsed); | ||
|
@@ -1221,5 +1221,98 @@ NNVM_REGISTER_OP(slice_like) | |
}) | ||
.set_support_level(4); | ||
|
||
// where | ||
inline bool WhereShape(const nnvm::NodeAttrs& attrs, | ||
std::vector<TShape>* in_attrs, | ||
std::vector<TShape>* out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 3U); | ||
CHECK_EQ(out_attrs->size(), 1U); | ||
const TShape& cond_shape = in_attrs->at(0); | ||
const TShape& x_shape = in_attrs->at(1); | ||
const TShape& y_shape = in_attrs->at(2); | ||
CHECK_EQ(x_shape, y_shape) << "x and y must have the same shape: " | ||
<< x_shape << " vs " << y_shape; | ||
if (cond_shape != x_shape) { | ||
CHECK_EQ(cond_shape.ndim(), 1) | ||
<< "Shape of condition " << cond_shape | ||
<< " must be either equal to x or has dimension of 1."; | ||
} | ||
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, x_shape); | ||
return true; | ||
} | ||
|
||
inline bool WhereInferType(const NodeAttrs &attrs, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(1)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also check for dtype of X, Y to be same. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added. |
||
return true; | ||
} | ||
|
||
inline bool WhereCorrectLayout(const NodeAttrs& attrs, | ||
std::vector<Layout> *ilayouts, | ||
const std::vector<Layout> *last_ilayouts, | ||
std::vector<Layout> *olayouts) { | ||
CHECK_EQ(ilayouts->size(), last_ilayouts->size()); | ||
CHECK_EQ(olayouts->size(), 1U); | ||
|
||
for (size_t i = 0; i < ilayouts->size(); ++i) { | ||
const Layout& input = last_ilayouts->at(i).defined() ? | ||
last_ilayouts->at(i) : ilayouts->at(i); | ||
NNVM_ASSIGN_LAYOUT(*ilayouts, i, input); | ||
} | ||
|
||
return true; | ||
} | ||
|
||
NNVM_REGISTER_OP(where) | ||
.describe(R"code( | ||
Return the elements, either from x or y, depending on the condition. | ||
|
||
Given three ndarrays, condition, x, and y, return an ndarray with the elements | ||
from x or y, depending on the elements from condition are true or false. | ||
x and y must have the same shape. If condition has the same shape as x, | ||
each element in the output array is from x if the corresponding element | ||
in the condition is true, and from y if false. | ||
|
||
If condition does not have the same shape as x, it must be a 1D array whose | ||
size is the same as x’s first dimension size. Each row of the output array | ||
is from x’s row if the corresponding element from condition is true, and | ||
from y’s row if false. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggest to add some basic example. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added. |
||
Note that all non-zero values are interpreted as True in condition. | ||
|
||
Examples:: | ||
|
||
x = [[1, 2], [3, 4]] | ||
y = [[5, 6], [7, 8]] | ||
cond = [[0, 1], [-1, 0]] | ||
where(cond, x, y) = [[5, 2], [3, 8]] | ||
|
||
|
||
cond = [1, 0] | ||
where(cond, x, y) = [[1, 2], [7, 8]] | ||
|
||
)code" NNVM_ADD_FILELINE) | ||
.add_argument("condition", "Tensor", "Condition array") | ||
.add_argument("x", "Tensor", "First array to be selected") | ||
.add_argument("y", "Tensor", "Second array to be selected") | ||
.set_num_inputs(3) | ||
.set_num_outputs(1) | ||
.set_attr<FInferShape>("FInferShape", WhereShape) | ||
.set_attr<FInferType>("FInferType", WhereInferType) | ||
.set_attr<FCorrectLayout>("FCorrectLayout", WhereCorrectLayout) | ||
.set_attr<FTVMCompute>( | ||
"FTVMCompute", [](const NodeAttrs& attrs, | ||
const Array<Tensor>& inputs, | ||
const Array<Tensor>& out_info) { | ||
return Array<Tensor>{ | ||
topi::where(inputs[0], inputs[1], inputs[2]) | ||
}; | ||
}) | ||
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) { | ||
return std::vector<std::string>{"condition", "x", "y"}; | ||
}) | ||
.set_support_level(4); | ||
|
||
} // namespace top | ||
} // namespace nnvm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about proper broadcasting to allow arbitrary shape for condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Current implementation aligns with numpy to allow condition shape to be either same as x or 1-D. We can add more broadcasting later if necessary.