Skip to content

Commit

Permalink
Fix pool2d issues. Add test case. (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
mangguo321 authored May 13, 2021
1 parent 8e8b374 commit f01ceb8
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 41 deletions.
32 changes: 4 additions & 28 deletions ngraph/frontend/paddlepaddle/src/op/pool2d.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,7 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include <ngraph/opsets/opset6.hpp>
Expand Down Expand Up @@ -47,30 +37,16 @@ static void get_paddings(const NodeContext& node, ngraph::Shape& pad_begin, ngra
auto paddings = node.get_attribute<std::vector<int32_t>>("paddings");
auto data_format = node.get_attribute<std::string>("data_format");

// TODO: need to support NHWC input
switch (paddings.size())
{
case 1:
pad_begin = Shape(2, paddings[0]);
pad_end = pad_begin;
break;
case 2:
pad_begin = Shape{static_cast<uint64_t>(paddings[0]), static_cast<uint64_t>(paddings[1])};
pad_end = pad_begin;
break;
case 4:
pad_begin = Shape{static_cast<uint64_t>(paddings[0]), static_cast<uint64_t>(paddings[2])};
pad_end = Shape(static_cast<uint64_t>(paddings[1]), static_cast<uint64_t>(paddings[3]));
break;
case 8:
if (data_format == "NCHW") {
pad_begin = Shape{static_cast<uint64_t>(paddings[4]), static_cast<uint64_t>(paddings[6])};
pad_end = Shape(static_cast<uint64_t>(paddings[5]), static_cast<uint64_t>(paddings[7]));
} else if (data_format == "NHWC") {
pad_begin = Shape{static_cast<uint64_t>(paddings[2]), static_cast<uint64_t>(paddings[4])};
pad_end = Shape(static_cast<uint64_t>(paddings[3]), static_cast<uint64_t>(paddings[5]));
} else {
throw std::runtime_error("Unsupported pooling data_format " + data_format);
}
pad_end = Shape{static_cast<uint64_t>(paddings[1]), static_cast<uint64_t>(paddings[3])};
break;
default:
throw std::runtime_error("Unsupported pooling paddings " + paddings.size());
Expand Down
14 changes: 2 additions & 12 deletions ngraph/frontend/paddlepaddle/src/op/pool2d.hpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,7 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#pragma once
Expand Down
49 changes: 48 additions & 1 deletion ngraph/test/files/paddlepaddle/gen_scripts/generate_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,54 @@ def main():
}
# shape of out_6: [2, 4, 3, 3] which is different from out_1
pool2d(pooling_type+'Pool_test6', data_NHWC, pdpd_attrs)
#

# example 7:
# pool_padding size is 1
pdpd_attrs = {
'pool_size':[3,3],
'pool_type' : pooling_type,
'pool_stride' : [3,3],
'pool_padding':2,
'global_pooling':False,
'ceil_mode':False,
'exclusive':True,
'data_format':"NCHW"
}
pool2d(pooling_type+'Pool_test7', data_NCHW, pdpd_attrs)

#input data for test8 and test9
N_data1, C_data1, H_data1, W_data1 = 2, 3, 8, 8
data1 = np.arange(N_data1*C_data1*H_data1*W_data1).astype(data_type)
data1_NCHW = data1.reshape(N_data1, C_data1, H_data1, W_data1)
data1_NHWC = data1.reshape(N_data1, H_data1, W_data1, C_data1)
# example 8:
# pool_padding size is 4: [pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]
pdpd_attrs = {
'pool_size':[3,3],
'pool_type' : pooling_type,
'pool_stride' : [3,3],
'pool_padding':[2, 1, 2, 1],
'global_pooling':False,
'ceil_mode':False,
'exclusive':True,
'data_format':"NCHW"
}
pool2d(pooling_type+'Pool_test8', data1_NCHW, pdpd_attrs)

# example 9:
# input=data_NCHW and pool_padding is [[0,0], [0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]
pdpd_attrs = {
'pool_size':[3,3],
'pool_type' : pooling_type,
'pool_stride' : [3,3],
'pool_padding':[[0,0], [0,0], [2, 1], [2, 1]],
'global_pooling':False,
'ceil_mode':False,
'exclusive':True,
'data_format':"NCHW"
}
pool2d(pooling_type+'Pool_test9', data1_NCHW, pdpd_attrs)


# adaptive_pool2d
for i, pooling_type in enumerate(pooling_types):
Expand Down
6 changes: 6 additions & 0 deletions ngraph/test/files/paddlepaddle/models/models.csv
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ avgPool_test3,
avgPool_test4,
avgPool_test5,
#avgPool_test6,
avgPool_test7,
avgPool_test8,
avgPool_test9,
batch_norm,
bilinear_downsample_false_0,
bilinear_downsample_false_1,
Expand Down Expand Up @@ -53,6 +56,9 @@ maxPool_test3,
maxPool_test4,
maxPool_test5,
#maxPool_test6,
maxPool_test7,
maxPool_test8,
maxPool_test9,
nearest_downsample_false_0,
nearest_upsample_false_0,
pad3d_test1,
Expand Down

0 comments on commit f01ceb8

Please sign in to comment.