Skip to content

Commit

Permalink
[CPU] Added MVN fusion for case with constants inside (#5644)
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova authored May 21, 2021
1 parent 5e4c3c4 commit 57d49f3
Show file tree
Hide file tree
Showing 11 changed files with 482 additions and 38 deletions.
63 changes: 31 additions & 32 deletions inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -691,37 +691,6 @@ MKLDNNMVNNode::MKLDNNMVNNode(const std::shared_ptr<ngraph::Node>& op, const mkld
epsMode_ = INSIDE_SQRT;
acrossChannels_ = mvnOp->get_across_channels();
}

transformTo5DCase(inDataShape);
}

void MKLDNNMVNNode::transformTo5DCase(const ngraph::Shape& shape) {
switch (shape.size()) {
// for 1 and 2 rank, if acrossChannels_ is true, adjust shape to fully vectorize under unified 5d procedure.
// otherwise there are not enough data in spatial dimension to process in one kernel.
case 1 : // C
if (acrossChannels_) {
shape5D = std::make_tuple(1, 1, 1, 1, shape[0]);
acrossChannels_ = false;
break;
} else {
shape5D = std::make_tuple(1, shape[0], 1, 1, 1);
break;
}
case 2 : // NC
if (acrossChannels_) {
shape5D = std::make_tuple(1, shape[0], 1, shape[1], 1);
acrossChannels_ = false;
break;
} else {
shape5D = std::make_tuple(shape[0], shape[1], 1, 1, 1);
break;
}
case 3 : { shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], 1); break; }
case 4 : { shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], shape[3]); break; }
case 5 : { shape5D = std::make_tuple(shape[0], shape[1], shape[2], shape[3], shape[4]); break; }
default : { IE_THROW() << "MVN layer with name '" << getName() << "' doesn't support planar layout with rank: " << shape.size(); }
}
}

void MKLDNNMVNNode::getSupportedDescriptors() {
Expand Down Expand Up @@ -840,6 +809,8 @@ void MKLDNNMVNNode::createPrimitive() {
if (getSelectedPrimitiveDescriptor() == nullptr)
IE_THROW() << "Preferable primitive descriptor is not set.";

const SizeVector in_dims = getParentEdgeAt(0)->getDims().ToSizeVector();
transformTo5DCase(in_dims);
auto selectedPD = getSelectedPrimitiveDescriptor();
auto jcp = jit_mvn_config_params();
jcp.src_prc = selectedPD->getConfig().inConfs[0].desc.getPrecision();
Expand All @@ -849,7 +820,6 @@ void MKLDNNMVNNode::createPrimitive() {
jcp.planar_layout = MKLDNNMemory::GetPlainLayout(getChildEdgeAt(0)->getDims()) == selectedPD->getConfig().inConfs[0].desc.getLayout();
jcp.normalize_variance = normalizeVariance_;
jcp.across_channels = acrossChannels_;
SizeVector in_dims = getParentEdgeAt(0)->getDims().ToSizeVector();
int N = 0;
std::tie(N, jcp.C, jcp.D, jcp.H, jcp.W) = shape5D;

Expand Down Expand Up @@ -892,6 +862,35 @@ void MKLDNNMVNNode::createPrimitive() {
mvn_variance_kernel->create_ker();
}

void MKLDNNMVNNode::transformTo5DCase(const SizeVector& shape) {
switch (shape.size()) {
// for 1 and 2 rank, if acrossChannels_ is true, adjust shape to fully vectorize under unified 5d procedure.
// otherwise there are not enough data in spatial dimension to process in one kernel.
case 1 : // C
if (acrossChannels_) {
shape5D = std::make_tuple(1, 1, 1, 1, shape[0]);
acrossChannels_ = false;
break;
} else {
shape5D = std::make_tuple(1, shape[0], 1, 1, 1);
break;
}
case 2 : // NC
if (acrossChannels_) {
shape5D = std::make_tuple(1, shape[0], 1, shape[1], 1);
acrossChannels_ = false;
break;
} else {
shape5D = std::make_tuple(shape[0], shape[1], 1, 1, 1);
break;
}
case 3 : { shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], 1); break; }
case 4 : { shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], shape[3]); break; }
case 5 : { shape5D = std::make_tuple(shape[0], shape[1], shape[2], shape[3], shape[4]); break; }
default : { IE_THROW() << "MVN layer with name '" << getName() << "' doesn't support planar layout with rank: " << shape.size(); }
}
}

