From 6594e1c34b999b050cd98fbbb8bc4402f763b71e Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 26 Dec 2022 18:36:53 +0400 Subject: [PATCH 1/5] aten::arange --- src/frontends/pytorch/src/op/arange.cpp | 96 +++++++++++++++++++ src/frontends/pytorch/src/op_table.cpp | 2 + .../pytorch_tests/pytorch_layer_test_class.py | 2 +- .../layer_tests/pytorch_tests/test_arange.py | 80 ++++++++++++++++ 4 files changed, 179 insertions(+), 1 deletion(-) create mode 100644 src/frontends/pytorch/src/op/arange.cpp create mode 100644 tests/layer_tests/pytorch_tests/test_arange.py diff --git a/src/frontends/pytorch/src/op/arange.cpp b/src/frontends/pytorch/src/op/arange.cpp new file mode 100644 index 00000000000000..5203a49fc96264 --- /dev/null +++ b/src/frontends/pytorch/src/op/arange.cpp @@ -0,0 +1,96 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/opsets/opset8.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_arange(NodeContext& context) { + auto zero = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {0})); + auto one = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {1})); + auto dtype = element::f32; + bool dtype_applied = false; + int num_inputs = context.get_input_size(); + // aten::arange(Scalar end, tensor out) + if (num_inputs == 2) { + auto end = context.get_input(0); + auto range = context.mark_node(std::make_shared(zero, end, one, dtype)); + return {context.mark_node(std::make_shared(range, end))}; + } + // # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) + if (num_inputs == 4) { + auto start = context.get_input(0); + auto end = context.get_input(1); + auto step = context.get_input(2); + auto range = context.mark_node(std::make_shared(start, end, step, dtype)); + return {context.mark_node(std::make_shared(range, end))}; + } + // aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + if (num_inputs == 5) { + auto end = context.get_input(0); + if (!context.input_is_none(1)) { + auto pt_type = context.const_input(1); + FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::arange: ", pt_type); + dtype = TORCH_TO_OV_TYPE.at(pt_type); + end = context.mark_node(std::make_shared(end, dtype)); + zero = context.mark_node(std::make_shared(zero, dtype)); + one = context.mark_node(std::make_shared(one, dtype)); + dtype_applied = true; + } + auto range = context.mark_node(std::make_shared(zero, end, one, dtype)); + if (!dtype_applied) { + return {context.mark_node(std::make_shared(range, end))}; + } + return {range}; + } + // aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + if (num_inputs == 6) { + auto start = context.get_input(0); + auto end = context.get_input(1); + if (!context.input_is_none(2)) { + auto pt_type = context.const_input(2); + FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::arange: ", pt_type); + dtype = TORCH_TO_OV_TYPE.at(pt_type); + dtype_applied = true; + end = context.mark_node(std::make_shared(end, dtype)); + start = context.mark_node(std::make_shared(start, dtype)); + one = context.mark_node(std::make_shared(one, dtype)); + } + auto range = context.mark_node(std::make_shared(start, end, one, dtype)); + if (!dtype_applied) { + return {context.mark_node(std::make_shared(range, end))}; + } + return {range}; + } + // aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) + if (num_inputs == 7) { + auto start = context.get_input(0); + auto end = context.get_input(1); + auto step = context.get_input(2); + if (!context.input_is_none(3)) { + auto pt_type = context.const_input(3); + FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::arange: ", pt_type); + dtype = TORCH_TO_OV_TYPE.at(pt_type); + end = context.mark_node(std::make_shared(end, dtype)); + start = context.mark_node(std::make_shared(start, dtype)); + step = context.mark_node(std::make_shared(step, dtype)); + dtype_applied = true; + } + auto range = context.mark_node(std::make_shared(start, end, step, dtype)); + if (!dtype_applied) { + return {context.mark_node(std::make_shared(range, end))}; + } + return {range}; + } +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 087457b4330d51..9974c716363352 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -19,6 +19,7 @@ OP_CONVERTER(translate_adaptive_max_pool2d); OP_CONVERTER(translate_add); OP_CONVERTER(translate_addcmul); OP_CONVERTER(translate_addmm); +OP_CONVERTER(translate_arange); OP_CONVERTER(translate_as_tensor); OP_CONVERTER(translate_avg_pool2d); OP_CONVERTER(translate_batch_norm); @@ -98,6 +99,7 @@ const std::map get_supported_ops() { {"aten::add_", op::inplace_op}, {"aten::addcmul", op::translate_addcmul}, {"aten::addmm", op::translate_addmm}, + {"aten::arange", op::translate_arange}, {"aten::as_tensor", op::translate_as_tensor}, {"aten::avg_pool2d", op::translate_avg_pool2d}, {"aten::batch_norm", op::translate_batch_norm}, diff --git a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py index 96e3b8effe48cd..1d5e3b97fad499 100644 --- a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py +++ b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py @@ -102,7 +102,7 @@ def _test(self, model, ref_net, kind, ie_device, precision, ir_version, infer_ti else: assert type(fw_tensor) == type(ov_tensor) continue - assert torch.tensor(np.array(ov_tensor)).dtype == fw_tensor.dtype + assert torch.tensor(np.array(ov_tensor)).dtype == fw_tensor.dtype, f"dtype validation failed: {torch.tensor(np.array(ov_tensor)).dtype} != {fw_tensor.dtype}" if 'custom_eps' in kwargs and kwargs['custom_eps'] is not None: custom_eps = kwargs['custom_eps'] diff --git a/tests/layer_tests/pytorch_tests/test_arange.py b/tests/layer_tests/pytorch_tests/test_arange.py new file mode 100644 index 00000000000000..64e18d924e7052 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_arange.py @@ -0,0 +1,80 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from typing import Union +from pytorch_layer_test_class import PytorchLayerTest + + +class TestExp(PytorchLayerTest): + def _prepare_input(self, end, start=None, step=None, dtype="int64"): + import numpy as np + if start is None and step is None: + return (np.array(end).astype(dtype), ) + if step is None: + return (np.array(start).astype(dtype), np.array(end).astype(dtype)) + return (np.array(start).astype(dtype), np.array(end).astype(dtype), np.array(step).astype(dtype)) + + def create_model(self, dtype, num_inputs): + import torch + + dtype_map = { + "float32": torch.float32, + "float64": torch.float64, + "int64": torch.int64, + "int32": torch.int32, + "uint8": torch.uint8, + "int8": torch.int8 + } + class aten_arange_end(torch.nn.Module): + def __init__(self, dtype) -> None: + super(aten_arange_end, self).__init__() + self.dtype = dtype + + def forward(self, x:int): + return torch.arange(x, dtype=self.dtype) + + class aten_arange_start_end(torch.nn.Module): + def __init__(self, dtype) -> None: + super(aten_arange_start_end, self).__init__() + self.dtype = dtype + + def forward(self, x:float, y:float): + return torch.arange(start=x, end=y, dtype=self.dtype) + + class aten_arange_start_end_step(torch.nn.Module): + def __init__(self, dtype) -> None: + super(aten_arange_start_end_step, self).__init__() + self.dtype = dtype + + def forward(self, x:float, y:float, z:float): + return torch.arange(start=x, end=y, step=z, dtype=self.dtype) + model_classes = { + 1: aten_arange_end, + 2: aten_arange_start_end, + 3: aten_arange_start_end_step + } + dtype = dtype_map.get(dtype) + model_class = model_classes[num_inputs] + + ref_net = None + + return model_class(dtype), ref_net, "aten::arange" + + @pytest.mark.nightly + @pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8", "uin8"]) + @pytest.mark.parametrize("end", [1, 2, 3]) + def test_arange_end_only(self, dtype, end, ie_device, precision, ir_version): + self._test(*self.create_model(dtype, 1), ie_device, precision, ir_version, kwargs_to_prepare_input={"end": end}) + + @pytest.mark.nightly + @pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8"]) + @pytest.mark.parametrize("start,end", [(0, 1), (-1, 1), (1, 5), (0.5, 2.5)]) + def test_arange_start_end(self, dtype, end, start, ie_device, precision, ir_version): + self._test(*self.create_model(dtype, 2), ie_device, precision, ir_version, kwargs_to_prepare_input={"end": end, "start": start, "dtype": "float32"}) + + @pytest.mark.nightly + @pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8"]) + @pytest.mark.parametrize("start,end,step", [(0, 1, 1), (-2, 1, 1.25), (1, -5, -1), (1, 10, 2), (-1, -5, -2)]) + def test_arange_start_end_step(self, dtype, end, start, step, ie_device, precision, ir_version): + self._test(*self.create_model(dtype, 3), ie_device, precision, ir_version, kwargs_to_prepare_input={"end": end, "start": start, "step": step, "dtype": "float32"}) \ No newline at end of file From 2fc8873b1a7a06c5316e20e60f39957acec058c6 Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 29 Dec 2022 11:57:50 +0400 Subject: [PATCH 2/5] extend tests and constant filling ops --- src/frontends/pytorch/src/op/arange.cpp | 8 ++ src/frontends/pytorch/src/op/full.cpp | 102 +++++++++++++++++- .../layer_tests/pytorch_tests/test_arange.py | 54 +++++++--- 3 files changed, 147 insertions(+), 17 deletions(-) diff --git a/src/frontends/pytorch/src/op/arange.cpp b/src/frontends/pytorch/src/op/arange.cpp index 5203a49fc96264..5e6b1dd3b023d9 100644 --- a/src/frontends/pytorch/src/op/arange.cpp +++ b/src/frontends/pytorch/src/op/arange.cpp @@ -21,6 +21,10 @@ OutputVector translate_arange(NodeContext& context) { if (num_inputs == 2) { auto end = context.get_input(0); auto range = context.mark_node(std::make_shared(zero, end, one, dtype)); + if (!context.input_is_none(1)) { + auto out_tensor = context.get_input(1); + return {context.mark_node(std::make_shared(range, out_tensor))}; + } return {context.mark_node(std::make_shared(range, end))}; } // # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) @@ -29,6 +33,10 @@ OutputVector translate_arange(NodeContext& context) { auto end = context.get_input(1); auto step = context.get_input(2); auto range = context.mark_node(std::make_shared(start, end, step, dtype)); + if (!context.input_is_none(3)) { + auto out_tensor = context.get_input(3); + return {context.mark_node(std::make_shared(range, out_tensor))}; + } return {context.mark_node(std::make_shared(range, end))}; } // aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) diff --git a/src/frontends/pytorch/src/op/full.cpp b/src/frontends/pytorch/src/op/full.cpp index 995656bda11c40..fb7c36022a9755 100644 --- a/src/frontends/pytorch/src/op/full.cpp +++ b/src/frontends/pytorch/src/op/full.cpp @@ -14,7 +14,24 @@ namespace op { OutputVector translate_full(NodeContext& context) { auto sizes = context.get_input(0); auto value = context.get_input(1); - return {context.mark_node(std::make_shared(value, sizes))}; + int num_inputs = context.get_input_size(); + + auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); + if (num_inputs < 5) { + size_t out_id = num_inputs == 3 ? 2: 3; + if (!context.input_is_none(out_id)){ + auto out = context.get_input(out_id); + return {context.mark_node(std::make_shared(filled_tensor, out))}; + } + } + size_t dtype_id = num_inputs == 5 ? 2: 3; + if (!context.input_is_none(dtype_id)){ + auto pt_type = context.const_input(dtype_id); + FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::full: ", pt_type); + auto dtype = TORCH_TO_OV_TYPE.at(pt_type); + filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); + } + return {filled_tensor}; }; OutputVector translate_full_like(NodeContext& context) { @@ -22,6 +39,14 @@ OutputVector translate_full_like(NodeContext& context) { auto value = context.get_input(1); auto input_shape = context.mark_node(std::make_shared(input)); auto filled_tensor = context.mark_node(std::make_shared(value, input_shape)); + if (context.get_input_size() == 7 && !context.input_is_none(2)){ + auto pt_type = context.const_input(2); + FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::full_like: ", pt_type); + auto dtype = TORCH_TO_OV_TYPE.at(pt_type); + filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); + } else { + filled_tensor = context.mark_node(std::make_shared(filled_tensor, input)); + } return {filled_tensor}; }; @@ -30,13 +55,38 @@ OutputVector translate_new_full(NodeContext& context) { auto sizes = context.get_input(1); auto value = context.get_input(2); auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); + if (context.get_input_size() == 7 && !context.input_is_none(3)) { + auto pt_type = context.const_input(2); + FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::new_full: ", pt_type); + auto dtype = TORCH_TO_OV_TYPE.at(pt_type); + return {context.mark_node(std::make_shared(filled_tensor, dtype))}; + } return {context.mark_node(std::make_shared(filled_tensor, input))}; }; OutputVector translate_zeros(NodeContext& context) { auto sizes = context.get_input(0); auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0})); - return {context.mark_node(std::make_shared(value, sizes))}; + auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); + int num_inputs = context.get_input_size(); + if (num_inputs < 5) { + size_t out_id = num_inputs == 2 ? 1: 2; + if (!context.input_is_none(out_id)){ + auto out = context.get_input(out_id); + return {context.mark_node(std::make_shared(filled_tensor, out))}; + } + return {filled_tensor}; + } + size_t dtype_id = num_inputs == 5 ? 1: 2; + if (!context.input_is_none(dtype_id)){ + std::cout << dtype_id << std::endl; + auto pt_type = context.const_input(dtype_id); + std::cout << pt_type << std::endl; + FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::zeros: ", pt_type); + auto dtype = TORCH_TO_OV_TYPE.at(pt_type); + filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); + } + return {filled_tensor}; }; OutputVector translate_zeros_like(NodeContext& context) { @@ -44,6 +94,15 @@ OutputVector translate_zeros_like(NodeContext& context) { auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0})); auto input_shape = context.mark_node(std::make_shared(input)); auto filled_tensor = context.mark_node(std::make_shared(value, input_shape)); + if (context.get_input_size() == 6 && !context.input_is_none(1)){ + auto pt_type = context.const_input(1); + FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::zeros_like: ", pt_type); + auto dtype = TORCH_TO_OV_TYPE.at(pt_type); + filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); + } + else { + filled_tensor = context.mark_node(std::make_shared(filled_tensor, input)); + } return {filled_tensor}; }; @@ -52,13 +111,35 @@ OutputVector translate_new_zeros(NodeContext& context) { auto sizes = context.get_input(1); auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0})); auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); + if (context.get_input_size() == 6 && !context.input_is_none(2)){ + auto pt_type = context.const_input(2); + FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::new_zeros: ", pt_type); + auto dtype = TORCH_TO_OV_TYPE.at(pt_type); + return {context.mark_node(std::make_shared(filled_tensor, dtype))}; + } return {context.mark_node(std::make_shared(filled_tensor, input))}; }; OutputVector translate_ones(NodeContext& context) { auto sizes = context.get_input(0); auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1})); - return {context.mark_node(std::make_shared(value, sizes))}; + auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); + int num_inputs = context.get_input_size(); + if (num_inputs < 5) { + size_t out_id = num_inputs == 2 ? 1: 2; + if (!context.input_is_none(out_id)){ + auto out = context.get_input(out_id); + return {context.mark_node(std::make_shared(filled_tensor, out))}; + } + } + size_t dtype_id = num_inputs == 5 ? 1: 2; + if (!context.input_is_none(dtype_id)){ + auto pt_type = context.const_input(dtype_id); + FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::ones: ", pt_type); + auto dtype = TORCH_TO_OV_TYPE.at(pt_type); + filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); + } + return {filled_tensor}; }; OutputVector translate_ones_like(NodeContext& context) { @@ -66,6 +147,15 @@ OutputVector translate_ones_like(NodeContext& context) { auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1})); auto input_shape = context.mark_node(std::make_shared(input)); auto filled_tensor = context.mark_node(std::make_shared(value, input_shape)); + if (context.get_input_size() == 6 && !context.input_is_none(1)){ + auto pt_type = context.const_input(1); + FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::ones_like: ", pt_type); + auto dtype = TORCH_TO_OV_TYPE.at(pt_type); + filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); + } + else { + filled_tensor = context.mark_node(std::make_shared(filled_tensor, input)); + } return {filled_tensor}; }; @@ -74,6 +164,12 @@ OutputVector translate_new_ones(NodeContext& context) { auto sizes = context.get_input(1); auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1})); auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); + if (context.get_input_size() == 6 && !context.input_is_none(2)){ + auto pt_type = context.const_input(2); + FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::new_zeros: ", pt_type); + auto dtype = TORCH_TO_OV_TYPE.at(pt_type); + return {context.mark_node(std::make_shared(filled_tensor, dtype))}; + } return {context.mark_node(std::make_shared(filled_tensor, input))}; }; diff --git a/tests/layer_tests/pytorch_tests/test_arange.py b/tests/layer_tests/pytorch_tests/test_arange.py index 64e18d924e7052..088dcecd484543 100644 --- a/tests/layer_tests/pytorch_tests/test_arange.py +++ b/tests/layer_tests/pytorch_tests/test_arange.py @@ -15,7 +15,7 @@ def _prepare_input(self, end, start=None, step=None, dtype="int64"): return (np.array(start).astype(dtype), np.array(end).astype(dtype)) return (np.array(start).astype(dtype), np.array(end).astype(dtype), np.array(step).astype(dtype)) - def create_model(self, dtype, num_inputs): + def create_model(self, dtype, num_inputs, use_out=False): import torch dtype_map = { @@ -26,46 +26,72 @@ def create_model(self, dtype, num_inputs): "uint8": torch.uint8, "int8": torch.int8 } - class aten_arange_end(torch.nn.Module): + class aten_arange_end_dtype(torch.nn.Module): def __init__(self, dtype) -> None: - super(aten_arange_end, self).__init__() + super(aten_arange_end_dtype, self).__init__() self.dtype = dtype def forward(self, x:int): return torch.arange(x, dtype=self.dtype) - class aten_arange_start_end(torch.nn.Module): + class aten_arange_start_end_dtype(torch.nn.Module): def __init__(self, dtype) -> None: - super(aten_arange_start_end, self).__init__() + super(aten_arange_start_end_dtype, self).__init__() self.dtype = dtype def forward(self, x:float, y:float): return torch.arange(start=x, end=y, dtype=self.dtype) - class aten_arange_start_end_step(torch.nn.Module): + class aten_arange_start_end_step_dtype(torch.nn.Module): def __init__(self, dtype) -> None: - super(aten_arange_start_end_step, self).__init__() + super(aten_arange_start_end_step_dtype, self).__init__() self.dtype = dtype def forward(self, x:float, y:float, z:float): return torch.arange(start=x, end=y, step=z, dtype=self.dtype) + + class aten_arange_end_out(torch.nn.Module): + def __init__(self, dtype) -> None: + super(aten_arange_end_out, self).__init__() + self.dtype = dtype + + def forward(self, x:int): + return torch.arange(x, out=torch.zeros(1, dtype=self.dtype)) + + class aten_arange_start_end_out(torch.nn.Module): + def __init__(self, out) -> None: + super(aten_arange_start_end_out, self).__init__() + self.out = out + + def forward(self, x:float, y:float): + return torch.arange(start=x, end=y, out=self.out) + + class aten_arange_start_end_step_out(torch.nn.Module): + def __init__(self, out) -> None: + super(aten_arange_start_end_step_out, self).__init__() + self.out = out + + def forward(self, x:float, y:float, z:float): + return torch.arange(start=x, end=y, step=z, out=self.out) model_classes = { - 1: aten_arange_end, - 2: aten_arange_start_end, - 3: aten_arange_start_end_step + 1: (aten_arange_end_dtype, aten_arange_end_out), + 2: (aten_arange_start_end_dtype, aten_arange_start_end_out), + 3: (aten_arange_start_end_step_dtype, aten_arange_start_end_step_out) } dtype = dtype_map.get(dtype) - model_class = model_classes[num_inputs] + model_class = model_classes[num_inputs][0](dtype) if not use_out or dtype is None else model_classes[num_inputs][1](dtype) + print(model_class) ref_net = None - return model_class(dtype), ref_net, "aten::arange" + return model_class, ref_net, "aten::arange" @pytest.mark.nightly @pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8", "uin8"]) @pytest.mark.parametrize("end", [1, 2, 3]) - def test_arange_end_only(self, dtype, end, ie_device, precision, ir_version): - self._test(*self.create_model(dtype, 1), ie_device, precision, ir_version, kwargs_to_prepare_input={"end": end}) + @pytest.mark.parametrize("use_out", [True, False]) + def test_arange_end_only(self, dtype, end, use_out, ie_device, precision, ir_version): + self._test(*self.create_model(dtype, 1, use_out), ie_device, precision, ir_version, kwargs_to_prepare_input={"end": end}) @pytest.mark.nightly @pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8"]) From 2a5e9c787e73b6870cdbade1935cd91634a9c753 Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 29 Dec 2022 17:01:52 +0400 Subject: [PATCH 3/5] fix as_tensor and full ops --- src/frontends/pytorch/src/op/as_tensor.cpp | 25 +- src/frontends/pytorch/src/op/full.cpp | 15 +- src/frontends/pytorch/src/op_table.cpp | 1 + tests/layer_tests/pytorch_tests/test_full.py | 351 ++++++++++++++++++- 4 files changed, 362 insertions(+), 30 deletions(-) diff --git a/src/frontends/pytorch/src/op/as_tensor.cpp b/src/frontends/pytorch/src/op/as_tensor.cpp index fe09362f4d2ae3..01ce5ef9decf1e 100644 --- a/src/frontends/pytorch/src/op/as_tensor.cpp +++ b/src/frontends/pytorch/src/op/as_tensor.cpp @@ -13,18 +13,23 @@ namespace pytorch { namespace op { OutputVector translate_as_tensor(NodeContext& context) { - auto dtype_ext_node = context.get_input_from_visible_context(1).get_node_shared_ptr(); - auto dtype_fw_node = std::dynamic_pointer_cast(dtype_ext_node); + auto dtype = element::f32; Output cast; - if (dtype_fw_node && dtype_fw_node->get_op_type() == "prim::dtype") { - auto type_input = dtype_fw_node->input_value(0); - cast = context.mark_node(std::make_shared(context.get_input(0), type_input)); - } else if (const auto dtype_const = std::dynamic_pointer_cast(dtype_ext_node)) { - auto pt_type = dtype_const->cast_vector()[0]; - FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::as_tensor: ", pt_type); - auto dtype = TORCH_TO_OV_TYPE.at(pt_type); - cast = context.mark_node(std::make_shared(context.get_input(0), dtype)); + if (!context.input_is_none(1)){ + auto dtype_ext_node = context.get_input_from_visible_context(1).get_node_shared_ptr(); + auto dtype_fw_node = std::dynamic_pointer_cast(dtype_ext_node); + if (dtype_fw_node && dtype_fw_node->get_op_type() == "prim::dtype") { + auto type_input = dtype_fw_node->input_value(0); + return {context.mark_node(std::make_shared(context.get_input(0), type_input))}; + } + if (auto dtype_const = std::dynamic_pointer_cast(dtype_ext_node)){ + auto pt_type = dtype_const->cast_vector()[0]; + FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::as_tensor: ", pt_type); + dtype = TORCH_TO_OV_TYPE.at(pt_type); + } } + cast = context.mark_node(std::make_shared(context.get_input(0), dtype)); + // Input with index 2 is device, we skip this input return {cast}; }; diff --git a/src/frontends/pytorch/src/op/full.cpp b/src/frontends/pytorch/src/op/full.cpp index fb7c36022a9755..e2c997329d967f 100644 --- a/src/frontends/pytorch/src/op/full.cpp +++ b/src/frontends/pytorch/src/op/full.cpp @@ -17,14 +17,14 @@ OutputVector translate_full(NodeContext& context) { int num_inputs = context.get_input_size(); auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); - if (num_inputs < 5) { + if (num_inputs < 6) { size_t out_id = num_inputs == 3 ? 2: 3; if (!context.input_is_none(out_id)){ auto out = context.get_input(out_id); return {context.mark_node(std::make_shared(filled_tensor, out))}; } } - size_t dtype_id = num_inputs == 5 ? 2: 3; + size_t dtype_id = num_inputs == 6 ? 2: 3; if (!context.input_is_none(dtype_id)){ auto pt_type = context.const_input(dtype_id); FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::full: ", pt_type); @@ -45,7 +45,8 @@ OutputVector translate_full_like(NodeContext& context) { auto dtype = TORCH_TO_OV_TYPE.at(pt_type); filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); } else { - filled_tensor = context.mark_node(std::make_shared(filled_tensor, input)); + auto out_dtype = context.input_is_none(3)? input : context.get_input(3); + filled_tensor = context.mark_node(std::make_shared(filled_tensor, out_dtype)); } return {filled_tensor}; }; @@ -56,7 +57,7 @@ OutputVector translate_new_full(NodeContext& context) { auto value = context.get_input(2); auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); if (context.get_input_size() == 7 && !context.input_is_none(3)) { - auto pt_type = context.const_input(2); + auto pt_type = context.const_input(3); FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::new_full: ", pt_type); auto dtype = TORCH_TO_OV_TYPE.at(pt_type); return {context.mark_node(std::make_shared(filled_tensor, dtype))}; @@ -101,7 +102,8 @@ OutputVector translate_zeros_like(NodeContext& context) { filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); } else { - filled_tensor = context.mark_node(std::make_shared(filled_tensor, input)); + auto out_dtype = context.input_is_none(2)? input : context.get_input(2); + filled_tensor = context.mark_node(std::make_shared(filled_tensor, out_dtype)); } return {filled_tensor}; }; @@ -154,7 +156,8 @@ OutputVector translate_ones_like(NodeContext& context) { filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); } else { - filled_tensor = context.mark_node(std::make_shared(filled_tensor, input)); + auto out_dtype = context.input_is_none(2)? input : context.get_input(2); + filled_tensor = context.mark_node(std::make_shared(filled_tensor, out_dtype)); } return {filled_tensor}; }; diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 9974c716363352..a70c998fd1d855 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -187,6 +187,7 @@ const std::map get_supported_ops() { {"aten::sub", op::translate_sub}, {"aten::sum", op::translate_sum}, {"aten::tanh", op::translate_1to1_match_1_inputs}, + {"aten::tensor", op::translate_as_tensor}, {"aten::type_as", op::translate_1to1_match_2_inputs}, // TODO: overflow semantics is different {"aten::to", op::translate_to}, diff --git a/tests/layer_tests/pytorch_tests/test_full.py b/tests/layer_tests/pytorch_tests/test_full.py index 1109e33c1f82a0..0ce73a235c37a6 100644 --- a/tests/layer_tests/pytorch_tests/test_full.py +++ b/tests/layer_tests/pytorch_tests/test_full.py @@ -1,6 +1,5 @@ # Copyright (C) 2018-2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from typing import Tuple import numpy as np import pytest from pytorch_layer_test_class import PytorchLayerTest @@ -10,8 +9,17 @@ class TestFull(PytorchLayerTest): def _prepare_input(self, value): return (np.array(value, dtype=np.float32), ) - def create_model(self, shape): + def create_model(self, shape, dtype=None, use_dtype=False, use_out=False, with_names=False): import torch + dtype_map = { + "float32": torch.float32, + "float64": torch.float64, + "int64": torch.int64, + "int32": torch.int32, + "uint8": torch.uint8, + "int8": torch.int8, + "bool": torch.bool + } class aten_full(torch.nn.Module): def __init__(self, shape): @@ -21,9 +29,53 @@ def __init__(self, shape): def forward(self, x: float): return torch.full(self.shape, x) + class aten_full_dtype(torch.nn.Module): + def __init__(self, shape, dtype): + super(aten_full_dtype, self).__init__() + self.shape = shape + self.dtype = dtype + + def forward(self, x: float): + return torch.full(self.shape, x, dtype=self.dtype) + + class aten_full_dtype_with_names(torch.nn.Module): + def __init__(self, shape, dtype): + super(aten_full_dtype_with_names, self).__init__() + self.shape = shape + self.dtype = dtype + + def forward(self, x: float): + return torch.full(self.shape, x, dtype=self.dtype, names=None) + + class aten_full_out(torch.nn.Module): + def __init__(self, shape, dtype): + super(aten_full_out, self).__init__() + self.shape = shape + self.dtype = dtype + + def forward(self, x: float): + return torch.full(self.shape, x, out=torch.tensor(1, dtype=self.dtype)) + + + class aten_full_out_with_names(torch.nn.Module): + def __init__(self, shape, dtype): + super(aten_full_out_with_names, self).__init__() + self.shape = shape + self.dtype = dtype + + def forward(self, x: float): + return torch.full(self.shape, x, out=torch.tensor(1, dtype=self.dtype), names=None) + ref_net = None + model = aten_full(shape) + if use_dtype or use_out: + dtype = dtype_map.get(dtype, dtype) + if not use_out: + model = aten_full_dtype(shape, dtype) if not with_names else aten_full_dtype_with_names(shape, dtype) + else: + model = aten_full_out(shape, dtype) if not with_names else aten_full_out_with_names(shape, dtype) - return aten_full(shape), ref_net, "aten::full" + return model, ref_net, "aten::full" @pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]]) @pytest.mark.parametrize("value", [0, 1, -1, 0.5]) @@ -32,22 +84,72 @@ def test_full(self, shape, value, ie_device, precision, ir_version): self._test(*self.create_model(shape), ie_device, precision, ir_version, kwargs_to_prepare_input={'value': value}) + @pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]]) + @pytest.mark.parametrize("value", [0, 1, -1, 0.5]) + @pytest.mark.parametrize("dtype", ["int8", "int32","int64", "float32", "float64"]) + @pytest.mark.parametrize("with_names", [True, False]) + @pytest.mark.nightly + def test_full_dtype(self, shape, value, dtype, with_names, ie_device, precision, ir_version): + self._test(*self.create_model(shape, dtype=dtype, use_dtype=True, with_names=with_names), ie_device, precision, + ir_version, kwargs_to_prepare_input={'value': value}) + + @pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]]) + @pytest.mark.parametrize("value", [0, 1, -1, 0.5]) + @pytest.mark.parametrize("dtype", ["int8", "int32","int64", "float32", "float64"]) + @pytest.mark.parametrize("with_names", [True, False]) + @pytest.mark.nightly + def test_full_out(self, shape, value, dtype, with_names, ie_device, precision, ir_version): + self._test(*self.create_model(shape, dtype=dtype, use_out=True, with_names=with_names), ie_device, precision, + ir_version, kwargs_to_prepare_input={'value': value}) class TestFullLike(PytorchLayerTest): def _prepare_input(self, value, shape): return (np.random.randn(*shape).astype(np.float32), np.array(value, dtype=np.float32), ) - def create_model(self): + def create_model(self, dtype=None, use_dtype=False, use_out=False): import torch + dtype_map = { + "float32": torch.float32, + "float64": torch.float64, + "int64": torch.int64, + "int32": torch.int32, + "uint8": torch.uint8, + "int8": torch.int8, + "bool": torch.bool + } class aten_full_like(torch.nn.Module): def forward(self, input_t: torch.Tensor, x: float): return torch.full_like(input_t, x) + class aten_full_like_dtype(torch.nn.Module): + def __init__(self, dtype): + super(aten_full_like_dtype, self).__init__() + self.dtype = dtype + + def forward(self, input_t: torch.Tensor, x: float): + return torch.full_like(input_t, x, dtype=self.dtype) + + class aten_full_like_out(torch.nn.Module): + def __init__(self, dtype): + super(aten_full_like_out, self).__init__() + self.dtype = dtype + + def forward(self, input_t: torch.Tensor, x: float): + return torch.full_like(input_t, x, out=torch.tensor(1, dtype=self.dtype)) + ref_net = None - return aten_full_like(), ref_net, "aten::full_like" + model = aten_full_like() + if use_dtype or use_out: + dtype = dtype_map.get(dtype, dtype) + if not use_out: + model = aten_full_like_dtype(dtype) + else: + model = aten_full_like_out(dtype) + + return model, ref_net, "aten::full_like" @pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]]) @pytest.mark.parametrize("value", [0, 1, -1, 0.5]) @@ -56,13 +158,38 @@ def test_full_like(self, shape, value, ie_device, precision, ir_version): self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={'value': value, 'shape': shape}) + @pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]]) + @pytest.mark.parametrize("value", [0, 1, -1, 0.5]) + @pytest.mark.parametrize("dtype", ["int8", "int32","int64", "float32", "float64"]) + @pytest.mark.nightly + def test_full_like_dtype(self, shape, value, dtype, ie_device, precision, ir_version): + self._test(*self.create_model(dtype, use_dtype=True), ie_device, precision, ir_version, + kwargs_to_prepare_input={'value': value, 'shape': shape}) + + @pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]]) + @pytest.mark.parametrize("value", [0, 1, -1, 0.5]) + @pytest.mark.parametrize("dtype", ["int8", "int32","int64", "float32", "float64"]) + @pytest.mark.nightly + def test_full_like_out(self, shape, value, dtype, ie_device, precision, ir_version): + self._test(*self.create_model(dtype, use_out=True), ie_device, precision, ir_version, + kwargs_to_prepare_input={'value': value, 'shape': shape}) + class TestNewFull(PytorchLayerTest): def _prepare_input(self, value, input_dtype=np.float32): return (np.random.randn(1, 3, 10, 10).astype(input_dtype), np.array(value, dtype=np.float32)) - def create_model(self, shape): + def create_model(self, shape, dtype=None, used_dtype=False): import torch + dtype_map = { + "float32": torch.float32, + "float64": torch.float64, + "int64": torch.int64, + "int32": torch.int32, + "uint8": torch.uint8, + "int8": torch.int8, + "bool": torch.bool + } class aten_full(torch.nn.Module): def __init__(self, shape): @@ -72,9 +199,24 @@ def __init__(self, shape): def forward(self, input_tensor: torch.Tensor, x: float): return input_tensor.new_full(self.shape, x) + class aten_full_with_dtype(torch.nn.Module): + def __init__(self, shape, dtype): + super(aten_full_with_dtype, self).__init__() + self.shape = shape + self.dtype = dtype + + def forward(self, input_tensor: torch.Tensor, x: float): + return input_tensor.new_full(size=self.shape, fill_value=x, dtype=self.dtype) + ref_net = None + model = aten_full(shape) + + if used_dtype: + dtype = dtype_map[dtype] + model = aten_full_with_dtype(shape, dtype) - return aten_full(shape), ref_net, "aten::new_full" + + return model, ref_net, "aten::new_full" @pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]]) @pytest.mark.parametrize("value,input_dtype", [(0, np.uint8), (1, np.int32), (-1, np.float32), (0.5, np.float64)]) @@ -83,12 +225,20 @@ def test_new_full(self, shape, value, input_dtype, ie_device, precision, ir_vers self._test(*self.create_model(shape), ie_device, precision, ir_version, kwargs_to_prepare_input={'value': value, 'input_dtype': input_dtype}) + @pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]]) + @pytest.mark.parametrize("value,input_dtype", [(0, np.uint8), (1, np.int32), (-1, np.float32), (0.5, np.float64)]) + @pytest.mark.parametrize("dtype", ["int8", "int32","int64", "float32", "float64"]) + @pytest.mark.nightly + def test_new_full_with_dtype(self, value, shape, dtype, input_dtype, ie_device, precision, ir_version): + self._test(*self.create_model(shape, dtype=dtype, used_dtype=True), ie_device, precision, ir_version, + kwargs_to_prepare_input={'value': value, 'input_dtype': input_dtype}) + class TestZerosAndOnes(PytorchLayerTest): def _prepare_input(self, shape): return (np.random.randn(*shape).astype(np.float32),) - def create_model(self, op_type): + def create_model(self, op_type, dtype=None, with_dtype=False, with_out=False, with_names=False): import torch ops = { "aten::zeros": torch.zeros, @@ -96,6 +246,15 @@ def create_model(self, op_type): "aten::zeros_like": torch.zeros_like, "aten::ones_like": torch.ones_like } + dtype_map = { + "float32": torch.float32, + "float64": torch.float64, + "int64": torch.int64, + "int32": torch.int32, + "uint8": torch.uint8, + "int8": torch.int8, + "bool": torch.bool + } class aten_op(torch.nn.Module): def __init__(self, op): @@ -114,12 +273,78 @@ def __init__(self, op): def forward(self, x): return self.op(x) - model_cls = aten_op_like if op_type.endswith('_like') else aten_op + class aten_op_dtype(torch.nn.Module): + def __init__(self, op, dtype): + super(aten_op_dtype, self).__init__() + self.op = op + self.dtype = dtype + + def forward(self, x): + shape = x.shape + return self.op(shape, dtype=self.dtype) + + class aten_op_dtype_with_names(aten_op_dtype): + def forward(self, x): + shape = x.shape + return self.op(shape, dtype=self.dtype, names=None) + + class aten_op_like_dtype(torch.nn.Module): + def __init__(self, op, dtype): + super(aten_op_like_dtype, self).__init__() + self.op = op + self.dtype = dtype + + def forward(self, x): + return self.op(x, dtype=self.dtype) + + class aten_op_out(torch.nn.Module): + def __init__(self, op, dtype): + super(aten_op_out, self).__init__() + self.op = op + self.dtype = dtype + + def forward(self, x): + shape = x.shape + return self.op(shape, out=torch.tensor(0, dtype=self.dtype)) + + class aten_op_out_with_names(torch.nn.Module): + def __init__(self, op, dtype): + super(aten_op_out_with_names, self).__init__() + self.op = op + self.dtype = dtype + + def forward(self, x): + shape = x.shape + return self.op(shape, out=torch.tensor(0, dtype=self.dtype), names=None) + + class aten_op_like_out(torch.nn.Module): + def __init__(self, op, dtype): + super(aten_op_like_out, self).__init__() + self.op = op + self.dtype = dtype + + def forward(self, x): + return self.op(x, out=torch.tensor(0, dtype=self.dtype)) + + like = op_type.endswith('_like') op = ops[op_type] + if not like: + model_cls = aten_op(op) + if with_dtype or with_out: + dtype = dtype_map[dtype] + if with_dtype: + model_cls = aten_op_dtype(op, dtype) if not with_names else aten_op_dtype_with_names(op, dtype) + if with_out: + model_cls = aten_op_out(op, dtype) if not with_names else aten_op_out_with_names(op, dtype) + else: + model_cls = aten_op_like(op) + if with_dtype or with_out: + dtype = dtype_map[dtype] + model_cls = aten_op_like_dtype(op, dtype) if not with_out else aten_op_like_out(op, dtype) ref_net = None - return model_cls(op), ref_net, op_type + return model_cls, ref_net, op_type @pytest.mark.parametrize("shape", [(1, 1), (1, 2), (1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5, 6)]) @pytest.mark.parametrize("op_type", ["aten::zeros", "aten::ones", "aten::zeros_like", "aten::ones_like"]) @@ -128,13 +353,56 @@ def test_fill(self, op_type, shape, ie_device, precision, ir_version): self._test(*self.create_model(op_type), ie_device, precision, ir_version, kwargs_to_prepare_input={'shape': shape}) + @pytest.mark.parametrize("shape", [(1, 1), (1, 2), (1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5, 6)]) + @pytest.mark.parametrize("op_type", ["aten::zeros", "aten::ones"]) + @pytest.mark.parametrize("dtype", ["int8", "int32","int64", "float32", "float64"]) + @pytest.mark.parametrize("with_names", [True, False]) + @pytest.mark.nightly + def test_fill_with_dtype(self, op_type, shape, dtype, with_names, ie_device, precision, ir_version): + self._test(*self.create_model(op_type, dtype=dtype, with_dtype=True, with_names=with_names), ie_device, precision, + ir_version, kwargs_to_prepare_input={'shape': shape}) + + @pytest.mark.parametrize("shape", [(1, 1), (1, 2), (1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5, 6)]) + @pytest.mark.parametrize("op_type", ["aten::zeros", "aten::ones"]) + @pytest.mark.parametrize("dtype", ["int8", "int32","int64", "float32", "float64"]) + @pytest.mark.parametrize("with_names", [True, False]) + @pytest.mark.nightly + def test_fill_with_out(self, op_type, shape, dtype, with_names, ie_device, precision, ir_version): + self._test(*self.create_model(op_type, dtype=dtype, with_out=True, with_names=with_names), ie_device, precision, + ir_version, kwargs_to_prepare_input={'shape': shape}) + + @pytest.mark.parametrize("shape", [(1, 1), (1, 2), (1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5, 6)]) + @pytest.mark.parametrize("op_type", ["aten::zeros_like", "aten::ones_like"]) + @pytest.mark.parametrize("dtype", ["int8", "int32","int64", "float32", "float64"]) + @pytest.mark.nightly + def test_fill_like_with_dtype(self, op_type, shape, dtype, ie_device, precision, ir_version): + self._test(*self.create_model(op_type, dtype=dtype, with_dtype=True), ie_device, precision, + ir_version, kwargs_to_prepare_input={'shape': shape}) + + @pytest.mark.parametrize("shape", [(1, 1), (1, 2), (1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5, 6)]) + @pytest.mark.parametrize("op_type", ["aten::zeros_like", "aten::ones_like"]) + @pytest.mark.parametrize("dtype", ["int8", "int32","int64", "float32", "float64"]) + @pytest.mark.nightly + def test_fill_like_with_out(self, op_type, shape, dtype, ie_device, precision, ir_version): + self._test(*self.create_model(op_type, dtype=dtype, with_out=True), ie_device, precision, + ir_version, kwargs_to_prepare_input={'shape': shape}) + class TestNewZeros(PytorchLayerTest): def _prepare_input(self, input_dtype=np.float32): return (np.random.randn(1, 3, 10, 10).astype(input_dtype), ) - def create_model(self, shape): + def create_model(self, shape, dtype=None, used_dtype=False): import torch + dtype_map = { + "float32": torch.float32, + "float64": torch.float64, + "int64": torch.int64, + "int32": torch.int32, + "uint8": torch.uint8, + "int8": torch.int8, + "bool": torch.bool + } class aten_full(torch.nn.Module): def __init__(self, shape): @@ -144,9 +412,24 @@ def __init__(self, shape): def forward(self, input_tensor: torch.Tensor): return input_tensor.new_zeros(self.shape) + class aten_full_with_dtype(torch.nn.Module): + def __init__(self, shape, dtype): + super(aten_full_with_dtype, self).__init__() + self.shape = shape + self.dtype = dtype + + def forward(self, input_tensor: torch.Tensor): + return input_tensor.new_zeros(self.shape, dtype=self.dtype) + ref_net = None + model = aten_full(shape) + + if used_dtype: + dtype = dtype_map[dtype] + model = aten_full_with_dtype(shape, dtype) + - return aten_full(shape), ref_net, "aten::new_zeros" + return model, ref_net, "aten::new_zeros" @pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]]) @pytest.mark.parametrize("input_dtype", [np.uint8, np.int8, np.int32, np.int64, np.float32, np.float64]) @@ -155,13 +438,30 @@ def test_new_zeros(self, shape, input_dtype, ie_device, precision, ir_version): self._test(*self.create_model(shape), ie_device, precision, ir_version, kwargs_to_prepare_input={'input_dtype': input_dtype}) + @pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]]) + @pytest.mark.parametrize("input_dtype", [bool, np.uint8, np.int8, np.int32, np.int64, np.float32, np.float64]) + @pytest.mark.parametrize("dtype", ["bool", "uint8", "int8", "int32","int64", "float32", "float64"]) + @pytest.mark.nightly + def test_new_zeros_with_dtype(self, shape, dtype, input_dtype, ie_device, precision, ir_version): + self._test(*self.create_model(shape, dtype=dtype, used_dtype=True), ie_device, precision, ir_version, + kwargs_to_prepare_input={'input_dtype': input_dtype}) + class TestNewOnes(PytorchLayerTest): def _prepare_input(self, input_dtype=np.float32): return (np.random.randn(1, 3, 10, 10).astype(input_dtype), ) - def create_model(self, shape): + def create_model(self, shape, dtype=None, used_dtype=False): import torch + dtype_map = { + "float32": torch.float32, + "float64": torch.float64, + "int64": torch.int64, + "int32": torch.int32, + "uint8": torch.uint8, + "int8": torch.int8, + "bool": torch.bool + } class aten_full(torch.nn.Module): def __init__(self, shape): @@ -171,9 +471,24 @@ def __init__(self, shape): def forward(self, input_tensor: torch.Tensor): return input_tensor.new_ones(self.shape) + class aten_full_with_dtype(torch.nn.Module): + def __init__(self, shape, dtype): + super(aten_full_with_dtype, self).__init__() + self.shape = shape + self.dtype = dtype + + def forward(self, input_tensor: torch.Tensor): + return input_tensor.new_ones(self.shape, dtype=self.dtype) + ref_net = None + model = aten_full(shape) + + if used_dtype: + dtype = dtype_map[dtype] + model = aten_full_with_dtype(shape, dtype) - return aten_full(shape), ref_net, "aten::new_ones" + + return model, ref_net, "aten::new_ones" @pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]]) @pytest.mark.parametrize("input_dtype", [np.uint8, np.int8, np.int32, np.int64, np.float32, np.float64]) @@ -181,3 +496,11 @@ def forward(self, input_tensor: torch.Tensor): def test_new_ones(self, shape, input_dtype, ie_device, precision, ir_version): self._test(*self.create_model(shape), ie_device, precision, ir_version, kwargs_to_prepare_input={'input_dtype': input_dtype}) + + @pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]]) + @pytest.mark.parametrize("input_dtype", [bool, np.uint8, np.int8, np.int32, np.int64, np.float32, np.float64]) + @pytest.mark.parametrize("dtype", ["bool", "uint8", "int8", "int32","int64", "float32", "float64"]) + @pytest.mark.nightly + def test_new_ones_with_dtype(self, shape, dtype, input_dtype, ie_device, precision, ir_version): + self._test(*self.create_model(shape, dtype=dtype, used_dtype=True), ie_device, precision, ir_version, + kwargs_to_prepare_input={'input_dtype': input_dtype}) From 900e7002a8929c23ff45ee4ff8f40870f77bf05f Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 2 Jan 2023 18:45:00 +0400 Subject: [PATCH 4/5] reduce code duplication --- src/frontends/pytorch/src/op/arange.cpp | 94 ++++++++++--------------- 1 file changed, 38 insertions(+), 56 deletions(-) diff --git a/src/frontends/pytorch/src/op/arange.cpp b/src/frontends/pytorch/src/op/arange.cpp index 5e6b1dd3b023d9..5ea5b448112699 100644 --- a/src/frontends/pytorch/src/op/arange.cpp +++ b/src/frontends/pytorch/src/op/arange.cpp @@ -10,92 +10,74 @@ namespace ov { namespace frontend { namespace pytorch { namespace op { - +namespace { +ov::element::Type get_output_dtype(NodeContext& context, size_t input_id) { + auto pt_type = context.const_input(input_id); + FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::arange: ", pt_type); + return TORCH_TO_OV_TYPE.at(pt_type); +} +} // namespace OutputVector translate_arange(NodeContext& context) { auto zero = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {0})); auto one = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {1})); auto dtype = element::f32; bool dtype_applied = false; int num_inputs = context.get_input_size(); + ov::Output end; + ov::Output out_tensor; + ov::Output start = zero; + ov::Output step = one; + // aten::arange(Scalar end, tensor out) if (num_inputs == 2) { - auto end = context.get_input(0); - auto range = context.mark_node(std::make_shared(zero, end, one, dtype)); - if (!context.input_is_none(1)) { - auto out_tensor = context.get_input(1); - return {context.mark_node(std::make_shared(range, out_tensor))}; - } - return {context.mark_node(std::make_shared(range, end))}; + end = context.get_input(0); + out_tensor = context.input_is_none(1) ? end : context.get_input(1); } // # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) if (num_inputs == 4) { - auto start = context.get_input(0); - auto end = context.get_input(1); - auto step = context.get_input(2); - auto range = context.mark_node(std::make_shared(start, end, step, dtype)); - if (!context.input_is_none(3)) { - auto out_tensor = context.get_input(3); - return {context.mark_node(std::make_shared(range, out_tensor))}; - } - return {context.mark_node(std::make_shared(range, end))}; + start = context.get_input(0); + end = context.get_input(1); + step = context.get_input(2); + out_tensor = context.input_is_none(3) ? end : context.get_input(3); } // aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) if (num_inputs == 5) { - auto end = context.get_input(0); + end = context.get_input(0); + out_tensor = end; if (!context.input_is_none(1)) { - auto pt_type = context.const_input(1); - FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::arange: ", pt_type); - dtype = TORCH_TO_OV_TYPE.at(pt_type); - end = context.mark_node(std::make_shared(end, dtype)); - zero = context.mark_node(std::make_shared(zero, dtype)); - one = context.mark_node(std::make_shared(one, dtype)); + dtype = get_output_dtype(context, 1); dtype_applied = true; } - auto range = context.mark_node(std::make_shared(zero, end, one, dtype)); - if (!dtype_applied) { - return {context.mark_node(std::make_shared(range, end))}; - } - return {range}; } // aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) if (num_inputs == 6) { - auto start = context.get_input(0); - auto end = context.get_input(1); + start = context.get_input(0); + end = context.get_input(1); + out_tensor = end; if (!context.input_is_none(2)) { - auto pt_type = context.const_input(2); - FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::arange: ", pt_type); - dtype = TORCH_TO_OV_TYPE.at(pt_type); + dtype = get_output_dtype(context, 2); dtype_applied = true; - end = context.mark_node(std::make_shared(end, dtype)); - start = context.mark_node(std::make_shared(start, dtype)); - one = context.mark_node(std::make_shared(one, dtype)); - } - auto range = context.mark_node(std::make_shared(start, end, one, dtype)); - if (!dtype_applied) { - return {context.mark_node(std::make_shared(range, end))}; } - return {range}; } // aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) if (num_inputs == 7) { - auto start = context.get_input(0); - auto end = context.get_input(1); - auto step = context.get_input(2); + start = context.get_input(0); + end = context.get_input(1); + step = context.get_input(2); + out_tensor = end; if (!context.input_is_none(3)) { - auto pt_type = context.const_input(3); - FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::arange: ", pt_type); - dtype = TORCH_TO_OV_TYPE.at(pt_type); - end = context.mark_node(std::make_shared(end, dtype)); - start = context.mark_node(std::make_shared(start, dtype)); - step = context.mark_node(std::make_shared(step, dtype)); + dtype = get_output_dtype(context, 2); dtype_applied = true; } - auto range = context.mark_node(std::make_shared(start, end, step, dtype)); - if (!dtype_applied) { - return {context.mark_node(std::make_shared(range, end))}; - } - return {range}; } + auto r_end = context.mark_node(std::make_shared(end, dtype)); + auto r_start = context.mark_node(std::make_shared(start, dtype)); + auto r_step = context.mark_node(std::make_shared(step, dtype)); + auto range = context.mark_node(std::make_shared(r_start, r_end, r_step, dtype)); + if (!dtype_applied) { + range = context.mark_node(std::make_shared(range, out_tensor)); + } + return {range}; }; } // namespace op From 438f077f98068a32e309470b3007a55f99346256 Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 2 Jan 2023 19:26:19 +0400 Subject: [PATCH 5/5] refactor --- src/frontends/pytorch/src/op/arange.cpp | 13 +-- src/frontends/pytorch/src/op/full.cpp | 142 ++++++++++-------------- src/frontends/pytorch/src/utils.cpp | 6 + src/frontends/pytorch/src/utils.hpp | 2 + 4 files changed, 68 insertions(+), 95 deletions(-) diff --git a/src/frontends/pytorch/src/op/arange.cpp b/src/frontends/pytorch/src/op/arange.cpp index 5ea5b448112699..3fafb2f71da0b3 100644 --- a/src/frontends/pytorch/src/op/arange.cpp +++ b/src/frontends/pytorch/src/op/arange.cpp @@ -10,13 +10,6 @@ namespace ov { namespace frontend { namespace pytorch { namespace op { -namespace { -ov::element::Type get_output_dtype(NodeContext& context, size_t input_id) { - auto pt_type = context.const_input(input_id); - FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::arange: ", pt_type); - return TORCH_TO_OV_TYPE.at(pt_type); -} -} // namespace OutputVector translate_arange(NodeContext& context) { auto zero = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {0})); auto one = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {1})); @@ -45,7 +38,7 @@ OutputVector translate_arange(NodeContext& context) { end = context.get_input(0); out_tensor = end; if (!context.input_is_none(1)) { - dtype = get_output_dtype(context, 1); + dtype = convert_dtype(context, 1); dtype_applied = true; } } @@ -55,7 +48,7 @@ OutputVector translate_arange(NodeContext& context) { end = context.get_input(1); out_tensor = end; if (!context.input_is_none(2)) { - dtype = get_output_dtype(context, 2); + dtype = convert_dtype(context, 2); dtype_applied = true; } } @@ -66,7 +59,7 @@ OutputVector translate_arange(NodeContext& context) { step = context.get_input(2); out_tensor = end; if (!context.input_is_none(3)) { - dtype = get_output_dtype(context, 2); + dtype = convert_dtype(context, 2); dtype_applied = true; } } diff --git a/src/frontends/pytorch/src/op/full.cpp b/src/frontends/pytorch/src/op/full.cpp index 085f4e2285b2df..91caf1d0166877 100644 --- a/src/frontends/pytorch/src/op/full.cpp +++ b/src/frontends/pytorch/src/op/full.cpp @@ -11,167 +11,139 @@ namespace frontend { namespace pytorch { namespace op { +ov::Output base_translate_full(NodeContext& context, ov::Output sizes, ov::Output value) { + return context.mark_node(std::make_shared(value, sizes)); +} + +ov::Output base_translate_full_with_convert(NodeContext& context, + ov::Output sizes, + ov::Output value, + size_t dtype_id) { + auto filled_tensor = base_translate_full(context, sizes, value); + if (!context.input_is_none(dtype_id)) { + auto dtype = convert_dtype(context, dtype_id); + filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); + } + return filled_tensor; +} + +ov::Output base_translate_full_with_convertlike(NodeContext& context, + ov::Output sizes, + ov::Output value, + ov::Output out) { + auto filled_tensor = base_translate_full(context, sizes, value); + return context.mark_node(std::make_shared(filled_tensor, out)); +} + OutputVector translate_full(NodeContext& context) { auto sizes = context.get_input(0); auto value = context.get_input(1); int num_inputs = context.get_input_size(); - - auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); if (num_inputs < 6) { size_t out_id = num_inputs == 3 ? 2 : 3; if (!context.input_is_none(out_id)) { auto out = context.get_input(out_id); - return {context.mark_node(std::make_shared(filled_tensor, out))}; + return {base_translate_full_with_convertlike(context, sizes, value, out)}; } + return {base_translate_full(context, sizes, value)}; } size_t dtype_id = num_inputs == 6 ? 2 : 3; - if (!context.input_is_none(dtype_id)) { - auto pt_type = context.const_input(dtype_id); - FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::full: ", pt_type); - auto dtype = TORCH_TO_OV_TYPE.at(pt_type); - filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); - } - return {filled_tensor}; + return {base_translate_full_with_convert(context, sizes, value, dtype_id)}; }; OutputVector translate_full_like(NodeContext& context) { auto input = context.get_input(0); auto value = context.get_input(1); - auto input_shape = context.mark_node(std::make_shared(input)); - auto filled_tensor = context.mark_node(std::make_shared(value, input_shape)); - if (context.get_input_size() == 7 && !context.input_is_none(2)) { - auto pt_type = context.const_input(2); - FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::full_like: ", pt_type); - auto dtype = TORCH_TO_OV_TYPE.at(pt_type); - filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); - } else { - auto out_dtype = context.input_is_none(3) ? input : context.get_input(3); - filled_tensor = context.mark_node(std::make_shared(filled_tensor, out_dtype)); + auto sizes = context.mark_node(std::make_shared(input)); + if (context.get_input_size() == 7) { + return {base_translate_full_with_convert(context, sizes, value, 2)}; } - return {filled_tensor}; + auto out = context.input_is_none(3) ? input : context.get_input(3); + return {base_translate_full_with_convertlike(context, sizes, value, out)}; }; OutputVector translate_new_full(NodeContext& context) { auto input = context.get_input(0); auto sizes = context.get_input(1); auto value = context.get_input(2); - auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); - if (context.get_input_size() == 7 && !context.input_is_none(3)) { - auto pt_type = context.const_input(3); - FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::new_full: ", pt_type); - auto dtype = TORCH_TO_OV_TYPE.at(pt_type); - return {context.mark_node(std::make_shared(filled_tensor, dtype))}; + if (context.get_input_size() == 7) { + return {base_translate_full_with_convert(context, sizes, value, 3)}; } - return {context.mark_node(std::make_shared(filled_tensor, input))}; + return {base_translate_full_with_convertlike(context, sizes, value, input)}; }; OutputVector translate_zeros(NodeContext& context) { auto sizes = context.get_input(0); auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0})); - auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); int num_inputs = context.get_input_size(); if (num_inputs < 5) { size_t out_id = num_inputs == 2 ? 1 : 2; if (!context.input_is_none(out_id)) { auto out = context.get_input(out_id); - return {context.mark_node(std::make_shared(filled_tensor, out))}; + return {base_translate_full_with_convertlike(context, sizes, value, out)}; } - return {filled_tensor}; + return {base_translate_full(context, sizes, value)}; } size_t dtype_id = num_inputs == 5 ? 1 : 2; - if (!context.input_is_none(dtype_id)) { - std::cout << dtype_id << std::endl; - auto pt_type = context.const_input(dtype_id); - std::cout << pt_type << std::endl; - FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::zeros: ", pt_type); - auto dtype = TORCH_TO_OV_TYPE.at(pt_type); - filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); - } - return {filled_tensor}; + return {base_translate_full_with_convert(context, sizes, value, dtype_id)}; }; OutputVector translate_zeros_like(NodeContext& context) { auto input = context.get_input(0); auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0})); - auto input_shape = context.mark_node(std::make_shared(input)); - auto filled_tensor = context.mark_node(std::make_shared(value, input_shape)); - if (context.get_input_size() == 6 && !context.input_is_none(1)) { - auto pt_type = context.const_input(1); - FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::zeros_like: ", pt_type); - auto dtype = TORCH_TO_OV_TYPE.at(pt_type); - filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); - } else { - auto out_dtype = context.input_is_none(2) ? input : context.get_input(2); - filled_tensor = context.mark_node(std::make_shared(filled_tensor, out_dtype)); + auto sizes = context.mark_node(std::make_shared(input)); + if (context.get_input_size() == 6) { + return {base_translate_full_with_convert(context, sizes, value, 1)}; } - return {filled_tensor}; + auto out = context.input_is_none(2) ? input : context.get_input(2); + return {base_translate_full_with_convertlike(context, sizes, value, out)}; }; OutputVector translate_new_zeros(NodeContext& context) { auto input = context.get_input(0); auto sizes = context.get_input(1); auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0})); - auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); - if (context.get_input_size() == 6 && !context.input_is_none(2)) { - auto pt_type = context.const_input(2); - FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::new_zeros: ", pt_type); - auto dtype = TORCH_TO_OV_TYPE.at(pt_type); - return {context.mark_node(std::make_shared(filled_tensor, dtype))}; + if (context.get_input_size() == 6) { + return {base_translate_full_with_convert(context, sizes, value, 2)}; } - return {context.mark_node(std::make_shared(filled_tensor, input))}; + return {base_translate_full_with_convertlike(context, sizes, value, input)}; }; OutputVector translate_ones(NodeContext& context) { auto sizes = context.get_input(0); auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1})); - auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); int num_inputs = context.get_input_size(); if (num_inputs < 5) { size_t out_id = num_inputs == 2 ? 1 : 2; if (!context.input_is_none(out_id)) { auto out = context.get_input(out_id); - return {context.mark_node(std::make_shared(filled_tensor, out))}; + return {base_translate_full_with_convertlike(context, sizes, value, out)}; } + return {base_translate_full(context, sizes, value)}; } size_t dtype_id = num_inputs == 5 ? 1 : 2; - if (!context.input_is_none(dtype_id)) { - auto pt_type = context.const_input(dtype_id); - FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::ones: ", pt_type); - auto dtype = TORCH_TO_OV_TYPE.at(pt_type); - filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); - } - return {filled_tensor}; + return {base_translate_full_with_convert(context, sizes, value, dtype_id)}; }; OutputVector translate_ones_like(NodeContext& context) { auto input = context.get_input(0); auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1})); - auto input_shape = context.mark_node(std::make_shared(input)); - auto filled_tensor = context.mark_node(std::make_shared(value, input_shape)); - if (context.get_input_size() == 6 && !context.input_is_none(1)) { - auto pt_type = context.const_input(1); - FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::ones_like: ", pt_type); - auto dtype = TORCH_TO_OV_TYPE.at(pt_type); - filled_tensor = context.mark_node(std::make_shared(filled_tensor, dtype)); - } else { - auto out_dtype = context.input_is_none(2) ? input : context.get_input(2); - filled_tensor = context.mark_node(std::make_shared(filled_tensor, out_dtype)); + auto sizes = context.mark_node(std::make_shared(input)); + if (context.get_input_size() == 6) { + return {base_translate_full_with_convert(context, sizes, value, 1)}; } - return {filled_tensor}; + auto out = context.input_is_none(2) ? input : context.get_input(2); + return {base_translate_full_with_convertlike(context, sizes, value, out)}; }; OutputVector translate_new_ones(NodeContext& context) { auto input = context.get_input(0); auto sizes = context.get_input(1); auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1})); - auto filled_tensor = context.mark_node(std::make_shared(value, sizes)); - if (context.get_input_size() == 6 && !context.input_is_none(2)) { - auto pt_type = context.const_input(2); - FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::new_zeros: ", pt_type); - auto dtype = TORCH_TO_OV_TYPE.at(pt_type); - return {context.mark_node(std::make_shared(filled_tensor, dtype))}; + if (context.get_input_size() == 6) { + return {base_translate_full_with_convert(context, sizes, value, 2)}; } - return {context.mark_node(std::make_shared(filled_tensor, input))}; + return {base_translate_full_with_convertlike(context, sizes, value, input)}; }; } // namespace op diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index 8234867ff4c577..16f8b075479e5f 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -118,6 +118,12 @@ std::shared_ptr numel(NodeContext& context, size_t input_id) { return context.mark_node(std::make_shared(input_shape, axes, false)); }; +ov::element::Type convert_dtype(NodeContext& context, size_t input_id) { + auto pt_type = context.const_input(input_id); + FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type: ", pt_type); + return TORCH_TO_OV_TYPE.at(pt_type); +}; + OutputVector make_framework_node(NodeContext* context) { auto schema = context->get_schema(); // TODO: properly process schema to get the actual position of mutable input diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index fb3da70f20e323..3a9a3797ada9ce 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -50,6 +50,8 @@ std::shared_ptr get_axes_range(NodeContext& context, size_t input_id); std::shared_ptr numel(NodeContext& context, size_t input_id); +ov::element::Type convert_dtype(NodeContext& context, size_t input_id); + std::shared_ptr convert_pytorch_model(std::shared_ptr pytorch_model, const TensorMap& external_tensor_map = {});