Skip to content

Commit

Permalink
[Fix Bug] fix the bug of pool_impl_nd when computing avgpool_nd whith…
Browse files Browse the repository at this point in the history
… ceil_mode and count_include_pad are True. (apache#9835)

* Added the offset[i] for getting the correct  boundary
* Added corresponding test case
  • Loading branch information
xiaolong18 authored Jan 13, 2022
1 parent f9d8c2b commit cc9d2f4
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
21 changes: 18 additions & 3 deletions include/tvm/topi/nn/pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
std::vector<PrimExpr> dilation(k_size);
std::vector<PrimExpr> pad_head(k_size);
std::vector<PrimExpr> pad_tail(k_size);
std::vector<PrimExpr> offset(k_size, 0);
Array<PrimExpr> pad_before(std::vector<PrimExpr>(x_size, 0));
Array<PrimExpr> pad_after(std::vector<PrimExpr>(x_size, 0));
Array<PrimExpr> data_shape = x->shape;
Expand All @@ -539,9 +540,13 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
pad_tail[i] = cast(DataType::Int(32), padding_size[i + k_size]);

if (ceil_mode) {
// Additional padding to ensure we do ceil instead of floor when
// The offset[i] is an additional padding to ensure we do ceil instead of floor when
// dividing by stride.
pad_tail[i] += stride[i] - 1;
// In the case of ceil_mode=True and count_include_pad=True,
// in order to obtain the correct boundary,
// we also need to use the offset[i] to eliminate this extra padding.
offset[i] = stride[i] - 1;
pad_tail[i] += offset[i];
}

const int64_t* padding0 = as_const_int(pad_head[i]);
Expand Down Expand Up @@ -602,9 +607,19 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
if (count_include_pad) {
std::vector<PrimExpr> start(k_size);
std::vector<PrimExpr> end(k_size);
auto num_el = make_const(DataType::Int(32), 1);
for (int i = 0; i < k_size; i++) {
num_el *= kernel[i];
int ii = axis[i];
start[i] = output[ii] * stride[i] - pad_head[i];
// When computing the output shape in ceil_mode,
// we have added the extra padding of offset[i],
// so now in order to calculate the correct boundary ,
// we need to substract the offset[i].
end[i] = start[i] + (kernel[i] - 1) * dilation[i];
end[i] = min(end[i], data_shape[ii] + pad_tail[i] - 1 - offset[i]);
num_el *= (end[i] - start[i]) / dilation[i] + 1;
}
return div(pool_sum(indices), num_el);
} else {
Expand Down
9 changes: 9 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,15 @@ def forward(self, *args):
torch.nn.AvgPool2d(kernel_size=5, stride=2, padding=2).eval(), input_data=input_data
)

input_shape = [1, 1, 1, 9]
input_data = torch.rand(input_shape).float()
verify_model(
torch.nn.AvgPool2d(
kernel_size=[1, 2], stride=[1, 2], ceil_mode=True, count_include_pad=True
).eval(),
input_data=input_data,
)


@tvm.testing.uses_gpu
def test_forward_avgpool3d():
Expand Down

0 comments on commit cc9d2f4

Please sign in to comment.