From a0312a3fa98d32bd5366653aefbab03bc6d21912 Mon Sep 17 00:00:00 2001 From: chenhuwa Date: Wed, 3 Feb 2021 19:48:08 +0800 Subject: [PATCH] model validation bug fix --- .../mkldnn_plugin/nodes/mkldnn_mvn_node.cpp | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp index 80819f15aff91b..8e163d79486d77 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp @@ -605,8 +605,6 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator MKLDNNMVNNode::MKLDNNMVNNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache) : MKLDNNNode(layer, eng, cache) {} -void (MKLDNNMVNNode::*mvn_executor)(const uint8_t *, uint8_t *, const InferenceEngine::SizeVector &) = nullptr; - void MKLDNNMVNNode::getSupportedDescriptors() { if (!descs.empty()) return; @@ -798,15 +796,6 @@ void MKLDNNMVNNode::createPrimitive() { if (mvn_variance_kernel) mvn_variance_kernel->create_ker(); - - if (mayiuse(cpu::x64::sse41)) { - if (jcp.planar_layout) - mvn_executor = &MKLDNNMVNNode::mvn_pln; - else - mvn_executor = &MKLDNNMVNNode::mvn_blk; - } else { - mvn_executor = &MKLDNNMVNNode::mvn_ref; - } } void MKLDNNMVNNode::setPostOps(mkldnn::primitive_attr &attr, bool initWeights) { @@ -835,7 +824,13 @@ void MKLDNNMVNNode::execute(mkldnn::stream strm) { uint8_t *dst_data = reinterpret_cast(dstMemPtr->GetPtr()); uint8_t *src_data = reinterpret_cast(srcMemPtr->GetPtr()); - (this->*mvn_executor)(src_data, dst_data, getParentEdgeAt(0)->getDesc().getDims()); + auto dim = getParentEdgeAt(0)->getDesc().getDims(); + Layout layout = getParentEdgeAt(0)->getDesc().getLayout(); + if (layout == C || layout == NC || layout == CHW || layout == NCHW || layout == NCDHW) { + mvn_pln(src_data, dst_data, dim); + } else { + mvn_blk(src_data, dst_data, dim); + } } void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data, const SizeVector& dims) {