From 10ed57c07a438f98de561d1c1af1f3e7c569ff54 Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 11 Jan 2023 17:11:41 +0400 Subject: [PATCH] convert constants to input type --- src/frontends/pytorch/src/op/selu.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/frontends/pytorch/src/op/selu.cpp b/src/frontends/pytorch/src/op/selu.cpp index 5cd89328a42b55..cb7faaaef99efe 100644 --- a/src/frontends/pytorch/src/op/selu.cpp +++ b/src/frontends/pytorch/src/op/selu.cpp @@ -13,11 +13,13 @@ namespace op { OutputVector translate_selu(NodeContext& context) { auto x = context.get_input(0); - auto alpha = context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.6732632423543772848170429916717})); - auto lambda = context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.0507009873554804934193349852946})); - auto x_f64 = context.mark_node(std::make_shared(x, element::f64)); - auto res = context.mark_node(std::make_shared(x, alpha, lambda)); - return {context.mark_node(std::make_shared(res, x)}; + auto alpha = + context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.6732632423543772848170429916717})); + auto lambda = + context.mark_node(opset8::Constant::create(element::f64, Shape{}, {1.0507009873554804934193349852946})); + alpha = context.mark_node(std::make_shared(alpha, x)); + lambda = context.mark_node(std::make_shared(lambda, x)); + return {context.mark_node(std::make_shared(x, alpha, lambda))}; }; } // namespace op