void MKLDNNMVNNode::setPostOps(mkldnn::primitive_attr &attr, bool initWeights) {
mkldnn::post_ops ops;
for (auto &node : fusedWith) {
Expand Down
2 changes: 1 addition & 1 deletion inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class MKLDNNMVNNode : public MKLDNNNode {

void setPostOps(mkldnn::primitive_attr &attr, bool initWeights = false);

void transformTo5DCase(const ngraph::Shape& shape);
void transformTo5DCase(const InferenceEngine::SizeVector& shape);

std::tuple<size_t, size_t, size_t, size_t, size_t> shape5D;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API MVNFusion;
class TRANSFORMATIONS_API MVNFusion;
class TRANSFORMATIONS_API MVNFusionWithoutConstants;
class TRANSFORMATIONS_API MVNFusionWithConstantsInside;

} // namespace pass
} // namespace ngraph
Expand All @@ -26,8 +28,32 @@ namespace pass {
* @brief MVNFusion transformation replaces group of
* operations: (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps) to MVN op.
*/
class ngraph::pass::MVNFusion : public ngraph::pass::MatcherPass {
class ngraph::pass::MVNFusionWithoutConstants : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
MVNFusion();
MVNFusionWithoutConstants();
};

/**
* @ingroup ie_transformation_common_api
* @brief MVNFusion transformation replaces group of
* operations: gamma * (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps) + beta to MVN op.
*/
class ngraph::pass::MVNFusionWithConstantsInside : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
MVNFusionWithConstantsInside();
};

/**
* @ingroup ie_transformation_common_api
* @brief MVNFusion transformation replaces various sub-graphs with a MVN op.
*/
class ngraph::pass::MVNFusion: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
MVNFusion() {
add_matcher<ngraph::pass::MVNFusionWithoutConstants>();
add_matcher<ngraph::pass::MVNFusionWithConstantsInside>();
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ std::function<bool(ngraph::Output<ngraph::Node>)> value_is_equal_to(const std::v
};
}

ngraph::pass::MVNFusion::MVNFusion() {
MATCHER_SCOPE(MVNFusion);
NGRAPH_RTTI_DEFINITION(ngraph::pass::MVNFusionWithoutConstants, "MVNFusionWithoutConstants", 0);

ngraph::pass::MVNFusionWithoutConstants::MVNFusionWithoutConstants() {
MATCHER_SCOPE(MVNFusionWithoutConstants);
// Detect MVN decomposition pattern:
// (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps)
auto x = pattern::any_input();
Expand Down Expand Up @@ -188,3 +190,112 @@ ngraph::pass::MVNFusion::MVNFusion() {
auto m = std::make_shared<ngraph::pattern::Matcher>(powerMulOrDiv, matcher_name);
register_matcher(m, matcher_pass_callback);
}

NGRAPH_RTTI_DEFINITION(ngraph::pass::MVNFusionWithConstantsInside, "MVNFusionWithConstantsInside", 0);

ngraph::pass::MVNFusionWithConstantsInside::MVNFusionWithConstantsInside() {
MATCHER_SCOPE(MVNFusionWithConstantsInside);
// Detect MVN decomposition pattern:
// (x - ReduceMean(x, axes)) * gamma / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps) + beta
auto x = pattern::any_input();

// (x - ReduceMean(x, axes))^2
// `------mean1-------'
auto mean1_axes = pattern::wrap_type<opset6::Constant>();
auto mean1 = pattern::wrap_type<opset6::ReduceMean>({ x, mean1_axes });

// (x - ReduceMean(x, axes))^2
// `-squared_difference------'
auto squared_difference = pattern::wrap_type<opset6::SquaredDifference>({ x, mean1 });

// 1 / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
// `---mean2--------------------------------'
auto mean2_axes = pattern::wrap_type<opset6::Constant>();
auto mean2 = pattern::wrap_type<opset6::ReduceMean>({ squared_difference, mean2_axes });

// 1 / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
// `------------------------------------------add--'
auto eps = pattern::wrap_type<opset6::Constant>();
auto add_eps = pattern::wrap_type<opset6::Add>({ mean2, eps });

// 1 / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
// `-power-------------------------------------------------'
auto const_0_5 = pattern::wrap_type<opset6::Constant>(value_is_equal_to<float>({-0.5}));
auto power = pattern::wrap_type<opset6::Power>({ add_eps, const_0_5 });

// gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
// `---mul1----------------------------------------------------'
auto gamma = pattern::wrap_type<opset6::Constant>();
auto mul1 = pattern::wrap_type<opset6::Multiply>({ power, gamma });

// x * gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
// `---mul2--------------------------------------------------------'
auto mul2 = pattern::wrap_type<opset6::Multiply>({ x, mul1 });

// ReduceMean(x, axes) * gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps) - beta
// `-------------------mul3----------------------------------------------------------'
auto mul3 = pattern::wrap_type<opset6::Multiply>({ mul1, mean1 });

// beta - ReduceMean(x, axes) * gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
// `---sub-----------------------------------------------------------------------------------'
auto beta = pattern::wrap_type<opset6::Constant>();
auto sub = pattern::wrap_type<opset6::Subtract>({ beta, mul3 });

// Final Add
// x * gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps) +
// beta - ReduceMean(x, axes) * gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps) =
// gamma * (x - ReduceMean(x, axes)) / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps) + beta
auto add = pattern::wrap_type<opset6::Add>({ mul2, sub });

ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto x_output = pattern_to_output.at(x);

auto const_0_5_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(const_0_5).get_node_shared_ptr());
auto const_gamma_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(gamma).get_node_shared_ptr());
auto const_beta_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(beta).get_node_shared_ptr());
auto const_eps_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(eps).get_node_shared_ptr());
if (!const_0_5_node || !const_beta_node || !const_gamma_node || !const_eps_node) {
return false;
}

