diff --git a/inference-engine/src/gna_plugin/gna_infer_request.hpp b/inference-engine/src/gna_plugin/gna_infer_request.hpp index 81a5acc0c8dfc0..737a25ae697da3 100644 --- a/inference-engine/src/gna_plugin/gna_infer_request.hpp +++ b/inference-engine/src/gna_plugin/gna_infer_request.hpp @@ -48,7 +48,16 @@ class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal { void InferImpl() override { // execute input pre-processing. execDataPreprocessing(_inputs); - plg->Infer(_inputs, _outputs); + // result returned from sync infer wait method + auto result = plg->Infer(_inputs, _outputs); + + // if result is false we are dealing with QoS feature + // if result is ok, next call to wait() will return Ok, if request not in gna_queue + if (!result) { + inferRequestIdx = -1; + } else { + inferRequestIdx = -2; + } } /** @@ -92,7 +101,13 @@ class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal { qosOK = plg->WaitFor(inferRequestIdx, millis_timeout); } - return qosOK ? InferenceEngine::OK : InferenceEngine::INFER_NOT_STARTED; + if (qosOK) { + return InferenceEngine::OK; + } else { + // need to preserve invalid state here to avoid next Wait() from clearing it + inferRequestIdx = -1; + return InferenceEngine::INFER_NOT_STARTED; + } } }; } // namespace GNAPluginNS diff --git a/inference-engine/src/gna_plugin/gna_plugin.cpp b/inference-engine/src/gna_plugin/gna_plugin.cpp index 5bb7f9f5d24948..e005236dedef15 100644 --- a/inference-engine/src/gna_plugin/gna_plugin.cpp +++ b/inference-engine/src/gna_plugin/gna_plugin.cpp @@ -1069,7 +1069,7 @@ void GNAPlugin::Reset() { graphCompiler.Reset(); } -void GNAPlugin::Infer(const InferenceEngine::Blob &input, InferenceEngine::Blob &output) { +bool GNAPlugin::Infer(const InferenceEngine::Blob &input, InferenceEngine::Blob &output) { BlobMap bmInput; BlobMap bmOutput; if (inputsDataMap.size() != 1) { @@ -1080,11 +1080,11 @@ void GNAPlugin::Infer(const InferenceEngine::Blob &input, InferenceEngine::Blob bmInput[inputsDataMap.begin()->first] = std::shared_ptr(const_cast(&input), [](Blob*){}); IE_ASSERT(!outputsDataMap.empty()); bmOutput[outputsDataMap.begin()->first] = std::shared_ptr(&output, [](Blob*){}); - Infer(bmInput, bmOutput); + return Infer(bmInput, bmOutput); } -void GNAPlugin::Infer(const InferenceEngine::BlobMap &input, InferenceEngine::BlobMap &result) { - Wait(QueueInference(input, result)); +bool GNAPlugin::Infer(const InferenceEngine::BlobMap &input, InferenceEngine::BlobMap &result) { + return Wait(QueueInference(input, result)); } Blob::Ptr GNAPlugin::GetOutputBlob(const std::string& name, InferenceEngine::Precision precision) { diff --git a/inference-engine/src/gna_plugin/gna_plugin.hpp b/inference-engine/src/gna_plugin/gna_plugin.hpp index 332472baa604d1..709eccf6028a83 100644 --- a/inference-engine/src/gna_plugin/gna_plugin.hpp +++ b/inference-engine/src/gna_plugin/gna_plugin.hpp @@ -96,7 +96,7 @@ class GNAPlugin : public InferenceEngine::IInferencePluginInternal, public std:: void LoadNetwork(InferenceEngine::ICNNNetwork &network); - void Infer(const InferenceEngine::BlobMap &input, InferenceEngine::BlobMap &result); + bool Infer(const InferenceEngine::BlobMap &input, InferenceEngine::BlobMap &result); void GetPerformanceCounts(std::map &perfMap); void AddExtension(InferenceEngine::IExtensionPtr extension) override; @@ -107,7 +107,7 @@ class GNAPlugin : public InferenceEngine::IInferencePluginInternal, public std:: InferenceEngine::ExecutableNetwork LoadNetwork(const InferenceEngine::ICNNNetwork &network, const std::map &config_map, InferenceEngine::RemoteContext::Ptr context) override { THROW_GNA_EXCEPTION << "Not implemented"; } - void Infer(const InferenceEngine::Blob &input, InferenceEngine::Blob &result); + bool Infer(const InferenceEngine::Blob &input, InferenceEngine::Blob &result); void SetCore(InferenceEngine::ICore*) noexcept override {} InferenceEngine::ICore* GetCore() const noexcept override {return nullptr;} void Reset();