Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generalize conv2d, avg_pool2d, max_pool2d to support 1d and 3d cases #76

Merged
merged 4 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 0 additions & 36 deletions src/frontends/pytorch/src/op/avg_pool2d.cpp

This file was deleted.

50 changes: 50 additions & 0 deletions src/frontends/pytorch/src/op/avg_poolnd.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset8.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_avg_poolnd(NodeContext& context) {
auto input = context.get_input(0);
auto kernel = context.const_input<Shape>(1);
auto strides = context.const_input<Strides>(2);
auto pads = context.const_input<Shape>(3); // pytorch supports only symmetric padding
auto rounding_type = context.const_input<bool>(4) ? ov::op::RoundingType::CEIL : ov::op::RoundingType::FLOOR;
auto count_include_pad = context.const_input<bool>(5);
FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(6),
"Translation for aten::avg_pool2d do not support divisor_override input.");
// Although ov::AvgPool provides exclude_pad=false,
// The corner case of Average Pooling with ceil_mode on
// PyTorch allows sliding window go off bound, which leads to this accommodation.
// More detail on https://github.com/pytorch/pytorch/issues/57178
if (count_include_pad) {
auto zero = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0}));
auto zero_i32 = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {0}));
auto shape = context.mark_node(std::make_shared<opset8::ShapeOf>(input, element::i32));
auto rank = context.mark_node(std::make_shared<opset8::ShapeOf>(shape, element::i32));
auto pad_values = context.get_input(3);
auto pads_len = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {pads.size()}));
auto pads_diff = context.mark_node(std::make_shared<opset8::Subtract>(rank, pads_len));
auto pads_remaining = context.mark_node(std::make_shared<opset8::Broadcast>(zero_i32, pads_diff));
auto padding = context.mark_node(
std::make_shared<opset8::Concat>(NodeVector{pads_remaining, pad_values.get_node_shared_ptr()}, 0));
input =
context.mark_node(std::make_shared<opset8::Pad>(input, padding, padding, zero, ov::op::PadMode::CONSTANT));
pads = Shape(pads.size(), 0);
}

return {context.mark_node(
std::make_shared<opset8::AvgPool>(input, strides, pads, pads, kernel, !count_include_pad, rounding_type))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_conv2d(NodeContext& context) {
OutputVector translate_convnd(NodeContext& context) {
auto strides = context.const_input<Strides>(3);
// In torch pads at beginning are same as at end
auto pads = CoordinateDiff(strides.size(), 0);
Expand Down Expand Up @@ -49,8 +49,16 @@ OutputVector translate_conv2d(NodeContext& context) {
dilations,
pad_type);
}
if (!context.input_is_none(2)) {
auto bias = context.get_input(2);
auto bias_rank = bias.get_partial_shape().rank();
if (bias_rank == 1) {
bias = reshape_conv_bias(context, bias, conv);
}
conv = context.mark_node(std::make_shared<opset8::Add>(conv, bias));
}

return {context.mark_output(make_optional_bias(conv, context, 2, {-2, -1}))};
return {conv};
};

} // namespace op
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,18 @@ namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_max_pool2d(NodeContext& context) {
OutputVector translate_max_poolnd(NodeContext& context) {
auto kernel = context.const_input<Shape>(1);
auto strides = context.const_input<Strides>(2);
auto pads_begin = context.const_input<Shape>(3); // FIXME: The same 3 is used twice
auto pads_end = context.const_input<Shape>(3); // FIXME: The same 3 is used twice
auto pads = context.const_input<Shape>(3); // pytorch supports only symmetric paddings
auto dilations = context.const_input<Strides>(4);
auto rounding_type = context.const_input<bool>(5) ? ov::op::RoundingType::CEIL : ov::op::RoundingType::FLOOR;

return {context.mark_node(std::make_shared<opset8::MaxPool>(context.get_input(0),
strides,
dilations,
pads_begin,
pads_end,
pads,
pads,
kernel,
rounding_type))};
};
Expand Down
18 changes: 12 additions & 6 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ OP_CONVERTER(translate_add);
OP_CONVERTER(translate_addcmul);
OP_CONVERTER(translate_addmm);
OP_CONVERTER(translate_as_tensor);
OP_CONVERTER(translate_avg_pool2d);
OP_CONVERTER(translate_avg_poolnd);
OP_CONVERTER(translate_batch_norm);
OP_CONVERTER(translate_clamp);
OP_CONVERTER(translate_constant);
OP_CONVERTER(translate_conv2d);
OP_CONVERTER(translate_convnd);
OP_CONVERTER(translate_convolution);
OP_CONVERTER(translate_convolution_mode);
OP_CONVERTER(translate_dim);
Expand All @@ -45,7 +45,7 @@ OP_CONVERTER(translate_int);
OP_CONVERTER(translate_layer_norm);
OP_CONVERTER(translate_linear);
OP_CONVERTER(translate_loop);
OP_CONVERTER(translate_max_pool2d);
OP_CONVERTER(translate_max_poolnd);
OP_CONVERTER(translate_max);
OP_CONVERTER(translate_masked_fill);
OP_CONVERTER(translate_mean);
Expand Down Expand Up @@ -112,7 +112,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::atan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Atan>>},
{"aten::atanh", op::translate_1to1_match_1_inputs<opset8::Atanh>},
{"aten::atanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Atanh>>},
{"aten::avg_pool2d", op::translate_avg_pool2d},
{"aten::avg_pool1d", op::translate_avg_poolnd},
{"aten::avg_pool2d", op::translate_avg_poolnd},
{"aten::avg_pool3d", op::translate_avg_poolnd},
{"aten::batch_norm", op::translate_batch_norm},
// {"aten::cat", done as transformation},
{"aten::clamp", op::translate_clamp},
Expand All @@ -123,7 +125,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::clone", op::skip_node}, // ignore clone operators that are inserted by PyTorch autograd
{"aten::contiguous", op::skip_node}, // In openvino how tensors are stored in memory is internal plugin detail,
// we assume all tensors are contiguous
{"aten::conv2d", op::translate_conv2d},
{"aten::conv1d", op::translate_convnd},
{"aten::conv2d", op::translate_convnd},
{"aten::conv3d", op::translate_convnd},
{"aten::convolution", op::translate_convolution},
{"aten::cos", op::translate_1to1_match_1_inputs<opset8::Cos>},
{"aten::cos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Cos>>},
Expand Down Expand Up @@ -165,7 +169,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::matmul", op::translate_1to1_match_2_inputs<opset8::MatMul>},
{"aten::masked_fill", op::translate_masked_fill},
{"aten::masked_fill_", op::inplace_op<op::translate_masked_fill>},
{"aten::max_pool2d", op::translate_max_pool2d},
{"aten::max_pool1d", op::translate_max_poolnd},
{"aten::max_pool2d", op::translate_max_poolnd},
{"aten::max_pool3d", op::translate_max_poolnd},
{"aten::max", op::translate_max},
{"aten::mean", op::translate_mean},
{"aten::min", op::translate_min},
Expand Down
53 changes: 0 additions & 53 deletions tests/layer_tests/pytorch_tests/test_conv2d.py

This file was deleted.

Loading