From bbb5e91b5e014a7d3b9d51ed54b4ef3db6ac882d Mon Sep 17 00:00:00 2001 From: eaidova Date: Tue, 10 Jan 2023 09:32:56 +0400 Subject: [PATCH 1/3] aten::selu --- src/frontends/pytorch/src/op/selu.cpp | 24 ++++++++++++++ src/frontends/pytorch/src/op_table.cpp | 3 ++ tests/layer_tests/pytorch_tests/test_selu.py | 33 ++++++++++++++++++++ 3 files changed, 60 insertions(+) create mode 100644 src/frontends/pytorch/src/op/selu.cpp create mode 100644 tests/layer_tests/pytorch_tests/test_selu.py diff --git a/src/frontends/pytorch/src/op/selu.cpp b/src/frontends/pytorch/src/op/selu.cpp new file mode 100644 index 00000000000000..7a158db38690b2 --- /dev/null +++ b/src/frontends/pytorch/src/op/selu.cpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/opsets/opset8.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_selu(NodeContext& context) { + auto x = context.get_input(0); + auto alpha = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1.6732632423543772848170429916717})); + auto lambda = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1.0507009873554804934193349852946})); + return {context.mark_node(std::make_shared(x, alpha, lambda))}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 7e7456749b0fde..fab30c725a0a99 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -70,6 +70,7 @@ OP_CONVERTER(translate_rsub); OP_CONVERTER(translate_roll); OP_CONVERTER(translate_rsqrt); OP_CONVERTER(translate_select); +OP_CONVERTER(translate_selu); OP_CONVERTER(translate_size); OP_CONVERTER(translate_slice); OP_CONVERTER(translate_softmax); @@ -209,6 +210,8 @@ const std::map get_supported_ops() { {"aten::roll", op::translate_roll}, {"aten::rsqrt", op::translate_rsqrt}, {"aten::select", op::translate_select}, + {"aten::selu", op::translate_selu}, + {"aten::selu_", op::inplace_op}, {"aten::sigmoid", op::translate_1to1_match_1_inputs}, {"aten::silu", op::translate_1to1_match_1_inputs}, {"aten::silu_", op::inplace_op>}, diff --git a/tests/layer_tests/pytorch_tests/test_selu.py b/tests/layer_tests/pytorch_tests/test_selu.py new file mode 100644 index 00000000000000..854db09a40deaf --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_selu.py @@ -0,0 +1,33 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from pytorch_layer_test_class import PytorchLayerTest + + +class TestSilu(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(1, 3, 224, 224).astype(np.float32),) + + def create_model(self, inplace=False): + + import torch + import torch.nn.functional as F + + class aten_selu(torch.nn.Module): + def __init__(self, inplace): + super(aten_selu, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x, F.selu(x, inplace=self.inplace) + + ref_net = None + + return aten_selu(inplace), ref_net, "aten::selu" if not inplace else "aten::selu_" + + @pytest.mark.nightly + @pytest.mark.parametrize("inplace", [True, False]) + def test_silu(self, inplace, ie_device, precision, ir_version): + self._test(*self.create_model(inplace), ie_device, precision, ir_version) \ No newline at end of file From d6068b9a53e0b5e535d6728009b199ed4fea3f8e Mon Sep 17 00:00:00 2001 From: eaidova Date: Tue, 10 Jan 2023 20:32:53 +0400 Subject: [PATCH 2/3] execute in fp64 and convert back to initial precision --- src/frontends/pytorch/src/op/selu.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/frontends/pytorch/src/op/selu.cpp b/src/frontends/pytorch/src/op/selu.cpp index 7a158db38690b2..5cd89328a42b55 100644 --- a/src/frontends/pytorch/src/op/selu.cpp +++ b/src/frontends/pytorch/src/op/selu.cpp @@ -13,9 +13,11 @@ namespace op { OutputVector translate_selu(NodeContext& context) { auto x = context.get_input(0); - auto alpha = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1.6732632423543772848170429916717})); - auto lambda = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1.0507009873554804934193349852946})); - return {context.mark_node(std::make_shared(x, alpha, lambda))}; + auto alpha = context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.6732632423543772848170429916717})); + auto lambda = context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.0507009873554804934193349852946})); + auto x_f64 = context.mark_node(std::make_shared(x, element::f64)); + auto res = context.mark_node(std::make_shared(x, alpha, lambda)); + return {context.mark_node(std::make_shared(res, x)}; }; } // namespace op From 10ed57c07a438f98de561d1c1af1f3e7c569ff54 Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 11 Jan 2023 17:11:41 +0400 Subject: [PATCH 3/3] convert constants to input type --- src/frontends/pytorch/src/op/selu.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/frontends/pytorch/src/op/selu.cpp b/src/frontends/pytorch/src/op/selu.cpp index 5cd89328a42b55..cb7faaaef99efe 100644 --- a/src/frontends/pytorch/src/op/selu.cpp +++ b/src/frontends/pytorch/src/op/selu.cpp @@ -13,11 +13,13 @@ namespace op { OutputVector translate_selu(NodeContext& context) { auto x = context.get_input(0); - auto alpha = context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.6732632423543772848170429916717})); - auto lambda = context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.0507009873554804934193349852946})); - auto x_f64 = context.mark_node(std::make_shared(x, element::f64)); - auto res = context.mark_node(std::make_shared(x, alpha, lambda)); - return {context.mark_node(std::make_shared(res, x)}; + auto alpha = + context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.6732632423543772848170429916717})); + auto lambda = + context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.0507009873554804934193349852946})); + alpha = context.mark_node(std::make_shared(alpha, x)); + lambda = context.mark_node(std::make_shared(lambda, x)); + return {context.mark_node(std::make_shared(x, alpha, lambda))}; }; } // namespace op