-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI][RELAY] Add op Size #3094
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also please fix the code style: 2-space indent for c++
src/relay/op/tensor/unary.cc
Outdated
CHECK(tt != nullptr); | ||
const auto* param = attrs.as<SizeAttrs>(); | ||
CHECK(param != nullptr); | ||
reporter->Assign(types[1], TensorTypeNode::make({1}, param->dtype)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like type of size
in numpy is scalar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, both numpy.size and tf.size are scalar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but {1}
here is tensor? I think we need to keep it the same as np
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, sure I'll update the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some nits
topi/include/topi/transform.h
Outdated
const std::string name = "size", | ||
const std::string tag = kInjective) { | ||
int ndim = static_cast<int>(src->shape.size()); | ||
Array<Expr> out_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use out_size{1}
directly?
topi/include/topi/transform.h
Outdated
* \return Tensor of input shape. | ||
*/ | ||
inline Tensor size(const Tensor& src, | ||
Type dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const Type& ?
topi/src/topi.cc
Outdated
@@ -300,6 +300,11 @@ TVM_REGISTER_GLOBAL("topi.shape") | |||
*rv = shape(args[0], args[1]); | |||
}); | |||
|
|||
TVM_REGISTER_GLOBAL("topi.size") | |||
.set_body([](TVMArgs args, TVMRetValue *rv) { | |||
*rv = size(args[0], args[1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 space indent
return | ||
tvm_input = tvm.nd.array(input, ctx=ctx) | ||
tvm_output = tvm.nd.empty((1,), ctx=ctx, dtype=B.dtype) | ||
print("Running on target: %s" % device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove or log?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simply use print here to make it consistent with all the other tests in topi. All test cases are using print.
topi/include/topi/transform.h
Outdated
*/ | ||
inline Tensor size(const Tensor& src, | ||
Type dtype, | ||
const std::string name = "size", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
either std::string
or const std::string&
?
src/relay/op/tensor/unary.cc
Outdated
.set_attr<TOpPattern>("TOpPattern", kInjective) | ||
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", | ||
ElemwiseArbitraryLayout) | ||
.set_support_level(3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change support level to 10.
tests/python/relay/test_op_level3.py
Outdated
@@ -607,6 +607,24 @@ def verify_gather_nd(xshape, yshape, y_data): | |||
verify_gather_nd((2, 2, 2), (2, 2), [[0, 1], [1, 0]]) | |||
|
|||
|
|||
def test_size(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the test to test_op_level10.py
We need to debate a bit about the API. given that numpy's size is ndarray.size, we may not want to directly use |
@icemelon9 need to address comments, especially the returned scalar representation. I am working on it. |
I think we should name |
@jroesch the naming convention is a bit debatable, given numpy use ndarray and we also use NDArray as our storage type, my guess is that array_size is fine if we want to be relatively consistent with numpy |
relay.contrib.num_elements seems to be a clear name |
69ca070
to
48d69ac
Compare
@yzhliu @jroesch @zhiics @icemelon9 @tqchen @kevinthesun have incorporated comments except for the scalar representation. I tried to add scalar support in my own branch. But failed to run on graph runtime since the schedule was mutated in Halide improperly. Filed an issue. |
Can you rebase? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just address Zhi's comments?
I wonder if this commit will continue? |
@yongwww please rebase, rename to ndarray_size and let us merge this |
@tqchen renamed. Pls take a look |
Please rename the topi api name as well. |
@tqchen updated. Seems CI hangs |
Add op Size. Size is used in numpy, TensorFlow and MXNet, is needed for tf object detection models like mask-rcnn.
@zhiics @kevinthesun @icemelon9 @jroesch @Laurawly