Skip to content

Commit

Permalink
convert constants to input type
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Jan 11, 2023
1 parent d6068b9 commit 10ed57c
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/frontends/pytorch/src/op/selu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ namespace op {

OutputVector translate_selu(NodeContext& context) {
auto x = context.get_input(0);
auto alpha = context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.6732632423543772848170429916717}));
auto lambda = context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.0507009873554804934193349852946}));
auto x_f64 = context.mark_node(std::make_shared<opset8::Convert>(x, element::f64));
auto res = context.mark_node(std::make_shared<opset8::Selu>(x, alpha, lambda));
return {context.mark_node(std::make_shared<opset8::ConvertLike>(res, x)};
auto alpha =
context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.6732632423543772848170429916717}));
auto lambda =
context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.0507009873554804934193349852946}));
alpha = context.mark_node(std::make_shared<opset8::ConvertLike>(alpha, x));
lambda = context.mark_node(std::make_shared<opset8::ConvertLike>(lambda, x));
return {context.mark_node(std::make_shared<opset8::Selu>(x, alpha, lambda))};
};

} // namespace op
Expand Down

0 comments on commit 10ed57c

Please sign in to comment.