float eps_value;
bool valid_constant_values = op::util::has_constant_value<float>(const_0_5_node, -0.5) && op::util::get_single_value(const_eps_node, eps_value);
if (!valid_constant_values) {
return false;
}

auto axes_1_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(mean1_axes).get_node_shared_ptr());
auto axes_2_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(mean2_axes).get_node_shared_ptr());
if (!axes_1_node || !axes_2_node) {
return false;
}

auto axes_1_value = axes_1_node->cast_vector<int64_t>();
auto axes_2_value = axes_2_node->cast_vector<int64_t>();
if (axes_1_value != axes_2_value) {
return false;
}

auto mvn = std::make_shared<ngraph::opset6::MVN>(x_output, axes_1_node, true, eps_value, op::MVNEpsMode::INSIDE_SQRT);
auto mul_gamma = std::make_shared<ngraph::opset6::Multiply>(mvn, const_gamma_node);
auto add_beta = std::make_shared<ngraph::opset6::Add>(mul_gamma, const_beta_node);

ngraph::copy_runtime_info({ pattern_to_output.at(mean1).get_node_shared_ptr(),
pattern_to_output.at(squared_difference).get_node_shared_ptr(),
pattern_to_output.at(add_eps).get_node_shared_ptr(),
pattern_to_output.at(power).get_node_shared_ptr(),
pattern_to_output.at(mul1).get_node_shared_ptr(),
pattern_to_output.at(mul2).get_node_shared_ptr(),
pattern_to_output.at(mul3).get_node_shared_ptr(),
pattern_to_output.at(sub).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr() },
{ mvn, const_gamma_node, mul_gamma, const_beta_node, add_beta });
add_beta->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::replace_node(m.get_match_root(), add_beta);
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(add, matcher_name);
register_matcher(m, matcher_pass_callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,49 @@ TEST(TransformationTests, MVNFusionTestAltDivInsideSqrt) {
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

TEST(TransformationTests, MVNFusionTestWithParametersInside) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224 });
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 1 }, { 2 });
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes, true);
auto squared_difference = std::make_shared<ngraph::opset6::SquaredDifference>(input, mean1);
auto mean2_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 1 }, { 2 });
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(squared_difference, mean2_axes, true);
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
auto add_eps = std::make_shared<ngraph::opset6::Add>(mean2, eps);
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -0.5 });
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(add_eps, const_0_5);
auto gamma = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1 });
auto mul_gamma = std::make_shared<ngraph::opset6::Multiply>(power_sqrt, gamma);
auto mul1 = std::make_shared<ngraph::opset6::Multiply>(input, mul_gamma);
auto mul2 = std::make_shared<ngraph::opset6::Multiply>(mul_gamma, mean1);
auto beta = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
auto sub = std::make_shared<ngraph::opset6::Subtract>(beta, mul2);
auto add = std::make_shared<ngraph::opset6::Add>(mul1, sub);

f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input });

ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MVNFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}

{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224 });
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 1 }, { 2 });
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::INSIDE_SQRT);
auto gamma = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1 });
auto mul_gamma = std::make_shared<ngraph::opset6::Multiply>(mvn, gamma);
auto beta = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
auto add = std::make_shared<ngraph::opset6::Add>(mul_gamma, beta);

f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input });
}

auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
using namespace LayerTestsDefinitions;

const std::vector<std::vector<size_t>> inputShapes = {
{8},
{1, 16},
{3, 19},
{1, 32, 17},
{1, 37, 9},
{1, 16, 5, 8},
Expand Down
Loading

0 comments on commit 57d49f3

Please sign in to comment.