Skip to content

Commit

Permalink
[CPU] Added Mish activation (openvinotoolkit#1555)
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova authored Aug 10, 2020
1 parent 1eac9e3 commit 50e003c
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 8 deletions.
12 changes: 7 additions & 5 deletions inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,8 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndActivation(MKLDNNGraph &graph) {
return activationNode &&
(activationNode->getAlgorithm() == eltwise_relu ||
(conv->getCnnLayer()->precision == Precision::FP32 &&
isOneOf(activationNode->getAlgorithm(), {eltwise_elu, eltwise_logistic, eltwise_bounded_relu, eltwise_clamp, eltwise_swish})));
isOneOf(activationNode->getAlgorithm(), {eltwise_elu, eltwise_logistic, eltwise_bounded_relu, eltwise_clamp,
eltwise_swish, eltwise_mish})));
};

for (int i = 0; i < graphNodes.size(); i++) {
Expand Down Expand Up @@ -1187,7 +1188,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndSimpleOperation(MKLDNNGraph &graph)
THROW_IE_EXCEPTION << "Cannot get activation layer " << node->getName();

return isOneOf(activationNode->getAlgorithm(), {eltwise_relu, eltwise_elu, eltwise_logistic, eltwise_bounded_relu,
eltwise_clamp, eltwise_swish});
eltwise_clamp, eltwise_swish, eltwise_mish});
}

return false;
Expand Down Expand Up @@ -1431,7 +1432,8 @@ void MKLDNNGraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(MKLDNNG
return activationNode &&
(activationNode->getAlgorithm() == eltwise_relu ||
(conv->getCnnLayer()->precision == Precision::FP32 &&
isOneOf(activationNode->getAlgorithm(), {eltwise_elu, eltwise_logistic, eltwise_bounded_relu, eltwise_clamp, eltwise_swish})));
isOneOf(activationNode->getAlgorithm(), {eltwise_elu, eltwise_logistic, eltwise_bounded_relu, eltwise_clamp,
eltwise_swish, eltwise_mish})));
#else
return false;
#endif
Expand Down Expand Up @@ -1781,7 +1783,7 @@ void MKLDNNGraphOptimizer::FuseNormalizeAndSimpleOperation(MKLDNNGraph &graph) {
if (activationNode == nullptr)
THROW_IE_EXCEPTION << "Cannot get activation layer " << node->getName();
return isOneOf(activationNode->getAlgorithm(), {eltwise_relu, eltwise_gelu, eltwise_elu, eltwise_logistic,
eltwise_bounded_relu, eltwise_clamp, eltwise_tanh, eltwise_swish, eltwise_linear, eltwise_abs,
eltwise_bounded_relu, eltwise_clamp, eltwise_tanh, eltwise_swish, eltwise_mish, eltwise_linear, eltwise_abs,
eltwise_square, eltwise_sqrt});
}
return false;
Expand Down Expand Up @@ -1893,7 +1895,7 @@ void MKLDNNGraphOptimizer::FuseEltwiseAndSimple(MKLDNNGraph &graph) {
if (activationNode == nullptr)
THROW_IE_EXCEPTION << "Cannot get activation layer " << node->getName();
return isOneOf(activationNode->getAlgorithm(), {eltwise_relu, eltwise_elu, eltwise_logistic, eltwise_bounded_relu,
eltwise_clamp, eltwise_swish});
eltwise_clamp, eltwise_swish, eltwise_mish});
}

return false;
Expand Down
1 change: 1 addition & 0 deletions inference-engine/src/mkldnn_plugin/mkldnn_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ static const InferenceEngine::details::caseless_unordered_map<std::string, Type>
{ "Activation", Activation },
{ "Clamp", Activation },
{ "Swish", Activation },
{ "Mish", Activation },
{ "ScaleShift", Depthwise },
{ "PReLU", Depthwise },
{ "Norm", Lrn },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ caseless_map<std::string, std::function<void(GenericLayer*, mkldnn::algorithm&,
beta = 0.0f;
algorithm = eltwise_swish;
}},
{"mish", [](GenericLayer* activationLayer, mkldnn::algorithm& algorithm, float& alpha, float& beta) {
alpha = 0.0f;
beta = 0.0f;
algorithm = eltwise_mish;
}},
};

MKLDNNActivationNode::MKLDNNActivationNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ std::vector<std::string> disabledTestPatterns() {
R"(.*ActivationLayerTest.*Ceiling.*)",
// TODO: Issue: 32032
R"(.*ActivationParamLayerTest.*)",
// TODO: Issue: 32959
R"(.*ActivationLayerTest.*Mish.*)",
// TODO: Issue: 30999 (Implement Interpolate reference in NGraph)
R"(.*InterpolateLayerTest.*)"
};
Expand Down

0 comments on commit 50e003c

Please sign in to comment.