From 2942a54b5794e52c6e0237754070009a4a3f48e1 Mon Sep 17 00:00:00 2001 From: Maksim Kutakov Date: Tue, 25 May 2021 18:56:24 +0300 Subject: [PATCH] GetBlob performance fix --- inference-engine/src/mkldnn_plugin/mkldnn_graph.h | 8 ++++++++ .../src/mkldnn_plugin/mkldnn_infer_request.cpp | 13 ++++++------- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_graph.h b/inference-engine/src/mkldnn_plugin/mkldnn_graph.h index 3811c8f8b70d2e..822cdeb387b6a6 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_graph.h +++ b/inference-engine/src/mkldnn_plugin/mkldnn_graph.h @@ -79,6 +79,14 @@ class MKLDNNGraph { return outputNodesMap; } + bool hasInputWithName(const std::string& name) const { + return inputNodesMap.count(name); + } + + bool hasOutputWithName(const std::string& name) const { + return outputNodesMap.count(name); + } + mkldnn::engine getEngine() const { return eng; } diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_infer_request.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_infer_request.cpp index 392226d06d4eba..2496ea27fb6913 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_infer_request.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_infer_request.cpp @@ -212,10 +212,9 @@ InferenceEngine::Blob::Ptr MKLDNNPlugin::MKLDNNInferRequest::GetBlob(const std:: InferenceEngine::Blob::Ptr data; - InferenceEngine::BlobMap blobs; - graph->getInputBlobs(blobs); - - if (blobs.find(name) != blobs.end()) { + if (graph->hasInputWithName(name)) { + InferenceEngine::BlobMap blobs; + graph->getInputBlobs(blobs); // ROI blob is returned only if it was set previously. auto it = _preProcData.find(name); if (it != _preProcData.end()) { @@ -245,9 +244,9 @@ InferenceEngine::Blob::Ptr MKLDNNPlugin::MKLDNNInferRequest::GetBlob(const std:: checkBlob(data, name, true); } - blobs.clear(); - graph->getOutputBlobs(blobs); - if (blobs.find(name) != blobs.end()) { + if (graph->hasOutputWithName(name)) { + InferenceEngine::BlobMap blobs; + graph->getOutputBlobs(blobs); if (_outputs.find(name) == _outputs.end()) { if (!data) { InferenceEngine::TensorDesc desc = _networkOutputs[name]->getTensorDesc();