diff --git a/inference-engine/src/vpu/graph_transformer/src/middleend/passes/decompose_swish.cpp b/inference-engine/src/vpu/graph_transformer/src/middleend/passes/decompose_swish.cpp index f5c84d06277c41..95bfe5cbbc43a1 100644 --- a/inference-engine/src/vpu/graph_transformer/src/middleend/passes/decompose_swish.cpp +++ b/inference-engine/src/vpu/graph_transformer/src/middleend/passes/decompose_swish.cpp @@ -3,6 +3,7 @@ // #include +#include namespace vpu { @@ -31,16 +32,33 @@ void PassImpl::run(const Model& model) { const auto outputData = swish->output(0); const auto name = swish->name(); const auto& layer = swish->origLayer(); + const auto beta = swish->attrs().get("beta"); model->removeStage(swish); + auto sigmoidInput = inputData; - const auto sigmoidOutput = model->addNewData(inputData->name() + "@sigmoid", inputData->desc()); + if (beta != 1.0f) { + const auto betaDesc = DataDesc(inputData->desc()); + const auto betaConst = model->addConstData(inputData->name() + "@beta", betaDesc, + replicateContent(beta, betaDesc.totalDimSize(), betaDesc)); + const auto prodOutput = model->addNewData(inputData->name() + "@prod-x-beta", inputData->desc()); + _stageBuilder->addProdStage( + model, + name + "@prod-x-beta", + layer, + inputData, + betaConst, + prodOutput); + sigmoidInput = prodOutput; + } + const auto sigmoidDesc = inputData->desc(); + const auto sigmoidOutput = model->addNewData(inputData->name() + "@sigmoid", sigmoidDesc); _stageBuilder->addSigmoidStage( model, name + "@sigmoid", layer, - {inputData}, + {sigmoidInput}, {sigmoidOutput}); _stageBuilder->addProdStage( model, diff --git a/inference-engine/tests/functional/plugin/myriad/shared_tests_instances/single_layer_tests/activation.cpp b/inference-engine/tests/functional/plugin/myriad/shared_tests_instances/single_layer_tests/activation.cpp index 899458f1105189..2b52438e2614ec 100644 --- a/inference-engine/tests/functional/plugin/myriad/shared_tests_instances/single_layer_tests/activation.cpp +++ b/inference-engine/tests/functional/plugin/myriad/shared_tests_instances/single_layer_tests/activation.cpp @@ -24,7 +24,7 @@ const std::map>> activationTypes {Gelu, {}}, {Mish, {}}, {SoftPlus, {}}, - {Swish, {{1.0f}}} // {{0.05f}, {0.8f}, {1.0f}, {15.0f}}} #38489 + {Swish, {{0.05f}, {0.8f}, {1.0f}, {15.0f}}} }; std::map, std::vector>> basic = {