Skip to content

Commit

Permalink
Merge pull request openvinotoolkit#85 from eaidova/ea/selu
Browse files Browse the repository at this point in the history
aten::selu
  • Loading branch information
slyalin authored Jan 11, 2023
2 parents 636be1a + 10ed57c commit 5f4d382
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/frontends/pytorch/src/op/selu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset8.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_selu(NodeContext& context) {
auto x = context.get_input(0);
auto alpha =
context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.6732632423543772848170429916717}));
auto lambda =
context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.0507009873554804934193349852946}));
alpha = context.mark_node(std::make_shared<opset8::ConvertLike>(alpha, x));
lambda = context.mark_node(std::make_shared<opset8::ConvertLike>(lambda, x));
return {context.mark_node(std::make_shared<opset8::Selu>(x, alpha, lambda))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
3 changes: 3 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ OP_CONVERTER(translate_rsub);
OP_CONVERTER(translate_roll);
OP_CONVERTER(translate_rsqrt);
OP_CONVERTER(translate_select);
OP_CONVERTER(translate_selu);
OP_CONVERTER(translate_size);
OP_CONVERTER(translate_slice);
OP_CONVERTER(translate_softmax);
Expand Down Expand Up @@ -220,6 +221,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::roll", op::translate_roll},
{"aten::rsqrt", op::translate_rsqrt},
{"aten::select", op::translate_select},
{"aten::selu", op::translate_selu},
{"aten::selu_", op::inplace_op<op::translate_selu>},
{"aten::sigmoid", op::translate_1to1_match_1_inputs<opset8::Sigmoid>},
{"aten::silu", op::translate_1to1_match_1_inputs<opset8::Swish>},
{"aten::silu_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Swish>>},
Expand Down
33 changes: 33 additions & 0 deletions tests/layer_tests/pytorch_tests/test_selu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
from pytorch_layer_test_class import PytorchLayerTest


class TestSilu(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(1, 3, 224, 224).astype(np.float32),)

def create_model(self, inplace=False):

import torch
import torch.nn.functional as F

class aten_selu(torch.nn.Module):
def __init__(self, inplace):
super(aten_selu, self).__init__()
self.inplace = inplace

def forward(self, x):
return x, F.selu(x, inplace=self.inplace)

ref_net = None

return aten_selu(inplace), ref_net, "aten::selu" if not inplace else "aten::selu_"

@pytest.mark.nightly
@pytest.mark.parametrize("inplace", [True, False])
def test_silu(self, inplace, ie_device, precision, ir_version):
self._test(*self.create_model(inplace), ie_device, precision, ir_version)

0 comments on commit 5f4d382

Please sign in to comment.