diff --git a/src/frontends/pytorch/src/op/linspace.cpp b/src/frontends/pytorch/src/op/linspace.cpp new file mode 100644 index 00000000000000..c2233bee15ee24 --- /dev/null +++ b/src/frontends/pytorch/src/op/linspace.cpp @@ -0,0 +1,79 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/equal.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/subtract.hpp" +#include "pt_framework_node.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_linspace(const NodeContext& context) { + num_inputs_check(context, 3, 7); + // "aten::linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? + // device=None, bool? pin_memory=None) -> Tensor" + + // "aten::linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!)" + auto start = context.mark_node(std::make_shared(context.get_input(0), element::f32)); + auto end = context.mark_node(std::make_shared(context.get_input(1), element::f32)); + auto steps = context.mark_node(std::make_shared(context.get_input(2), element::f32)); + auto out_tensor = context.get_input(1); + auto apply_dtype = true; + auto dtype = element::f32; + if (!context.input_is_none(3) && context.get_input_size() == 7) { + // Case where dtype is provided directly in dtype input. + if (std::dynamic_pointer_cast(context.get_input_from_visible_context(3).get_node_shared_ptr())) { + dtype = convert_dtype(context.const_input(3)); + apply_dtype = true; + } else if (const auto& fw_node = cast_fw_node(context.get_input(3).get_node_shared_ptr(), "prim::dtype")) { + out_tensor = fw_node->input_value(0); + apply_dtype = false; + } else { + FRONT_END_OP_CONVERSION_CHECK(false, "Couldn't get dtype input"); + } + } else if (!context.input_is_none(3) && context.get_input_size() == 4) { + // Case where dtype is inherited from out tensor. + out_tensor = context.get_input(3); + apply_dtype = false; + } + + auto const_0 = v0::Constant::create(element::f32, Shape{}, {0}); + auto const_1 = v0::Constant::create(element::f32, Shape{}, {1}); + auto step_range = context.mark_node(std::make_shared(const_0, steps, const_1, element::f32)); + + auto sub_end_start = context.mark_node(std::make_shared(end, start)); + auto sub_steps_1 = context.mark_node(std::make_shared(steps, const_1)); + auto step_multiplier = context.mark_node(std::make_shared(sub_end_start, sub_steps_1)); + auto is_single_step = context.mark_node(std::make_shared(steps, const_1)); + auto select_multiplier = context.mark_node(std::make_shared(is_single_step, const_0, step_multiplier)); + auto step_values = context.mark_node(std::make_shared(step_range, select_multiplier)); + + auto linspace = context.mark_node(std::make_shared(step_values, start)); + if (apply_dtype) { + linspace = context.mark_node(std::make_shared(linspace, dtype)); + } else { + linspace = context.mark_node(std::make_shared(linspace, out_tensor)); + } + + return {linspace}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index ac0e66de932129..c422f6beae6ec2 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -84,6 +84,7 @@ OP_CONVERTER(translate_linalg_norm); OP_CONVERTER(translate_linalg_matrix_norm); OP_CONVERTER(translate_linalg_vector_norm); OP_CONVERTER(translate_linear); +OP_CONVERTER(translate_linspace); OP_CONVERTER(translate_list_construct); OP_CONVERTER(translate_list_unpack); OP_CONVERTER(translate_log); @@ -328,6 +329,7 @@ const std::map get_supported_ops_ts() { {"aten::linalg_matrix_norm", op::translate_linalg_matrix_norm}, {"aten::linalg_vector_norm", op::translate_linalg_vector_norm}, {"aten::linear", op::translate_linear}, + {"aten::linspace", op::translate_linspace}, {"aten::log", op::translate_log}, {"aten::log_", op::inplace_op}, {"aten::log_softmax", op::translate_log_softmax}, diff --git a/tests/layer_tests/pytorch_tests/test_linspace.py b/tests/layer_tests/pytorch_tests/test_linspace.py new file mode 100644 index 00000000000000..aa6f70d3d71c89 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_linspace.py @@ -0,0 +1,89 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch +from pytorch_layer_test_class import PytorchLayerTest + + +class TestLinspace(PytorchLayerTest): + def _prepare_input(self, start, end, steps, dtype=None, ref_dtype=None): + inputs = [np.array(start).astype(dtype), np.array(end).astype(dtype), np.array(steps).astype("int32")] + if ref_dtype: + inputs.append(np.zeros(1).astype(ref_dtype)) + return inputs + + def create_model(self, dtype=None, use_out=False, ref_dtype=False): + dtype_map = { + "float32": torch.float32, + "float64": torch.float64, + "int64": torch.int64, + "int32": torch.int32, + "uint8": torch.uint8, + "int8": torch.int8, + } + + class aten_linspace_dtype(torch.nn.Module): + def __init__(self, dtype) -> None: + super().__init__() + self.dtype = dtype + + def forward(self, start, end, steps): + return torch.linspace(start=start, end=end, steps=steps, dtype=self.dtype) + + class aten_linspace_out(torch.nn.Module): + def __init__(self, out) -> None: + super().__init__() + # Size of empty tensor needs to be of equal or larger size than linspace steps + self.out = torch.empty(25, dtype=out) + + def forward(self, start, end, steps): + return torch.linspace(start=start, end=end, steps=steps, out=self.out) + + class aten_linspace_prim_dtype(torch.nn.Module): + def forward(self, start, end, steps, d): + return torch.linspace(start=start, end=end, steps=steps, dtype=d.dtype) + + dtype = dtype_map.get(dtype) + if ref_dtype: + model_class = aten_linspace_prim_dtype() + elif not use_out: + model_class = aten_linspace_dtype(dtype) + else: + model_class = aten_linspace_out(dtype) + + ref_net = None + + return model_class, ref_net, "aten::linspace" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8"]) + @pytest.mark.parametrize( + "start,end,steps", [(0, 1, 5), (-2, 1, 5), (1, -5, 7), (1, 10, 2), (-1, -5, 2), (-1, -5, 1), (1.25, -5.5, 5)] + ) + def test_linspace_with_prim_dtype(self, dtype, end, start, steps, ie_device, precision, ir_version): + self._test( + *self.create_model(dtype, ref_dtype=True), + ie_device, + precision, + ir_version, + kwargs_to_prepare_input={"end": end, "start": start, "steps": steps, "ref_dtype": dtype} + ) + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8", "uin8"]) + @pytest.mark.parametrize( + "start,end,steps", [(0, 1, 5), (-2, 1, 5), (1, -5, 7), (1, 10, 2), (-1, -5, 2), (-1, -5, 1), (1.25, -5.5, 5)] + ) + @pytest.mark.parametrize("use_out", [False, True]) + def test_linspace_with_out(self, dtype, use_out, end, start, steps, ie_device, precision, ir_version): + self._test( + *self.create_model(dtype=dtype, use_out=use_out), + ie_device, + precision, + ir_version, + kwargs_to_prepare_input={"end": end, "start": start, "steps": steps} + )