Skip to content

Commit

Permalink
[CPU] Dynamic shape support for MKLDNNMVNNode::isSupportedOperation (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Evgenya Stepyreva authored Jun 4, 2021
1 parent 0db9d3e commit 62c37a8
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,9 +604,13 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator

bool MKLDNNMVNNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
try {
const auto& inDataShapeSize = op->input_value(0).get_shape().size();
if (inDataShapeSize < 1 || inDataShapeSize > 5) {
errorMessage = "First input accepts ranks from 1 to 5. Actual: " + std::to_string(inDataShapeSize);
if (op->get_output_partial_shape(0).rank().is_dynamic()) {
errorMessage = "Unsupported dynamic input rank.";
return false;
}
const auto& inDataRank = op->get_output_partial_shape(0).rank().get_length();
if (inDataRank < 1 || inDataRank > 5) {
errorMessage = "First input accepts ranks from 1 to 5. Actual: " + std::to_string(inDataRank);
return false;
}

Expand All @@ -632,21 +636,20 @@ bool MKLDNNMVNNode::isSupportedOperation(const std::shared_ptr<const ngraph::Nod
// 4D: axes: [1,2,3], [2,3]
// 5D: axes: [1,2,3,4], [2,3,4]
auto axesVal = axesOp->cast_vector<int>();
auto& mvnShape = mvnOp->get_output_shape(0);
for (int& axe : axesVal)
axe = axe < 0 ? axe + mvnShape.size() : axe;
axe = axe < 0 ? axe + inDataRank : axe;
std::sort(axesVal.begin(), axesVal.end());
if (mvnShape.size() == 1) {
if (inDataRank == 1) {
if (axesVal.size() != 1 || axesVal[0] != 0) {
errorMessage = "Unsupported axes.";
return false;
}
} else {
if (mvnShape.size() > 5 || (mvnShape.size() != axesVal.size() + 1 && mvnShape.size() != axesVal.size() + 2)) {
if (inDataRank > 5 || (inDataRank != axesVal.size() + 1 && inDataRank != axesVal.size() + 2)) {
errorMessage = "Unsupported axes.";
return false;
}
int value = mvnShape.size() - 1;
int value = inDataRank - 1;
for (int i = axesVal.size() - 1; i >= 0; i--, value--) {
if (axesVal[i] != value) {
errorMessage = "Unsupported axes.";
Expand Down

0 comments on commit 62c37a8

Please sign in to comment.