-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Op] Adaptive pooling #3085
Conversation
include/tvm/relay/attrs/nn.h
Outdated
@@ -332,6 +332,22 @@ struct GlobalPool2DAttrs : public tvm::AttrsNode<GlobalPool2DAttrs> { | |||
} | |||
}; | |||
|
|||
/*! \brief Attributes for global pool operator */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
global -> adaptive?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good
python/tvm/relay/op/nn/nn.py
Outdated
layout="NCHW"): | ||
r"""2D adaptive max pooling operator. | ||
|
||
This operator takes data as input and does 2D average value calculation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
average->max
auto i_end_h = end_index(output[height_axis], out_height, height); | ||
auto i_start_w = start_index(output[width_axis], out_width, width); | ||
auto i_end_w = end_index(output[width_axis], out_width, width); | ||
auto dheight = tvm::reduce_axis(Range(0, i_end_h - i_start_h), "rv1"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, is naming the IterVar rv1 necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is optional.
const Expr& odim, | ||
const Expr& idim) { | ||
Expr tmp = (out_index + 1) * idim / odim; | ||
return tvm::ir::Select::make((out_index + 1) * idim % odim == 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason not to explicitly take the ceil here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ceil will involve several casts and causing result inaccurate.
topi/python/topi/cuda/pooling.py
Outdated
|
||
Parameters | ||
---------- | ||
outs: Array of Tensor | ||
The computation graph description of global_pool | ||
The computation graph description of adaptive_poo |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adaptive_poo->adaptive_pool
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a few nits.
src/relay/op/nn/pooling.cc
Outdated
const auto* data = types[0].as<TensorTypeNode>(); | ||
if (data == nullptr) { return false; } | ||
const auto dshape = data->shape; | ||
CHECK_NE(dshape.size(), 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove? since dshape.size() >= 2
src/relay/op/nn/pooling.cc
Outdated
CHECK_NE(dshape.size(), 0); | ||
CHECK_GE(dshape.size(), 2U) | ||
<< "Pool2D only support input >= 2-D: input must have height and width"; | ||
const auto param = attrs.as<AdaptivePool2DAttrs>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const auto*
to reduce confusion since as
returns a pointer?
c23b218
to
060309b
Compare
Please take some time to discuss the API naming and docs, as adaptive_pooling seems to be something that is not conventional. |
@tqchen This naming comes from pytorch pooling family: https://pytorch.org/docs/stable/_modules/torch/nn/modules/pooling.html MXNet has a similar API for adaptive avg pooing: https://mxnet.apache.org/api/python/ndarray/contrib.html?highlight=adap#mxnet.ndarray.contrib.AdaptiveAvgPooling2D For now I don't see a more appropriate name in mainsteam DL frameworks. Also adaptive describes the behavior of this kind of pooling. |
let us make it contrib.adaptive_pooling then and mark as experimental |
c357607
to
7fdb3f2
Compare
abbd9dd
to
2de44ca
Compare
@tqchen This is ready to be reviewed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also fix the mxnet converter for adaptive pooling?
@kevinthesun Could you also update the mxnet converter for adaptive pooling? |
Already updated. |
Sorry, didn't notice that... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Notice Tianqi's comment. Could you change the nn.contrib_adaptive_avg_pool2d
and nn.contrib_adaptive_max_pool2d
to contrib.adaptive_avg_pool2d
and contrib.adaptive_avg_pool2d
?
Please also update the doc at docs/langref/relay_op.rst
and docs/api/python/topi.rst
.
I will leave the PR management to @icemelon9 |
Thanks everyone. This is now merged. |
* Add topi adaptive_pool * Use adaptive_pool to compute global_pool * Add relay adaptive pool2d * Fix lint * Fix typo * Minor change * Change support level to 10 * Add contrib * Remove global pool schedule * Add contrib module * Fix lint * Update doc * Update doc
* Add topi adaptive_pool * Use adaptive_pool to compute global_pool * Add relay adaptive pool2d * Fix lint * Fix typo * Minor change * Change support level to 10 * Add contrib * Remove global pool schedule * Add contrib module * Fix lint * Update doc * Update doc
Adaptive pooling operator.