diff --git a/src/frontends/pytorch/src/op/im2col.cpp b/src/frontends/pytorch/src/op/im2col.cpp index 4884c813794be3..1347a3d72d15ad 100644 --- a/src/frontends/pytorch/src/op/im2col.cpp +++ b/src/frontends/pytorch/src/op/im2col.cpp @@ -60,32 +60,31 @@ OutputVector translate_im2col(const NodeContext& context) { num_inputs_check(context, 5, 5); auto input = context.get_input(0); auto kernel_size = context.const_input>(1); - PYTORCH_OP_CONVERSION_CHECK(kernel_size.size() == 2, "kernel size should contains 2 elements"); + PYTORCH_OP_CONVERSION_CHECK(kernel_size.size() == 2, "kernel size should contain 2 elements"); auto dilation = context.const_input>(2); - PYTORCH_OP_CONVERSION_CHECK(kernel_size.size() == 2, "dilation should contains 2 elements"); + PYTORCH_OP_CONVERSION_CHECK(dilation.size() == 2, "dilation should contain 2 elements"); auto padding = context.const_input>(3); - PYTORCH_OP_CONVERSION_CHECK(kernel_size.size() == 2, "padding should contains 2 elements"); + PYTORCH_OP_CONVERSION_CHECK(padding.size() == 2, "padding should contain 2 elements"); auto stride = context.const_input>(4); - PYTORCH_OP_CONVERSION_CHECK(kernel_size.size() == 2, "stride should contains 2 elements"); + PYTORCH_OP_CONVERSION_CHECK(stride.size() == 2, "stride should contain 2 elements"); auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); auto input_shape = context.mark_node(std::make_shared(input, element::i32)); - auto zero_f = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0})); auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); auto two = context.mark_node(v0::Constant::create(element::i32, Shape{}, {2})); auto four = context.mark_node(v0::Constant::create(element::i32, Shape{}, {4})); auto input_shape_split = context.mark_node(std::make_shared(input_shape, zero, 4)); - auto input_b = input_shape_split->output(0); - auto input_c = input_shape_split->output(1); - auto input_h = input_shape_split->output(2); - auto input_w = input_shape_split->output(3); - auto stride_h = stride[0]; - auto stride_w = stride[1]; - auto padding_h = padding[0]; - auto padding_w = padding[1]; - auto dilation_h = dilation[0]; - auto dilation_w = dilation[1]; - auto kernel_h = kernel_size[0]; - auto kernel_w = kernel_size[1]; + const auto& input_b = input_shape_split->output(0); + const auto& input_c = input_shape_split->output(1); + const auto& input_h = input_shape_split->output(2); + const auto& input_w = input_shape_split->output(3); + const auto& stride_h = stride[0]; + const auto& stride_w = stride[1]; + const auto& padding_h = padding[0]; + const auto& padding_w = padding[1]; + const auto& dilation_h = dilation[0]; + const auto& dilation_w = dilation[1]; + const auto& kernel_h = kernel_size[0]; + const auto& kernel_w = kernel_size[1]; auto blocks_row_indices = get_im2col_indices_along_dim(context, input_h, kernel_h, dilation_h, padding_h, stride_h); auto blocks_col_indices = get_im2col_indices_along_dim(context, input_w, kernel_w, dilation_w, padding_w, stride_w); auto kernel_window = context.mark_node(v0::Constant::create(element::i32, Shape{}, {kernel_h * kernel_w})); @@ -96,6 +95,7 @@ OutputVector translate_im2col(const NodeContext& context) { std::make_shared(OutputVector{input_b, channel_unfolded_unsqueezed, minus_one}, 0)); auto pads = context.mark_node( v0::Constant::create(element::i32, Shape{4}, std::vector{0, 0, padding_h, padding_w})); + auto zero_f = context.mark_node(std::make_shared(zero, input)); auto padded_input = context.mark_node(std::make_shared(input, pads, pads, zero_f, ov::op::PadMode::CONSTANT)); auto output = context.mark_node(std::make_shared(padded_input, blocks_row_indices, two));