Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Extend PyTorch Frontend with MaxPool-14 and AvgPool-14 #23027

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 2 additions & 21 deletions src/frontends/pytorch/src/op/avg_poolnd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,35 +40,16 @@ OutputVector translate_avg_poolnd(const NodeContext& context) {
}
ov::op::RoundingType rounding_type = ov::op::RoundingType::FLOOR;
if (!(context.input_is_none(4))) {
rounding_type = context.const_input<bool>(4) ? ov::op::RoundingType::CEIL : ov::op::RoundingType::FLOOR;
rounding_type = context.const_input<bool>(4) ? ov::op::RoundingType::CEIL_TORCH : ov::op::RoundingType::FLOOR;
}
if (!(context.input_is_none(5))) {
count_include_pad = context.const_input<bool>(5);
}
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(6),
"Translation for aten::avg_pool2d do not support divisor_override input.");
// Although ov::AvgPool provides exclude_pad=false,
// The corner case of Average Pooling with ceil_mode on
// PyTorch allows sliding window go off bound, which leads to this accommodation.
// More detail on https://github.com/pytorch/pytorch/issues/57178
if (count_include_pad) {
auto zero = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
zero = context.mark_node(std::make_shared<v1::ConvertLike>(zero, input));
auto zero_i32 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
Output<Node> rank;
std::tie(std::ignore, rank) = get_shape_rank(context, input);
auto pad_values = context.get_input(3);
auto pads_len = context.mark_node(v0::Constant::create(element::i32, Shape{}, {pads.size()}));
auto pads_diff = context.mark_node(std::make_shared<v1::Subtract>(rank, pads_len));
auto pads_remaining = context.mark_node(std::make_shared<v3::Broadcast>(zero_i32, pads_diff));
auto padding = context.mark_node(
std::make_shared<v0::Concat>(OutputVector{std::move(pads_remaining), std::move(pad_values)}, 0));
input = context.mark_node(std::make_shared<v1::Pad>(input, padding, padding, zero, ov::op::PadMode::CONSTANT));
pads = Shape(pads.size(), 0);
}

return {context.mark_node(
std::make_shared<v1::AvgPool>(input, strides, pads, pads, kernel, !count_include_pad, rounding_type))};
std::make_shared<v14::AvgPool>(input, strides, pads, pads, kernel, !count_include_pad, rounding_type))};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have downgrading transformation for lowering version of AvgPool and MaxPool? That is to support plugins without v14 imlementation?

Copy link
Contributor

@mitruska mitruska Feb 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, this work-around logic should be moved from the Pytorch frontend to the downgrade transformation (if CEIL_TORCH mode detected) to keep backward compatibility and migration time for plugins. Plugin with support for the new version of MaxPool/AvgPool will disable such downgrade transformation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Downgrade transformations: #23381

};

} // namespace op
Expand Down
72 changes: 12 additions & 60 deletions src/frontends/pytorch/src/op/max_poolnd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ OutputVector translate_max_poolnd(const NodeContext& context) {
if (!context.input_is_none(2)) {
strides = context.const_input<Strides>(2);
}
const bool use_kernel = context.input_is_none(2) || (strides.size() == 0);
if (use_kernel) {
if (context.input_is_none(2) || strides.size() == 0) {
// In case strides are not provided default is kernel
strides = kernel;
}
Expand All @@ -51,66 +50,19 @@ OutputVector translate_max_poolnd(const NodeContext& context) {
if (context.input_is_none(5)) {
rounding_type = RoundingType::FLOOR;
} else {
rounding_type = context.const_input<bool>(5) ? RoundingType::CEIL : RoundingType::FLOOR;
rounding_type = context.const_input<bool>(5) ? RoundingType::CEIL_TORCH : RoundingType::FLOOR;
}

auto input = context.get_input(0);
if (rounding_type == RoundingType::CEIL) {
// The corner case of Max Pooling with ceil_mode on
// PyTorch allows sliding window go off bound, which leads to this accommodation.
// More detail on https://github.com/pytorch/pytorch/issues/57178
const auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
const auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
const auto two = context.mark_node(v0::Constant::create(element::i32, Shape{}, {2}));

const auto padding =
context.input_is_none(3)
? context.mark_node(std::make_shared<v0::Constant>(element::i32, Shape{pads.size()}, 0))->output(0)
: context.get_input(3);
const auto pads_len = context.mark_node(v0::Constant::create(element::i32, Shape{}, {pads.size()}));
const auto pads_remaining = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {0, 0}));

// gather input spatial dims and prepare for compare as values (in_dim + pad)
const auto input_shape_rank = get_shape_rank(context, input);
const auto end = context.mark_node(v0::Constant::create(element::i32, Shape{}, {pads.size() + 2}));
const auto dim_idxs = context.mark_node(std::make_shared<v4::Range>(two, end, one, element::i32));
const auto gth_in_dims =
context.mark_node(std::make_shared<v8::Gather>(std::get<0>(input_shape_rank), dim_idxs, zero));
const auto in_left_padded = context.mark_node(std::make_shared<v1::Add>(gth_in_dims, padding));

// gather output spatial dims and prepare it for compare as values (out_dim - 1) * stride
const auto mp = context.mark_node(
std::make_shared<v8::MaxPool>(input, strides, dilations, pads, pads, kernel, rounding_type));
const auto shape_of_mp = context.mark_node(std::make_shared<v3::ShapeOf>(mp, element::i32));
const auto gth_out_dims = context.mark_node(std::make_shared<v8::Gather>(shape_of_mp, dim_idxs, zero));
const auto out_sub_one = context.mark_node(std::make_shared<v1::Subtract>(gth_out_dims, one));
const auto stride_node = use_kernel ? context.get_input(1) : context.get_input(2);
const auto out_mul_stride = context.mark_node(std::make_shared<v1::Multiply>(out_sub_one, stride_node));

// if (in_dim + pad) > ((out_dim - 1) * stride) sliding window in bound use end padding.
const auto in_gt_out = context.mark_node(std::make_shared<v1::Greater>(in_left_padded, out_mul_stride));
const auto selected_pads = context.mark_node(std::make_shared<v1::Select>(in_gt_out, padding, zero));

// apply padding on input clear pads attribute
const auto pb = context.mark_node(std::make_shared<v0::Concat>(OutputVector{pads_remaining, padding}, 0));
const auto pe = context.mark_node(std::make_shared<v0::Concat>(OutputVector{pads_remaining, selected_pads}, 0));
auto minus_inf =
context.mark_node(v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits<float>::infinity()}));
minus_inf = context.mark_node(std::make_shared<v1::ConvertLike>(minus_inf, input));
input = context.mark_node(std::make_shared<v12::Pad>(input, pb, pe, minus_inf, op::PadMode::CONSTANT));
std::fill_n(pads.begin(), pads.size(), 0);
}

