diff --git a/src/frontends/pytorch/src/op/as_tensor.cpp b/src/frontends/pytorch/src/op/as_tensor.cpp index 01ce5ef9decf1e..c0292764c730cf 100644 --- a/src/frontends/pytorch/src/op/as_tensor.cpp +++ b/src/frontends/pytorch/src/op/as_tensor.cpp @@ -15,17 +15,19 @@ namespace op { OutputVector translate_as_tensor(NodeContext& context) { auto dtype = element::f32; Output cast; - if (!context.input_is_none(1)){ + if (!context.input_is_none(1)) { auto dtype_ext_node = context.get_input_from_visible_context(1).get_node_shared_ptr(); auto dtype_fw_node = std::dynamic_pointer_cast(dtype_ext_node); if (dtype_fw_node && dtype_fw_node->get_op_type() == "prim::dtype") { auto type_input = dtype_fw_node->input_value(0); return {context.mark_node(std::make_shared(context.get_input(0), type_input))}; } - if (auto dtype_const = std::dynamic_pointer_cast(dtype_ext_node)){ - auto pt_type = dtype_const->cast_vector()[0]; - FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::as_tensor: ", pt_type); - dtype = TORCH_TO_OV_TYPE.at(pt_type); + if (auto dtype_const = std::dynamic_pointer_cast(dtype_ext_node)) { + auto pt_type = dtype_const->cast_vector()[0]; + FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), + "Unknown type in aten::as_tensor: ", + pt_type); + dtype = TORCH_TO_OV_TYPE.at(pt_type); } } cast = context.mark_node(std::make_shared(context.get_input(0), dtype)); diff --git a/src/frontends/pytorch/src/op/full.cpp b/src/frontends/pytorch/src/op/full.cpp index e2c997329d967f..085f4e2285b2df 100644 --- a/src/frontends/pytorch/src/op/full.cpp +++ b/src/frontends/pytorch/src/op/full.cpp @@ -18,14 +18,14 @@ OutputVector translate_full(NodeContext& context) { auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); if (num_inputs < 6) { - size_t out_id = num_inputs == 3 ? 2: 3; - if (!context.input_is_none(out_id)){ - auto out = context.get_input(out_id); - return {context.mark_node(std::make_shared(filled_tensor, out))}; + size_t out_id = num_inputs == 3 ? 2 : 3; + if (!context.input_is_none(out_id)) { + auto out = context.get_input(out_id); + return {context.mark_node(std::make_shared(filled_tensor, out))}; } } - size_t dtype_id = num_inputs == 6 ? 2: 3; - if (!context.input_is_none(dtype_id)){ + size_t dtype_id = num_inputs == 6 ? 2 : 3; + if (!context.input_is_none(dtype_id)) { auto pt_type = context.const_input(dtype_id); FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::full: ", pt_type); auto dtype = TORCH_TO_OV_TYPE.at(pt_type); @@ -39,13 +39,13 @@ OutputVector translate_full_like(NodeContext& context) { auto value = context.get_input(1); auto input_shape = context.mark_node(std::make_shared(input)); auto filled_tensor = context.mark_node(std::make_shared(value, input_shape)); - if (context.get_input_size() == 7 && !context.input_is_none(2)){ + if (context.get_input_size() == 7 && !context.input_is_none(2)) { auto pt_type = context.const_input(2); FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::full_like: ", pt_type); auto dtype = TORCH_TO_OV_TYPE.at(pt_type); filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); } else { - auto out_dtype = context.input_is_none(3)? input : context.get_input(3); + auto out_dtype = context.input_is_none(3) ? input : context.get_input(3); filled_tensor = context.mark_node(std::make_shared(filled_tensor, out_dtype)); } return {filled_tensor}; @@ -71,15 +71,15 @@ OutputVector translate_zeros(NodeContext& context) { auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); int num_inputs = context.get_input_size(); if (num_inputs < 5) { - size_t out_id = num_inputs == 2 ? 1: 2; - if (!context.input_is_none(out_id)){ - auto out = context.get_input(out_id); - return {context.mark_node(std::make_shared(filled_tensor, out))}; + size_t out_id = num_inputs == 2 ? 1 : 2; + if (!context.input_is_none(out_id)) { + auto out = context.get_input(out_id); + return {context.mark_node(std::make_shared(filled_tensor, out))}; } return {filled_tensor}; } - size_t dtype_id = num_inputs == 5 ? 1: 2; - if (!context.input_is_none(dtype_id)){ + size_t dtype_id = num_inputs == 5 ? 1 : 2; + if (!context.input_is_none(dtype_id)) { std::cout << dtype_id << std::endl; auto pt_type = context.const_input(dtype_id); std::cout << pt_type << std::endl; @@ -95,14 +95,13 @@ OutputVector translate_zeros_like(NodeContext& context) { auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0})); auto input_shape = context.mark_node(std::make_shared(input)); auto filled_tensor = context.mark_node(std::make_shared(value, input_shape)); - if (context.get_input_size() == 6 && !context.input_is_none(1)){ + if (context.get_input_size() == 6 && !context.input_is_none(1)) { auto pt_type = context.const_input(1); FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::zeros_like: ", pt_type); auto dtype = TORCH_TO_OV_TYPE.at(pt_type); filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); - } - else { - auto out_dtype = context.input_is_none(2)? input : context.get_input(2); + } else { + auto out_dtype = context.input_is_none(2) ? input : context.get_input(2); filled_tensor = context.mark_node(std::make_shared(filled_tensor, out_dtype)); } return {filled_tensor}; @@ -113,7 +112,7 @@ OutputVector translate_new_zeros(NodeContext& context) { auto sizes = context.get_input(1); auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0})); auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); - if (context.get_input_size() == 6 && !context.input_is_none(2)){ + if (context.get_input_size() == 6 && !context.input_is_none(2)) { auto pt_type = context.const_input(2); FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::new_zeros: ", pt_type); auto dtype = TORCH_TO_OV_TYPE.at(pt_type); @@ -128,14 +127,14 @@ OutputVector translate_ones(NodeContext& context) { auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); int num_inputs = context.get_input_size(); if (num_inputs < 5) { - size_t out_id = num_inputs == 2 ? 1: 2; - if (!context.input_is_none(out_id)){ - auto out = context.get_input(out_id); - return {context.mark_node(std::make_shared(filled_tensor, out))}; + size_t out_id = num_inputs == 2 ? 1 : 2; + if (!context.input_is_none(out_id)) { + auto out = context.get_input(out_id); + return {context.mark_node(std::make_shared(filled_tensor, out))}; } } - size_t dtype_id = num_inputs == 5 ? 1: 2; - if (!context.input_is_none(dtype_id)){ + size_t dtype_id = num_inputs == 5 ? 1 : 2; + if (!context.input_is_none(dtype_id)) { auto pt_type = context.const_input(dtype_id); FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::ones: ", pt_type); auto dtype = TORCH_TO_OV_TYPE.at(pt_type); @@ -149,14 +148,13 @@ OutputVector translate_ones_like(NodeContext& context) { auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1})); auto input_shape = context.mark_node(std::make_shared(input)); auto filled_tensor = context.mark_node(std::make_shared(value, input_shape)); - if (context.get_input_size() == 6 && !context.input_is_none(1)){ + if (context.get_input_size() == 6 && !context.input_is_none(1)) { auto pt_type = context.const_input(1); FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::ones_like: ", pt_type); auto dtype = TORCH_TO_OV_TYPE.at(pt_type); filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); - } - else { - auto out_dtype = context.input_is_none(2)? input : context.get_input(2); + } else { + auto out_dtype = context.input_is_none(2) ? input : context.get_input(2); filled_tensor = context.mark_node(std::make_shared(filled_tensor, out_dtype)); } return {filled_tensor}; @@ -167,7 +165,7 @@ OutputVector translate_new_ones(NodeContext& context) { auto sizes = context.get_input(1); auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1})); auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); - if (context.get_input_size() == 6 && !context.input_is_none(2)){ + if (context.get_input_size() == 6 && !context.input_is_none(2)) { auto pt_type = context.const_input(2); FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::new_zeros: ", pt_type); auto dtype = TORCH_TO_OV_TYPE.at(pt_type); diff --git a/src/frontends/pytorch/src/op/upsample.cpp b/src/frontends/pytorch/src/op/upsample.cpp index f7b64d0c989828..ad1c10d432ce90 100644 --- a/src/frontends/pytorch/src/op/upsample.cpp +++ b/src/frontends/pytorch/src/op/upsample.cpp @@ -16,7 +16,7 @@ OutputVector translate_upsample2d(NodeContext& context, opset8::Interpolate::Int auto size_mode = opset8::Interpolate::ShapeCalcMode::SIZES; bool align_corners = false; int scale_id = 2; - if (interpolate_mode == opset8::Interpolate::InterpolateMode::LINEAR_ONNX) { + if (interpolate_mode != opset8::Interpolate::InterpolateMode::NEAREST) { scale_id = 3; if (!context.input_is_none(2)) { align_corners = context.const_input(2); @@ -38,7 +38,7 @@ OutputVector translate_upsample2d(NodeContext& context, opset8::Interpolate::Int auto attrs = opset8::Interpolate::InterpolateAttrs(interpolate_mode, size_mode, pad, pad); attrs.coordinate_transformation_mode = opset8::Interpolate::CoordinateTransformMode::ASYMMETRIC; attrs.nearest_mode = opset8::Interpolate::NearestMode::FLOOR; - if (attrs.mode == opset8::Interpolate::InterpolateMode::LINEAR_ONNX) { + if (attrs.mode != opset8::Interpolate::InterpolateMode::NEAREST) { if (align_corners) { attrs.coordinate_transformation_mode = opset8::Interpolate::CoordinateTransformMode::ALIGN_CORNERS; } @@ -54,6 +54,10 @@ OutputVector translate_upsample_nearest2d(NodeContext& context) { return translate_upsample2d(context, opset8::Interpolate::InterpolateMode::NEAREST); }; +OutputVector translate_upsample_bicubic2d(NodeContext& context) { + return translate_upsample2d(context, opset8::Interpolate::InterpolateMode::CUBIC); +}; + } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index a70c998fd1d855..be7207d7feb81a 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -78,6 +78,7 @@ OP_CONVERTER(translate_sum); OP_CONVERTER(translate_to); OP_CONVERTER(translate_transpose); OP_CONVERTER(translate_tuple_construct); +OP_CONVERTER(translate_upsample_bicubic2d); OP_CONVERTER(translate_upsample_bilinear2d); OP_CONVERTER(translate_upsample_nearest2d); OP_CONVERTER(translate_var); @@ -92,6 +93,10 @@ const std::map get_supported_ops() { {"aten::_convolution", op::translate_convolution}, {"aten::_convolution_mode", op::translate_convolution_mode}, {"aten::abs", op::translate_1to1_match_1_inputs}, + {"aten::acos", op::translate_1to1_match_1_inputs}, + {"aten::acos_", op::inplace_op>}, + {"aten::acosh", op::translate_1to1_match_1_inputs}, + {"aten::acosh_", op::inplace_op>}, {"aten::adaptive_avg_pool2d", op::translate_1to1_match_2_inputs}, {"aten::adaptive_avg_pool3d", op::translate_adaptive_avg_pool3d}, {"aten::adaptive_max_pool2d", op::translate_adaptive_max_pool2d}, @@ -100,18 +105,32 @@ const std::map get_supported_ops() { {"aten::addcmul", op::translate_addcmul}, {"aten::addmm", op::translate_addmm}, {"aten::arange", op::translate_arange}, + {"aten::asin", op::translate_1to1_match_1_inputs}, + {"aten::asin_", op::inplace_op>}, + {"aten::asinh", op::translate_1to1_match_1_inputs}, + {"aten::asinh_", op::inplace_op>}, {"aten::as_tensor", op::translate_as_tensor}, + {"aten::atan", op::translate_1to1_match_1_inputs}, + {"aten::atan_", op::inplace_op>}, + {"aten::atanh", op::translate_1to1_match_1_inputs}, + {"aten::atanh_", op::inplace_op>}, {"aten::avg_pool2d", op::translate_avg_pool2d}, {"aten::batch_norm", op::translate_batch_norm}, // {"aten::cat", done as transformation}, {"aten::clamp", op::translate_clamp}, {"aten::clamp_min", op::translate_1to1_match_2_inputs}, {"aten::clamp_max", op::translate_1to1_match_2_inputs}, + {"aten::ceil", op::translate_1to1_match_1_inputs}, + {"aten::ceil_", op::inplace_op>}, {"aten::clone", op::skip_node}, // ignore clone operators that are inserted by PyTorch autograd {"aten::contiguous", op::skip_node}, // In openvino how tensors are stored in memory is internal plugin detail, // we assume all tensors are contiguous {"aten::conv2d", op::translate_conv2d}, {"aten::convolution", op::translate_convolution}, + {"aten::cos", op::translate_1to1_match_1_inputs}, + {"aten::cos_", op::inplace_op>}, + {"aten::cosh", op::translate_1to1_match_1_inputs}, + {"aten::cosh_", op::inplace_op>}, {"aten::cumsum", op::translate_1to1_match_2_inputs}, {"aten::dim", op::translate_dim}, {"aten::div", op::translate_div}, @@ -123,11 +142,14 @@ const std::map get_supported_ops() { {"aten::expand", op::translate_expand}, {"aten::expand_as", op::translate_expand_as}, {"aten::flatten", op::translate_flatten}, + {"aten::floor", op::translate_1to1_match_1_inputs}, + {"aten::floor_", op::inplace_op>}, {"aten::floordiv", op::translate_floordiv}, {"aten::full", op::translate_full}, {"aten::full_like", op::translate_full_like}, {"aten::gelu", op::translate_gelu}, {"aten::group_norm", op::translate_group_norm}, + {"aten::ge", op::translate_1to1_match_2_inputs}, {"aten::gt", op::translate_1to1_match_2_inputs}, {"aten::hardsigmoid", op::translate_1to1_match_1_inputs}, {"aten::hardswish", op::translate_1to1_match_1_inputs}, @@ -140,6 +162,7 @@ const std::map get_supported_ops() { {"aten::leaky_relu", op::translate_1to1_match_2_inputs}, {"aten::leaky_relu_", op::inplace_op>}, {"aten::linear", op::translate_linear}, + {"aten::le", op::translate_1to1_match_2_inputs}, {"aten::lt", op::translate_1to1_match_2_inputs}, {"aten::matmul", op::translate_1to1_match_2_inputs}, {"aten::masked_fill", op::translate_masked_fill}, @@ -178,6 +201,10 @@ const std::map get_supported_ops() { {"aten::sigmoid", op::translate_1to1_match_1_inputs}, {"aten::silu", op::translate_1to1_match_1_inputs}, {"aten::silu_", op::inplace_op>}, + {"aten::sin", op::translate_1to1_match_1_inputs}, + {"aten::sin_", op::inplace_op>}, + {"aten::sinh", op::translate_1to1_match_1_inputs}, + {"aten::sinh_", op::inplace_op>}, {"aten::size", op::translate_size}, {"aten::slice", op::translate_slice}, {"aten::softmax", op::translate_softmax}, @@ -186,13 +213,18 @@ const std::map get_supported_ops() { {"aten::squeeze", op::translate_squeeze}, {"aten::sub", op::translate_sub}, {"aten::sum", op::translate_sum}, + {"aten::tan", op::translate_1to1_match_1_inputs}, + {"aten::tan_", op::inplace_op>}, {"aten::tanh", op::translate_1to1_match_1_inputs}, + {"aten::tanh_", op::inplace_op>}, {"aten::tensor", op::translate_as_tensor}, {"aten::type_as", op::translate_1to1_match_2_inputs}, // TODO: overflow semantics is different {"aten::to", op::translate_to}, {"aten::transpose", op::translate_transpose}, {"aten::unsqueeze", op::translate_1to1_match_2_inputs}, + {"aten::unsqueeze_", op::inplace_op>}, + {"aten::upsample_bicubic2d", op::translate_upsample_bicubic2d}, {"aten::upsample_bilinear2d", op::translate_upsample_bilinear2d}, {"aten::upsample_nearest2d", op::translate_upsample_nearest2d}, {"aten::var", op::translate_var}, diff --git a/tests/layer_tests/pytorch_tests/test_ceil.py b/tests/layer_tests/pytorch_tests/test_ceil.py new file mode 100644 index 00000000000000..86c7e04e9373a7 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_ceil.py @@ -0,0 +1,31 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from pytorch_layer_test_class import PytorchLayerTest + + +class TestCeil(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(1, 3, 224, 224).astype(np.float32),) + + def create_model(self, inplace): + import torch + + class aten_ceil(torch.nn.Module): + def __init__(self, inplace): + super(aten_ceil, self).__init__() + self.op = torch.ceil_ if inplace else torch.ceil + + def forward(self, x): + return x, self.op(x) + + ref_net = None + + return aten_ceil(inplace), ref_net, "aten::ceil" if not inplace else "aten::ceil_" + + @pytest.mark.parametrize("inplace", [False, True]) + @pytest.mark.nightly + def test_ceil(self, inplace, ie_device, precision, ir_version): + self._test(*self.create_model(inplace), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_comparision.py b/tests/layer_tests/pytorch_tests/test_comparision.py new file mode 100644 index 00000000000000..1fc801a399ad51 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_comparision.py @@ -0,0 +1,58 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from pytorch_layer_test_class import PytorchLayerTest + + +class TestComp(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(1, 3, 24, 24).astype(np.float32), np.random.randn(1, 3, 24, 24).astype(np.float32)) + + def create_model(self, op_type): + + import torch + + class aten_eq(torch.nn.Module): + def forward(self, x, y): + return x == y + + class aten_ne(torch.nn.Module): + def forward(self, x, y): + return x != y + + class aten_lt(torch.nn.Module): + def forward(self, x, y): + return x < y + + class aten_gt(torch.nn.Module): + def forward(self, x, y): + return x > y + + class aten_le(torch.nn.Module): + def forward(self, x, y): + return x <= y + + class aten_ge(torch.nn.Module): + def forward(self, x, y): + return x >= y + + ops = { + "eq": aten_eq, + "ne": aten_ne, + "lt": aten_lt, + "gt": aten_gt, + "ge": aten_ge, + "le": aten_le + } + model_cls = ops[op_type] + + ref_net = None + + return model_cls(), ref_net, f"aten::{op_type}" + + @pytest.mark.parametrize("op", ["eq", "ne", "lt", "gt", "le", "ge"]) + @pytest.mark.nightly + def test_comp(self, op, ie_device, precision, ir_version): + self._test(*self.create_model(op), ie_device, precision, ir_version) \ No newline at end of file diff --git a/tests/layer_tests/pytorch_tests/test_floor.py b/tests/layer_tests/pytorch_tests/test_floor.py new file mode 100644 index 00000000000000..8e26dd0c296c37 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_floor.py @@ -0,0 +1,31 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from pytorch_layer_test_class import PytorchLayerTest + + +class TestFloor(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(1, 3, 224, 224).astype(np.float32),) + + def create_model(self, inplace): + import torch + + class aten_floor(torch.nn.Module): + def __init__(self, inplace): + super(aten_floor, self).__init__() + self.op = torch.floor_ if inplace else torch.floor + + def forward(self, x): + return x, self.op(x) + + ref_net = None + + return aten_floor(inplace), ref_net, "aten::floor" if not inplace else "aten::floor_" + + @pytest.mark.parametrize("inplace", [False, True]) + @pytest.mark.nightly + def test_floor(self, inplace, ie_device, precision, ir_version): + self._test(*self.create_model(inplace), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_trigonometry.py b/tests/layer_tests/pytorch_tests/test_trigonometry.py new file mode 100644 index 00000000000000..f1bd68fec0a52a --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_trigonometry.py @@ -0,0 +1,63 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from pytorch_layer_test_class import PytorchLayerTest + + +class TestTrigonom(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(1, 2, 3, 4).astype(np.float32), ) + + def create_model(self, op_type): + + import torch + ops={ + "cos": torch.cos, + "cos_": torch.cos_, + "sin": torch.sin, + "sin_": torch.sin_, + "tan": torch.tan, + "tan_": torch.tan_, + "cosh": torch.cosh, + "cosh_": torch.cosh_, + "sinh": torch.sinh, + "sinh_": torch.sinh_, + "tanh": torch.tanh, + "tanh_": torch.tanh_, + "acos": torch.acos, + "acos_": torch.acos_, + "asin": torch.asin, + "asin_": torch.asin_, + "atan": torch.atan, + "atan_": torch.atan_, + "acosh": torch.acosh, + "acosh_": torch.acosh_, + "asinh": torch.asinh, + "asinh_": torch.asinh_, + "atanh": torch.atanh, + "atanh_": torch.atanh_, + } + + class aten_op(torch.nn.Module): + def __init__(self, op): + super(aten_op, self).__init__() + self.op = op + + def forward(self, x): + return self.op(x) + ref_net = None + + return aten_op(ops[op_type]), ref_net, f'aten::{op_type}' + + @ pytest.mark.parametrize("op", [ + "acos", "acos_", "acosh", "acosh_", + "asin", "asin_", "asinh", "asinh_", + "atan", "atan_", "atanh", "atanh_", + "cos", "cos_", "cosh", "cosh_", + "sin", "sin_", "sinh", "sinh_", + "tan", "tan_", "tanh", "tanh_"]) + @ pytest.mark.nightly + def test_mm(self, op, ie_device, precision, ir_version): + self._test(*self.create_model(op), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_unsqueeze.py b/tests/layer_tests/pytorch_tests/test_unsqueeze.py new file mode 100644 index 00000000000000..361710766ae8ec --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_unsqueeze.py @@ -0,0 +1,44 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from pytorch_layer_test_class import PytorchLayerTest + + +class TestUnsqueeze(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(5, 10).astype(np.float32),) + + def create_model(self, inplace=False, dim=0): + + import torch + import torch.nn.functional as F + + class aten_unsqueeze(torch.nn.Module): + def __init__(self, dim): + super(aten_unsqueeze, self).__init__() + self.op = torch.unsqueeze + self.dim = dim + + def forward(self, x): + return x, self.op(x, self.dim) + + class aten_unsqueeze_(torch.nn.Module): + def __init__(self, dim): + super(aten_unsqueeze_, self).__init__() + self.dim = dim + + def forward(self, x): + return x, x.unsqueeze_(self.dim) + + ref_net = None + model_class, op = (aten_unsqueeze, "aten::unsqueeze") if not inplace else (aten_unsqueeze_, "aten::unsqueeze_") + + return model_class(dim), ref_net, op + + @pytest.mark.parametrize("inplace", [False, True]) + @pytest.mark.parametrize("dim", [0, 1, -1]) + @pytest.mark.nightly + def test_relu(self, inplace, dim, ie_device, precision, ir_version): + self._test(*self.create_model(inplace, dim), ie_device, precision, ir_version) \ No newline at end of file diff --git a/tests/layer_tests/pytorch_tests/test_upsample.py b/tests/layer_tests/pytorch_tests/test_upsample.py index 8cced0213a2b05..0bf806208ca451 100644 --- a/tests/layer_tests/pytorch_tests/test_upsample.py +++ b/tests/layer_tests/pytorch_tests/test_upsample.py @@ -43,8 +43,14 @@ def forward(self, x): ('bilinear', (128, 480), None), ('bilinear', None, 2.5,), ('bilinear', None, 0.75), - ('bilinear', None, (1.2, 0.8))] + ('bilinear', None, (1.2, 0.8)), + ('bicubic', 300, None), + ('bicubic', 200, None), + ('bicubic', (128, 480), None), + ('bicubic', None, 2.5,), + ('bicubic', None, 0.75), + ('bicubic', None, (1.2, 0.8))] ) @pytest.mark.nightly - def test_upsample_nearest2d(self, mode, size, scale, ie_device, precision, ir_version): + def test_upsample(self, mode, size, scale, ie_device, precision, ir_version): self._test(*self.create_model(size, scale, mode), ie_device, precision, ir_version, trace_model=True) \ No newline at end of file