diff --git a/src/frontends/pytorch/src/op/selu.cpp b/src/frontends/pytorch/src/op/selu.cpp new file mode 100644 index 00000000000000..cb7faaaef99efe --- /dev/null +++ b/src/frontends/pytorch/src/op/selu.cpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/opsets/opset8.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_selu(NodeContext& context) { + auto x = context.get_input(0); + auto alpha = + context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.6732632423543772848170429916717})); + auto lambda = + context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.0507009873554804934193349852946})); + alpha = context.mark_node(std::make_shared(alpha, x)); + lambda = context.mark_node(std::make_shared(lambda, x)); + return {context.mark_node(std::make_shared(x, alpha, lambda))}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 5060b8a2c84fa3..37384c72bb54bf 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -75,6 +75,7 @@ OP_CONVERTER(translate_rsub); OP_CONVERTER(translate_roll); OP_CONVERTER(translate_rsqrt); OP_CONVERTER(translate_select); +OP_CONVERTER(translate_selu); OP_CONVERTER(translate_size); OP_CONVERTER(translate_slice); OP_CONVERTER(translate_softmax); @@ -220,6 +221,8 @@ const std::map get_supported_ops() { {"aten::roll", op::translate_roll}, {"aten::rsqrt", op::translate_rsqrt}, {"aten::select", op::translate_select}, + {"aten::selu", op::translate_selu}, + {"aten::selu_", op::inplace_op}, {"aten::sigmoid", op::translate_1to1_match_1_inputs}, {"aten::silu", op::translate_1to1_match_1_inputs}, {"aten::silu_", op::inplace_op>}, diff --git a/tests/layer_tests/pytorch_tests/test_selu.py b/tests/layer_tests/pytorch_tests/test_selu.py new file mode 100644 index 00000000000000..854db09a40deaf --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_selu.py @@ -0,0 +1,33 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from pytorch_layer_test_class import PytorchLayerTest + + +class TestSilu(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(1, 3, 224, 224).astype(np.float32),) + + def create_model(self, inplace=False): + + import torch + import torch.nn.functional as F + + class aten_selu(torch.nn.Module): + def __init__(self, inplace): + super(aten_selu, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x, F.selu(x, inplace=self.inplace) + + ref_net = None + + return aten_selu(inplace), ref_net, "aten::selu" if not inplace else "aten::selu_" + + @pytest.mark.nightly + @pytest.mark.parametrize("inplace", [True, False]) + def test_silu(self, inplace, ie_device, precision, ir_version): + self._test(*self.create_model(inplace), ie_device, precision, ir_version) \ No newline at end of file