-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI]Add where operator #1416
[TOPI]Add where operator #1416
Conversation
@kevinthesun please request reviews from reviewers |
const TShape& y_shape = in_attrs->at(2); | ||
CHECK_EQ(x_shape, y_shape) << "x and y must have the same shape: " | ||
<< x_shape << " vs " << y_shape; | ||
if (cond_shape != x_shape) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about proper broadcasting to allow arbitrary shape for condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Current implementation aligns with numpy to allow condition shape to be either same as x or 1-D. We can add more broadcasting later if necessary.
inline bool WhereInferType(const NodeAttrs &attrs, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also check for dtype of X, Y to be same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
size is the same as x’s first dimension size. Each row of the output array | ||
is from x’s row if the corresponding element from condition is true, and | ||
from y’s row if false. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest to add some basic example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
topi/include/topi/transform.h
Outdated
return tvm::select(condition(indices) != 0, x(indices), y(indices)); | ||
}, name, tag); | ||
} else { | ||
CHECK_EQ(condition->shape.size(), 1) << "condition array must be either " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this check ever fail as you are checking the same condition in if statement?
topi/include/topi/transform.h
Outdated
"have the same shape as x or to be a 1-D array."; | ||
out = compute( | ||
oshape, [&](const Array<Var>& indices) { | ||
Array<Expr> condition_idx{indices[0]}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is condition index always from the first dimension? Is it 'where' operator's property or user can specify the axis?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two cases for condition shape. 1. the same as x. 2. to be 1-D. This is for 1-D case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In second case of 1-D condition array, 'condition_idx{indices[0]}' always take first entry of indices array. It can be done 'condition_idx{indices[axis]}' where user can specify the axis along which output is taken from x or y (example: in 2D x and y, choose complete column of x or row of x if condition array has true entry. Current implementation always takes row of x if condition is true.).
But never mind, i checked TF and MxNet implementation, both are using first dimension. This looks okay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM.
@kevinthesun In 1-D case of condition array, what will happen if the size of condition array does not match with size of first dimension of x array? If it is expected to throw error, you can add a check for it. |
@PariksheetPinjari909 Added. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. LGTM.
Thanks @PariksheetPinjari909 @srkreddy1238 for the reviews, thanks @kevinthesun for the contribution. this is now merged! |
where operator to support gluoncv SSD model. #1269