Skip to content

Commit

Permalink
[PT FE] Fix typos in im2col convertor (openvinotoolkit#25256)
Browse files Browse the repository at this point in the history
### Details:
 - *item1*
 - *...*

### Tickets:
 - *openvinotoolkit#25247*
  • Loading branch information
mvafin authored Jun 27, 2024
1 parent 1e9645c commit 6e62275
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions src/frontends/pytorch/src/op/im2col.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,32 +60,31 @@ OutputVector translate_im2col(const NodeContext& context) {
num_inputs_check(context, 5, 5);
auto input = context.get_input(0);
auto kernel_size = context.const_input<std::vector<int64_t>>(1);
PYTORCH_OP_CONVERSION_CHECK(kernel_size.size() == 2, "kernel size should contains 2 elements");
PYTORCH_OP_CONVERSION_CHECK(kernel_size.size() == 2, "kernel size should contain 2 elements");
auto dilation = context.const_input<std::vector<int64_t>>(2);
PYTORCH_OP_CONVERSION_CHECK(kernel_size.size() == 2, "dilation should contains 2 elements");
PYTORCH_OP_CONVERSION_CHECK(dilation.size() == 2, "dilation should contain 2 elements");
auto padding = context.const_input<std::vector<int64_t>>(3);
PYTORCH_OP_CONVERSION_CHECK(kernel_size.size() == 2, "padding should contains 2 elements");
PYTORCH_OP_CONVERSION_CHECK(padding.size() == 2, "padding should contain 2 elements");
auto stride = context.const_input<std::vector<int64_t>>(4);
PYTORCH_OP_CONVERSION_CHECK(kernel_size.size() == 2, "stride should contains 2 elements");
PYTORCH_OP_CONVERSION_CHECK(stride.size() == 2, "stride should contain 2 elements");
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
auto zero_f = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto two = context.mark_node(v0::Constant::create(element::i32, Shape{}, {2}));
auto four = context.mark_node(v0::Constant::create(element::i32, Shape{}, {4}));
auto input_shape_split = context.mark_node(std::make_shared<v1::Split>(input_shape, zero, 4));
auto input_b = input_shape_split->output(0);
auto input_c = input_shape_split->output(1);
auto input_h = input_shape_split->output(2);
auto input_w = input_shape_split->output(3);
auto stride_h = stride[0];
auto stride_w = stride[1];
auto padding_h = padding[0];
auto padding_w = padding[1];
auto dilation_h = dilation[0];
auto dilation_w = dilation[1];
auto kernel_h = kernel_size[0];
auto kernel_w = kernel_size[1];
const auto& input_b = input_shape_split->output(0);
const auto& input_c = input_shape_split->output(1);
const auto& input_h = input_shape_split->output(2);
const auto& input_w = input_shape_split->output(3);
const auto& stride_h = stride[0];
const auto& stride_w = stride[1];
const auto& padding_h = padding[0];
const auto& padding_w = padding[1];
const auto& dilation_h = dilation[0];
const auto& dilation_w = dilation[1];
const auto& kernel_h = kernel_size[0];
const auto& kernel_w = kernel_size[1];
auto blocks_row_indices = get_im2col_indices_along_dim(context, input_h, kernel_h, dilation_h, padding_h, stride_h);
auto blocks_col_indices = get_im2col_indices_along_dim(context, input_w, kernel_w, dilation_w, padding_w, stride_w);
auto kernel_window = context.mark_node(v0::Constant::create(element::i32, Shape{}, {kernel_h * kernel_w}));
Expand All @@ -96,6 +95,7 @@ OutputVector translate_im2col(const NodeContext& context) {
std::make_shared<v0::Concat>(OutputVector{input_b, channel_unfolded_unsqueezed, minus_one}, 0));
auto pads = context.mark_node(
v0::Constant::create(element::i32, Shape{4}, std::vector<int64_t>{0, 0, padding_h, padding_w}));
auto zero_f = context.mark_node(std::make_shared<v1::ConvertLike>(zero, input));
auto padded_input =
context.mark_node(std::make_shared<v1::Pad>(input, pads, pads, zero_f, ov::op::PadMode::CONSTANT));
auto output = context.mark_node(std::make_shared<v8::Gather>(padded_input, blocks_row_indices, two));
Expand Down

0 comments on commit 6e62275

Please sign in to comment.