auto res = context.mark_node(std::make_shared<v8::MaxPool>(input,
strides,
dilations,
pads,
pads,
kernel,
rounding_type,
PadType::EXPLICIT,
element::i64,
2));
auto res = context.mark_node(std::make_shared<v14::MaxPool>(context.get_input(0),
strides,
dilations,
pads,
pads,
kernel,
rounding_type,
PadType::EXPLICIT,
element::i64,
2));
if (context.get_output_size() == 2) {
auto out1 = res->output(0);
auto out2 = res->output(1);
Expand Down
23 changes: 4 additions & 19 deletions tests/layer_tests/pytorch_tests/test_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
{'kernel_size': [2, 1], 'stride': [2, 1], 'padding': 0},
{'kernel_size': [2, 1], 'stride': None, 'padding': 0},
{'kernel_size': [2, 1], 'stride': [], 'padding': 0},
{'kernel_size': [8, 8], 'stride': [8, 4], 'padding': 1},
]

d2_params_corner_case = [{'kernel_size': [8, 8], 'stride': [8, 4], 'padding': 1}]
p-wysocki marked this conversation as resolved.
Show resolved Hide resolved

d1_params = [{'kernel_size': 3, 'stride': 1, 'padding': 0},
{'kernel_size': (4,), 'stride': 1, 'padding': 1},
{'kernel_size': 4, 'stride': (5,), 'padding': 2},
Expand Down Expand Up @@ -143,15 +142,7 @@ def test_avg_pool1d(self, params, ceil_mode, count_include_pad, ie_device, preci
ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 3}, trace_model=True,
dynamic_shapes=False)

@pytest.mark.parametrize(
"params",
d2_params
+ [
pytest.param(
{"kernel_size": [8, 8], "stride": [8, 4], "padding": 1},
marks=pytest.mark.xfail(reason="Sliding windows that would start in the right padded are ignored.")
)
])
@pytest.mark.parametrize("params", d2_params)
@pytest.mark.parametrize("ceil_mode", [True, False])
@pytest.mark.parametrize("count_include_pad", [True, False])
@pytest.mark.nightly
Expand Down Expand Up @@ -190,7 +181,7 @@ def test_max_pool1d(self, params, ceil_mode, dilation, ie_device, precision, ir_
self._test(*self.create_model("max_pool1d", **params, ceil_mode=ceil_mode, dilation=dilation),
ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 3}, dynamic_shapes=False)

@pytest.mark.parametrize("params", d2_params + d2_params_corner_case)
@pytest.mark.parametrize("params", d2_params)
@pytest.mark.parametrize("ceil_mode", [True, False])
@pytest.mark.parametrize("dilation", [1, 2])
@pytest.mark.parametrize("dtype", [torch.float32, torch.int32])
Expand Down Expand Up @@ -224,12 +215,10 @@ def test_max_pool3d(self, params, ceil_mode, dilation, ie_device, precision, ir_
@pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64',
reason='Ticket - 122715')
def test_max_pool1d_indices(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
if ceil_mode and (np.array(params["padding"]).any() != 0):
pytest.skip("ticket 122418")
self._test(*self.create_model("max_pool1d_with_indices", **params, ceil_mode=ceil_mode, dilation=dilation),
ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 3}, dynamic_shapes=False)

@pytest.mark.parametrize("params", d2_params + d2_params_corner_case)
@pytest.mark.parametrize("params", d2_params)
@pytest.mark.parametrize("ceil_mode", [True, False])
@pytest.mark.parametrize("dilation", [1, 2])
@pytest.mark.nightly
Expand All @@ -238,8 +227,6 @@ def test_max_pool1d_indices(self, params, ceil_mode, dilation, ie_device, precis
@pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64',
reason='Ticket - 122715')
def test_max_pool2d_indices(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
if ceil_mode and (np.array(params["padding"]).any() != 0):
pytest.skip("ticket 122418")
to_trace = False
if params["stride"] == []:
to_trace = True
Expand All @@ -255,7 +242,5 @@ def test_max_pool2d_indices(self, params, ceil_mode, dilation, ie_device, preci
@pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64',
reason='Ticket - 122715')
def test_max_pool3d_indices(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
if ceil_mode and (np.array(params["padding"]).any() != 0):
pytest.skip("ticket 122418")
self._test(*self.create_model("max_pool3d_with_indices", **params, ceil_mode=ceil_mode, dilation=dilation),
ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 5}, dynamic_shapes=False)
Loading