Skip to content

Commit

Permalink
[CPU] Fixed MVN
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed May 15, 2021
1 parent fded2b3 commit 15ffe41
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 32 deletions.
61 changes: 30 additions & 31 deletions inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -691,37 +691,6 @@ MKLDNNMVNNode::MKLDNNMVNNode(const std::shared_ptr<ngraph::Node>& op, const mkld
epsMode_ = INSIDE_SQRT;
acrossChannels_ = mvnOp->get_across_channels();
}

transformTo5DCase(inDataShape);
}

void MKLDNNMVNNode::transformTo5DCase(const ngraph::Shape& shape) {
switch (shape.size()) {
// for 1 and 2 rank, if acrossChannels_ is true, adjust shape to fully vectorize under unified 5d procedure.
// otherwise there are not enough data in spatial dimension to process in one kernel.
case 1 : // C
if (acrossChannels_) {
shape5D = std::make_tuple(1, 1, 1, 1, shape[0]);
acrossChannels_ = false;
break;
} else {
shape5D = std::make_tuple(1, shape[0], 1, 1, 1);
break;
}
case 2 : // NC
if (acrossChannels_) {
shape5D = std::make_tuple(1, shape[0], 1, shape[1], 1);
acrossChannels_ = false;
break;
} else {
shape5D = std::make_tuple(shape[0], shape[1], 1, 1, 1);
break;
}
case 3 : { shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], 1); break; }
case 4 : { shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], shape[3]); break; }
case 5 : { shape5D = std::make_tuple(shape[0], shape[1], shape[2], shape[3], shape[4]); break; }
default : { IE_THROW() << "MVN layer with name '" << getName() << "' doesn't support planar layout with rank: " << shape.size(); }
}
}

void MKLDNNMVNNode::getSupportedDescriptors() {
Expand Down Expand Up @@ -850,6 +819,7 @@ void MKLDNNMVNNode::createPrimitive() {
jcp.normalize_variance = normalizeVariance_;
jcp.across_channels = acrossChannels_;
SizeVector in_dims = getParentEdgeAt(0)->getDims().ToSizeVector();
transformTo5DCase(in_dims);
int N = 0;
std::tie(N, jcp.C, jcp.D, jcp.H, jcp.W) = shape5D;

Expand Down Expand Up @@ -892,6 +862,35 @@ void MKLDNNMVNNode::createPrimitive() {
mvn_variance_kernel->create_ker();
}

void MKLDNNMVNNode::transformTo5DCase(const SizeVector& shape) {
switch (shape.size()) {
// for 1 and 2 rank, if acrossChannels_ is true, adjust shape to fully vectorize under unified 5d procedure.
// otherwise there are not enough data in spatial dimension to process in one kernel.
case 1 : // C
if (acrossChannels_) {
shape5D = std::make_tuple(1, 1, 1, 1, shape[0]);
acrossChannels_ = false;
break;
} else {
shape5D = std::make_tuple(1, shape[0], 1, 1, 1);
break;
}
case 2 : // NC
if (acrossChannels_) {
shape5D = std::make_tuple(1, shape[0], 1, shape[1], 1);
acrossChannels_ = false;
break;
} else {
shape5D = std::make_tuple(shape[0], shape[1], 1, 1, 1);
break;
}
case 3 : { shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], 1); break; }
case 4 : { shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], shape[3]); break; }
case 5 : { shape5D = std::make_tuple(shape[0], shape[1], shape[2], shape[3], shape[4]); break; }
default : { IE_THROW() << "MVN layer with name '" << getName() << "' doesn't support planar layout with rank: " << shape.size(); }
}
}

void MKLDNNMVNNode::setPostOps(mkldnn::primitive_attr &attr, bool initWeights) {
mkldnn::post_ops ops;
for (auto &node : fusedWith) {
Expand Down
2 changes: 1 addition & 1 deletion inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class MKLDNNMVNNode : public MKLDNNNode {

void setPostOps(mkldnn::primitive_attr &attr, bool initWeights = false);

void transformTo5DCase(const ngraph::Shape& shape);
void transformTo5DCase(const InferenceEngine::SizeVector& shape);

std::tuple<size_t, size_t, size_t, size_t, size_t> shape5D;

Expand Down

0 comments on commit 15ffe41

Please sign in to comment.