From e433486380ed4aae8b58994065eec2e56bb73216 Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 28 Dec 2022 13:04:19 +0400 Subject: [PATCH 1/3] generalize conv2d implementation for conv1d and conv3d --- .../pytorch/src/op/{conv2d.cpp => convnd.cpp} | 12 +- src/frontends/pytorch/src/op_table.cpp | 6 +- .../layer_tests/pytorch_tests/test_conv2d.py | 53 ------ .../layer_tests/pytorch_tests/test_convnd.py | 151 ++++++++++++++++++ 4 files changed, 165 insertions(+), 57 deletions(-) rename src/frontends/pytorch/src/op/{conv2d.cpp => convnd.cpp} (84%) delete mode 100644 tests/layer_tests/pytorch_tests/test_conv2d.py create mode 100644 tests/layer_tests/pytorch_tests/test_convnd.py diff --git a/src/frontends/pytorch/src/op/conv2d.cpp b/src/frontends/pytorch/src/op/convnd.cpp similarity index 84% rename from src/frontends/pytorch/src/op/conv2d.cpp rename to src/frontends/pytorch/src/op/convnd.cpp index 91277f88ec6199..8e17ce3a285275 100644 --- a/src/frontends/pytorch/src/op/conv2d.cpp +++ b/src/frontends/pytorch/src/op/convnd.cpp @@ -11,7 +11,7 @@ namespace frontend { namespace pytorch { namespace op { -OutputVector translate_conv2d(NodeContext& context) { +OutputVector translate_convnd(NodeContext& context) { auto strides = context.const_input(3); // In torch pads at beginning are same as at end auto pads = CoordinateDiff(strides.size(), 0); @@ -49,8 +49,16 @@ OutputVector translate_conv2d(NodeContext& context) { dilations, pad_type); } + if (!context.input_is_none(2)) { + auto bias = context.get_input(2); + auto bias_rank = bias.get_partial_shape().rank(); + if (bias_rank == 1) { + bias = reshape_conv_bias(context, bias, conv); + } + conv = context.mark_node(std::make_shared(conv, bias)); + } - return {context.mark_output(make_optional_bias(conv, context, 2, {-2, -1}))}; + return {conv}; }; } // namespace op diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 087457b4330d51..a624fdd75dd1a8 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -24,7 +24,7 @@ OP_CONVERTER(translate_avg_pool2d); OP_CONVERTER(translate_batch_norm); OP_CONVERTER(translate_clamp); OP_CONVERTER(translate_constant); -OP_CONVERTER(translate_conv2d); +OP_CONVERTER(translate_convnd); OP_CONVERTER(translate_convolution); OP_CONVERTER(translate_convolution_mode); OP_CONVERTER(translate_dim); @@ -108,7 +108,9 @@ const std::map get_supported_ops() { {"aten::clone", op::skip_node}, // ignore clone operators that are inserted by PyTorch autograd {"aten::contiguous", op::skip_node}, // In openvino how tensors are stored in memory is internal plugin detail, // we assume all tensors are contiguous - {"aten::conv2d", op::translate_conv2d}, + {"aten::conv1d", op::translate_convnd}, + {"aten::conv3d", op::translate_convnd}, + {"aten::conv2d", op::translate_convnd}, {"aten::convolution", op::translate_convolution}, {"aten::cumsum", op::translate_1to1_match_2_inputs}, {"aten::dim", op::translate_dim}, diff --git a/tests/layer_tests/pytorch_tests/test_conv2d.py b/tests/layer_tests/pytorch_tests/test_conv2d.py deleted file mode 100644 index 2c76b65e4c18d5..00000000000000 --- a/tests/layer_tests/pytorch_tests/test_conv2d.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (C) 2018-2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import pytest -from pytorch_layer_test_class import PytorchLayerTest - - -class TestConv2D(PytorchLayerTest): - def _prepare_input(self): - import numpy as np - return (np.random.randn(2, 3, 25, 25).astype(np.float32),) - - def create_model(self, weights_shape, strides, pads, dilations, groups, bias): - - import torch - import torch.nn.functional as F - - class aten_conv2d(torch.nn.Module): - def __init__(self): - super(aten_conv2d, self).__init__() - self.weight = torch.randn(weights_shape) - self.bias = None - if bias: - self.bias = torch.randn(weights_shape[0]) - self.strides = strides - self.pads = pads - self.dilations = dilations - self.groups = groups - - def forward(self, x): - return F.conv2d(x, self.weight, self.bias, self.strides, self.pads, self.dilations, self.groups) - - ref_net = None - - return aten_conv2d(), ref_net, "aten::conv2d" - - @pytest.mark.parametrize("params", - [{'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 1, 'groups': 1}, - {'weights_shape': [1, 3, 3, 3], 'strides': 2, 'pads': 0, 'dilations': 1, 'groups': 1}, - {'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': 1, 'dilations': 1, 'groups': 1}, - {'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 2, 'groups': 1}, - {'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': [0, 1], 'dilations': 1, 'groups': 1}, - {'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': [1, 0], 'dilations': 1, 'groups': 1}, - {'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': 'same', 'dilations': 1, 'groups': 1}, - {'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': 'valid', 'dilations': 1, 'groups': 1}, - # doesn't work because input shape is dynamic which makes kernel shape dynamic - # {'weights_shape': [2, 3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 1, 'groups': 2}, - ]) - @pytest.mark.parametrize("bias", [True, False]) - @pytest.mark.nightly - def test_conv2d(self, params, bias, ie_device, precision, ir_version): - self._test(*self.create_model(**params, bias=bias), - ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_convnd.py b/tests/layer_tests/pytorch_tests/test_convnd.py new file mode 100644 index 00000000000000..62afc05a93c360 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_convnd.py @@ -0,0 +1,151 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from pytorch_layer_test_class import PytorchLayerTest + + +class TestConv2D(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(2, 3, 25, 25).astype(np.float32),) + + def create_model(self, weights_shape, strides, pads, dilations, groups, bias): + + import torch + import torch.nn.functional as F + + class aten_conv2d(torch.nn.Module): + def __init__(self): + super(aten_conv2d, self).__init__() + self.weight = torch.randn(weights_shape) + self.bias = None + if bias: + self.bias = torch.randn(weights_shape[0]) + self.strides = strides + self.pads = pads + self.dilations = dilations + self.groups = groups + + def forward(self, x): + return F.conv2d(x, self.weight, self.bias, self.strides, self.pads, self.dilations, self.groups) + + ref_net = None + + return aten_conv2d(), ref_net, "aten::conv2d" + + @pytest.mark.parametrize("params", + [{'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 1, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3], 'strides': 2, 'pads': 0, 'dilations': 1, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': 1, 'dilations': 1, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 2, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': [0, 1], 'dilations': 1, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': [1, 0], 'dilations': 1, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': 'same', 'dilations': 1, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3], 'strides': 1, 'pads': 'valid', 'dilations': 1, 'groups': 1}, + # doesn't work because input shape is dynamic which makes kernel shape dynamic + # {'weights_shape': [2, 3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 1, 'groups': 2}, + ]) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.nightly + def test_conv2d(self, params, bias, ie_device, precision, ir_version): + self._test(*self.create_model(**params, bias=bias), + ie_device, precision, ir_version) + + +class TestConv1D(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(2, 3, 25).astype(np.float32),) + + def create_model(self, weights_shape, strides, pads, dilations, groups, bias): + + import torch + import torch.nn.functional as F + + class aten_conv1d(torch.nn.Module): + def __init__(self): + super(aten_conv1d, self).__init__() + self.weight = torch.randn(weights_shape) + self.bias = None + if bias: + self.bias = torch.randn(weights_shape[0]) + self.strides = strides + self.pads = pads + self.dilations = dilations + self.groups = groups + + def forward(self, x): + return F.conv1d(x, self.weight, self.bias, self.strides, self.pads, self.dilations, self.groups) + + ref_net = None + + return aten_conv1d(), ref_net, "aten::conv1d" + + @pytest.mark.parametrize("params", + [{'weights_shape': [3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 1, 'groups': 1}, + {'weights_shape': [3, 3, 3], 'strides': 2, 'pads': 0, 'dilations': 1, 'groups': 1}, + {'weights_shape': [3, 3, 3], 'strides': 1, 'pads': 1, 'dilations': 1, 'groups': 1}, + {'weights_shape': [3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 2, 'groups': 1}, + {'weights_shape': [3, 3, 3], 'strides': 1, 'pads': 'same', 'dilations': 1, 'groups': 1}, + {'weights_shape': [3, 3, 3], 'strides': 1, 'pads': 'valid', 'dilations': 1, 'groups': 1}, + # doesn't work because input shape is dynamic which makes kernel shape dynamic + # {'weights_shape': [3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 1, 'groups': 2}, + ]) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.nightly + def test_conv1d(self, params, bias, ie_device, precision, ir_version): + self._test(*self.create_model(**params, bias=bias), + ie_device, precision, ir_version) + + +class TestConv3D(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(2, 3, 25, 25, 25).astype(np.float32),) + + def create_model(self, weights_shape, strides, pads, dilations, groups, bias): + + import torch + import torch.nn.functional as F + + class aten_conv3d(torch.nn.Module): + def __init__(self): + super(aten_conv3d, self).__init__() + self.weight = torch.randn(weights_shape) + self.bias = None + if bias: + self.bias = torch.randn(weights_shape[0]) + self.strides = strides + self.pads = pads + self.dilations = dilations + self.groups = groups + + def forward(self, x): + return F.conv3d(x, self.weight, self.bias, self.strides, self.pads, self.dilations, self.groups) + + ref_net = None + + return aten_conv3d(), ref_net, "aten::conv3d" + + @pytest.mark.parametrize("params", + [{'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 1, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3, 3], 'strides': 2, 'pads': 0, 'dilations': 1, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': 1, 'dilations': 1, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 2, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': [0, 1, 0], 'dilations': 1, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': [1, 0, 0], 'dilations': 1, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': [0, 0, 1], 'dilations': 1, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': [1, 1, 0], 'dilations': 1, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': [0, 1, 1], 'dilations': 1, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': [1, 0, 1], 'dilations': 1, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': 'same', 'dilations': 1, 'groups': 1}, + {'weights_shape': [1, 3, 3, 3, 3], 'strides': 1, 'pads': 'valid', 'dilations': 1, 'groups': 1}, + # doesn't work because input shape is dynamic which makes kernel shape dynamic + # {'weights_shape': [2, 3, 3, 3], 'strides': 1, 'pads': 0, 'dilations': 1, 'groups': 2}, + ]) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.nightly + def test_conv3d(self, params, bias, ie_device, precision, ir_version): + self._test(*self.create_model(**params, bias=bias), + ie_device, precision, ir_version) \ No newline at end of file From 67d465ebf275af13f230b3ba66a191a6069acbb4 Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 28 Dec 2022 20:00:37 +0400 Subject: [PATCH 2/3] extend and add tests for avg_pool and max_pool --- src/frontends/pytorch/src/op/avg_pool2d.cpp | 36 ----- src/frontends/pytorch/src/op/avg_poolnd.cpp | 53 +++++++ .../src/op/{max_pool2d.cpp => max_poolnd.cpp} | 9 +- src/frontends/pytorch/src/op_table.cpp | 14 +- .../layer_tests/pytorch_tests/test_pooling.py | 146 ++++++++++++++++++ 5 files changed, 212 insertions(+), 46 deletions(-) delete mode 100644 src/frontends/pytorch/src/op/avg_pool2d.cpp create mode 100644 src/frontends/pytorch/src/op/avg_poolnd.cpp rename src/frontends/pytorch/src/op/{max_pool2d.cpp => max_poolnd.cpp} (81%) create mode 100644 tests/layer_tests/pytorch_tests/test_pooling.py diff --git a/src/frontends/pytorch/src/op/avg_pool2d.cpp b/src/frontends/pytorch/src/op/avg_pool2d.cpp deleted file mode 100644 index e26e0b52875b02..00000000000000 --- a/src/frontends/pytorch/src/op/avg_pool2d.cpp +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (C) 2018-2022 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "openvino/frontend/pytorch/node_context.hpp" -#include "openvino/opsets/opset8.hpp" -#include "utils.hpp" - -namespace ov { -namespace frontend { -namespace pytorch { -namespace op { - -OutputVector translate_avg_pool2d(NodeContext& context) { - auto kernel = context.const_input(1); - auto strides = context.const_input(2); - auto pads_begin = context.const_input(3); // FIXME: The same 3 is used twice - auto pads_end = context.const_input(3); // FIXME: The same 3 is used twice - auto rounding_type = context.const_input(4) ? ov::op::RoundingType::CEIL : ov::op::RoundingType::FLOOR; - auto exclude_pad = !context.const_input(5); - FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(6), - "Translation for aten::avg_pool2d do not support divisor_override input."); - - return {context.mark_node(std::make_shared(context.get_input(0), - strides, - pads_begin, - pads_end, - kernel, - exclude_pad, - rounding_type))}; -}; - -} // namespace op -} // namespace pytorch -} // namespace frontend -} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op/avg_poolnd.cpp b/src/frontends/pytorch/src/op/avg_poolnd.cpp new file mode 100644 index 00000000000000..b4f030ebafe772 --- /dev/null +++ b/src/frontends/pytorch/src/op/avg_poolnd.cpp @@ -0,0 +1,53 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/opsets/opset8.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_avg_poolnd(NodeContext& context) { + auto input = context.get_input(0); + auto kernel = context.const_input(1); + auto strides = context.const_input(2); + auto pads = context.const_input(3); // pytorch supports only symmetric padding + auto rounding_type = context.const_input(4) ? ov::op::RoundingType::CEIL : ov::op::RoundingType::FLOOR; + auto count_include_pad = context.const_input(5); + FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(6), + "Translation for aten::avg_pool2d do not support divisor_override input."); + // Although ov::AvgPool provides exclude_pad=false, + // The corner case of Average Pooling with ceil_mode on + // PyTorch allows sliding window go off bound, which leads to this accommodation. + // More detail on https://github.com/pytorch/pytorch/issues/57178 + if (count_include_pad){ + auto zero = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0})); + auto zero_i32 = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {0})); + auto shape = context.mark_node(std::make_shared(input, element::i32)); + auto rank = context.mark_node(std::make_shared(shape, element::i32)); + auto pad_values = context.get_input(3); + auto pads_len = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {pads.size()})); + auto pads_diff = context.mark_node(std::make_shared(rank, pads_len)); + auto pads_remaining = context.mark_node(std::make_shared(zero_i32, pads_diff)); + auto padding = context.mark_node(std::make_shared(NodeVector{pads_remaining, pad_values.get_node_shared_ptr()}, 0)); + input = context.mark_node(std::make_shared(input, padding, padding, zero, ov::op::PadMode::CONSTANT)); + pads = Shape(pads.size(), 0); + } + + return {context.mark_node(std::make_shared(input, + strides, + pads, + pads, + kernel, + !count_include_pad, + rounding_type))}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op/max_pool2d.cpp b/src/frontends/pytorch/src/op/max_poolnd.cpp similarity index 81% rename from src/frontends/pytorch/src/op/max_pool2d.cpp rename to src/frontends/pytorch/src/op/max_poolnd.cpp index 9b4a12ca49b4f9..049d842c3fd866 100644 --- a/src/frontends/pytorch/src/op/max_pool2d.cpp +++ b/src/frontends/pytorch/src/op/max_poolnd.cpp @@ -11,19 +11,18 @@ namespace frontend { namespace pytorch { namespace op { -OutputVector translate_max_pool2d(NodeContext& context) { +OutputVector translate_max_poolnd(NodeContext& context) { auto kernel = context.const_input(1); auto strides = context.const_input(2); - auto pads_begin = context.const_input(3); // FIXME: The same 3 is used twice - auto pads_end = context.const_input(3); // FIXME: The same 3 is used twice + auto pads = context.const_input(3); // pytorch supports only symmetric paddings auto dilations = context.const_input(4); auto rounding_type = context.const_input(5) ? ov::op::RoundingType::CEIL : ov::op::RoundingType::FLOOR; return {context.mark_node(std::make_shared(context.get_input(0), strides, dilations, - pads_begin, - pads_end, + pads, + pads, kernel, rounding_type))}; }; diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index a624fdd75dd1a8..30ad01a8332833 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -20,7 +20,7 @@ OP_CONVERTER(translate_add); OP_CONVERTER(translate_addcmul); OP_CONVERTER(translate_addmm); OP_CONVERTER(translate_as_tensor); -OP_CONVERTER(translate_avg_pool2d); +OP_CONVERTER(translate_avg_poolnd); OP_CONVERTER(translate_batch_norm); OP_CONVERTER(translate_clamp); OP_CONVERTER(translate_constant); @@ -45,7 +45,7 @@ OP_CONVERTER(translate_int); OP_CONVERTER(translate_layer_norm); OP_CONVERTER(translate_linear); OP_CONVERTER(translate_loop); -OP_CONVERTER(translate_max_pool2d); +OP_CONVERTER(translate_max_poolnd); OP_CONVERTER(translate_max); OP_CONVERTER(translate_masked_fill); OP_CONVERTER(translate_mean); @@ -99,7 +99,9 @@ const std::map get_supported_ops() { {"aten::addcmul", op::translate_addcmul}, {"aten::addmm", op::translate_addmm}, {"aten::as_tensor", op::translate_as_tensor}, - {"aten::avg_pool2d", op::translate_avg_pool2d}, + {"aten::avg_pool1d", op::translate_avg_poolnd}, + {"aten::avg_pool2d", op::translate_avg_poolnd}, + {"aten::avg_pool3d", op::translate_avg_poolnd}, {"aten::batch_norm", op::translate_batch_norm}, // {"aten::cat", done as transformation}, {"aten::clamp", op::translate_clamp}, @@ -109,8 +111,8 @@ const std::map get_supported_ops() { {"aten::contiguous", op::skip_node}, // In openvino how tensors are stored in memory is internal plugin detail, // we assume all tensors are contiguous {"aten::conv1d", op::translate_convnd}, - {"aten::conv3d", op::translate_convnd}, {"aten::conv2d", op::translate_convnd}, + {"aten::conv3d", op::translate_convnd}, {"aten::convolution", op::translate_convolution}, {"aten::cumsum", op::translate_1to1_match_2_inputs}, {"aten::dim", op::translate_dim}, @@ -144,7 +146,9 @@ const std::map get_supported_ops() { {"aten::matmul", op::translate_1to1_match_2_inputs}, {"aten::masked_fill", op::translate_masked_fill}, {"aten::masked_fill_", op::inplace_op}, - {"aten::max_pool2d", op::translate_max_pool2d}, + {"aten::max_pool1d", op::translate_max_poolnd}, + {"aten::max_pool2d", op::translate_max_poolnd}, + {"aten::max_pool3d", op::translate_max_poolnd}, {"aten::max", op::translate_max}, {"aten::mean", op::translate_mean}, {"aten::min", op::translate_min}, diff --git a/tests/layer_tests/pytorch_tests/test_pooling.py b/tests/layer_tests/pytorch_tests/test_pooling.py new file mode 100644 index 00000000000000..c69b0576f4c65e --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_pooling.py @@ -0,0 +1,146 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from pytorch_layer_test_class import PytorchLayerTest + + +d2_avg_params = [{'kernel_size': [3, 3], 'stride': 1, 'padding': 0}, + {'kernel_size': [3, 3], 'stride': [1, 1], 'padding': 1}, + {'kernel_size': [3, 3], 'stride': [1, 1], 'padding': [0, 1]}, + {'kernel_size': [3, 3], 'stride': [1, 1], 'padding': [1, 0]}, + {'kernel_size': [3, 3], 'stride': [2, 1], 'padding': 0}, + {'kernel_size': [2, 1], 'stride': [2, 1], 'padding': 0}, + ] + +d1_avg_params = [{'kernel_size': 3, 'stride': 1, 'padding': 0}, + {'kernel_size': (4,), 'stride': 1, 'padding': 1}, + {'kernel_size': 4, 'stride': (5,), 'padding': 2}, + ] +d3_avg_params = [{'kernel_size': [3, 3, 3], 'stride': 1, 'padding': 0}, + {'kernel_size': [3, 3, 3], 'stride': [1, 1, 1], 'padding': 1}, + {'kernel_size': [3, 3, 3], 'stride': [3, 3, 3], 'padding': [0, 0, 0]}, + {'kernel_size': [3, 2, 1], 'stride': [3, 1, 1], 'padding': [0, 0, 0]}, + ] + + + +class TestPooling(PytorchLayerTest): + def _prepare_input(self, ndim=4): + import numpy as np + shape = (1, 3, 15, 15, 15) + return (np.random.randn(*shape[:ndim]).astype(np.float32),) + + def create_model(self, op_type, kernel_size, stride, padding, dilation=1, ceil_mode=True, count_include_pad=True): + + import torch + + class aten_avg_pooling_base(torch.nn.Module): + def __init__(self): + super(aten_avg_pooling_base, self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + + def forward(self, x): + pass + + class aten_max_pooling_base(torch.nn.Module): + def __init__(self): + super(aten_max_pooling_base, self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.ceil_mode = ceil_mode + + def forward(self, x): + pass + + class aten_avg_pool2d(aten_avg_pooling_base): + def forward(self, x): + return torch.nn.functional.avg_pool2d(x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) + + class aten_avg_pool3d(aten_avg_pooling_base): + def forward(self, x): + return torch.nn.functional.avg_pool3d(x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) + + class aten_avg_pool1d(aten_avg_pooling_base): + def forward(self, x): + return torch.nn.functional.avg_pool1d(x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) + + class aten_max_pool2d(aten_max_pooling_base): + def forward(self, x): + return torch.nn.functional.max_pool2d(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode) + + class aten_max_pool3d(aten_max_pooling_base): + def forward(self, x): + return torch.nn.functional.max_pool3d(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode) + + class aten_max_pool1d(aten_max_pooling_base): + def forward(self, x): + return torch.nn.functional.max_pool1d(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode) + + ops = { + "max_pool1d": aten_max_pool1d, + "max_pool2d": aten_max_pool2d, + "max_pool3d": aten_max_pool3d, + "avg_pool1d": aten_avg_pool1d, + "avg_pool2d": aten_avg_pool2d, + "avg_pool3d": aten_avg_pool3d + } + + ref_net = None + aten_pooling = ops[op_type] + + return aten_pooling(), ref_net, f"aten::{op_type}" + + @pytest.mark.parametrize("params", d2_avg_params) + @pytest.mark.parametrize("ceil_mode", [True, False]) + @pytest.mark.parametrize("count_include_pad", [True, False]) + @pytest.mark.nightly + def test_avg_pool2d(self, params, ceil_mode, count_include_pad, ie_device, precision, ir_version): + self._test(*self.create_model("avg_pool2d", **params, ceil_mode=ceil_mode, count_include_pad=count_include_pad), + ie_device, precision, ir_version, trace_model=True, dynamic_shapes=False) + + @pytest.mark.parametrize("params", d1_avg_params) + @pytest.mark.parametrize("ceil_mode", [True, False]) + @pytest.mark.parametrize("count_include_pad", [True, False]) + @pytest.mark.nightly + def test_avg_pool1d(self, params, ceil_mode, count_include_pad, ie_device, precision, ir_version): + self._test(*self.create_model("avg_pool1d", **params, ceil_mode=ceil_mode, count_include_pad=count_include_pad), + ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 3}, trace_model=True, dynamic_shapes=False) + + @pytest.mark.parametrize("params", d3_avg_params) + @pytest.mark.parametrize("ceil_mode", [True, False]) + @pytest.mark.parametrize("count_include_pad", [True, False]) + @pytest.mark.nightly + def test_avg_pool3d(self, params, ceil_mode, count_include_pad, ie_device, precision, ir_version): + self._test(*self.create_model("avg_pool3d", **params, ceil_mode=ceil_mode, count_include_pad=count_include_pad), + ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 5}, trace_model=True, dynamic_shapes=False) + + @pytest.mark.parametrize("params", d2_avg_params) + @pytest.mark.parametrize("ceil_mode", [True, False]) + @pytest.mark.parametrize("dilation", [1, 2]) + @pytest.mark.nightly + def test_max_pool2d(self, params, ceil_mode, dilation, ie_device, precision, ir_version): + self._test(*self.create_model("max_pool2d", **params, ceil_mode=ceil_mode, dilation=dilation), + ie_device, precision, ir_version, dynamic_shapes=False) + + @pytest.mark.parametrize("params", d1_avg_params) + @pytest.mark.parametrize("ceil_mode", [True, False]) + @pytest.mark.parametrize("dilation", [1, 2]) + @pytest.mark.nightly + def test_max_pool1d(self, params, ceil_mode, dilation, ie_device, precision, ir_version): + self._test(*self.create_model("max_pool1d", **params, ceil_mode=ceil_mode, dilation=dilation), + ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 3}, dynamic_shapes=False) + + @pytest.mark.parametrize("params", d3_avg_params) + @pytest.mark.parametrize("ceil_mode", [True, False]) + @pytest.mark.parametrize("dilation", [1, 2]) + @pytest.mark.nightly + def test_max_pool3d(self, params, ceil_mode, dilation, ie_device, precision, ir_version): + self._test(*self.create_model("max_pool3d", **params, ceil_mode=ceil_mode, dilation=dilation), + ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 5}, dynamic_shapes=False) \ No newline at end of file From fb6bb576bfac6d8849b404bda3d869dc9399e832 Mon Sep 17 00:00:00 2001 From: eaidova Date: Fri, 30 Dec 2022 09:09:55 +0400 Subject: [PATCH 3/3] fix code style --- src/frontends/pytorch/src/op/avg_poolnd.cpp | 19 ++++++++----------- src/frontends/pytorch/src/op/max_poolnd.cpp | 2 +- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/frontends/pytorch/src/op/avg_poolnd.cpp b/src/frontends/pytorch/src/op/avg_poolnd.cpp index b4f030ebafe772..c2106a1e6bfeb2 100644 --- a/src/frontends/pytorch/src/op/avg_poolnd.cpp +++ b/src/frontends/pytorch/src/op/avg_poolnd.cpp @@ -24,7 +24,7 @@ OutputVector translate_avg_poolnd(NodeContext& context) { // The corner case of Average Pooling with ceil_mode on // PyTorch allows sliding window go off bound, which leads to this accommodation. // More detail on https://github.com/pytorch/pytorch/issues/57178 - if (count_include_pad){ + if (count_include_pad) { auto zero = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0})); auto zero_i32 = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {0})); auto shape = context.mark_node(std::make_shared(input, element::i32)); @@ -33,18 +33,15 @@ OutputVector translate_avg_poolnd(NodeContext& context) { auto pads_len = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {pads.size()})); auto pads_diff = context.mark_node(std::make_shared(rank, pads_len)); auto pads_remaining = context.mark_node(std::make_shared(zero_i32, pads_diff)); - auto padding = context.mark_node(std::make_shared(NodeVector{pads_remaining, pad_values.get_node_shared_ptr()}, 0)); - input = context.mark_node(std::make_shared(input, padding, padding, zero, ov::op::PadMode::CONSTANT)); + auto padding = context.mark_node( + std::make_shared(NodeVector{pads_remaining, pad_values.get_node_shared_ptr()}, 0)); + input = + context.mark_node(std::make_shared(input, padding, padding, zero, ov::op::PadMode::CONSTANT)); pads = Shape(pads.size(), 0); } - - return {context.mark_node(std::make_shared(input, - strides, - pads, - pads, - kernel, - !count_include_pad, - rounding_type))}; + + return {context.mark_node( + std::make_shared(input, strides, pads, pads, kernel, !count_include_pad, rounding_type))}; }; } // namespace op diff --git a/src/frontends/pytorch/src/op/max_poolnd.cpp b/src/frontends/pytorch/src/op/max_poolnd.cpp index 049d842c3fd866..c0676677090892 100644 --- a/src/frontends/pytorch/src/op/max_poolnd.cpp +++ b/src/frontends/pytorch/src/op/max_poolnd.cpp @@ -14,7 +14,7 @@ namespace op { OutputVector translate_max_poolnd(NodeContext& context) { auto kernel = context.const_input(1); auto strides = context.const_input(2); - auto pads = context.const_input(3); // pytorch supports only symmetric paddings + auto pads = context.const_input(3); // pytorch supports only symmetric paddings auto dilations = context.const_input(4); auto rounding_type = context.const_input(5) ? ov::op::RoundingType::CEIL : ov::op::RoundingType::FLOOR;