Skip to content

Commit

Permalink
Propogate DEVICE_ID for functions working with RemoteContext (#3109)
Browse files Browse the repository at this point in the history
* Propogate DEVICE_ID for functions working with RemoteContext

* More fixes for RemoteContext

* Fixed tests compilation with VariableState
  • Loading branch information
ilya-lavrenov authored Nov 13, 2020
1 parent d87fdbe commit fec3bc0
Show file tree
Hide file tree
Showing 14 changed files with 62 additions and 38 deletions.
27 changes: 22 additions & 5 deletions inference-engine/include/gpu/gpu_ocl_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,34 @@
/**
* @brief Definitions required by Khronos headers
*/
#define CL_HPP_ENABLE_EXCEPTIONS
#define CL_HPP_MINIMUM_OPENCL_VERSION 120
#define CL_HPP_TARGET_OPENCL_VERSION 120

#if defined __GNUC__
#ifndef CL_HPP_ENABLE_EXCEPTIONS
# define CL_HPP_ENABLE_EXCEPTIONS
#endif

#ifdef CL_HPP_MINIMUM_OPENCL_VERSION
# if CL_HPP_MINIMUM_OPENCL_VERSION <= 120
# error "CL_HPP_MINIMUM_OPENCL_VERSION must be >= 120"
# endif
#else
# define CL_HPP_MINIMUM_OPENCL_VERSION 120
#endif

#ifdef CL_HPP_TARGET_OPENCL_VERSION
# if CL_HPP_TARGET_OPENCL_VERSION <= 120
# error "CL_HPP_TARGET_OPENCL_VERSION must be >= 120"
# endif
#else
# define CL_HPP_TARGET_OPENCL_VERSION 120
#endif

#ifdef __GNUC__
# pragma GCC diagnostic push
# pragma GCC system_header
#endif

#include <CL/cl2.hpp>

#if defined __GNUC__
#ifdef __GNUC__
# pragma GCC diagnostic pop
#endif
4 changes: 2 additions & 2 deletions inference-engine/src/cldnn_engine/cldnn_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,9 @@ RemoteContext::Ptr clDNNEngine::CreateContext(const ParamMap& params) {
}
}

RemoteContext::Ptr clDNNEngine::GetDefaultContext() {
RemoteContext::Ptr clDNNEngine::GetDefaultContext(const ParamMap& params) {
if (nullptr == m_defaultContext) {
m_defaultContext.reset(new CLDNNRemoteCLContext(shared_from_this(), ParamMap(), _impl->m_config));
m_defaultContext.reset(new CLDNNRemoteCLContext(shared_from_this(), params, _impl->m_config));
}
return std::dynamic_pointer_cast<RemoteContext>(m_defaultContext);
}
Expand Down
2 changes: 1 addition & 1 deletion inference-engine/src/cldnn_engine/cldnn_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class clDNNEngine : public InferenceEngine::InferencePluginInternal,
const std::map<std::string, std::string>& config) const override;

InferenceEngine::RemoteContext::Ptr CreateContext(const InferenceEngine::ParamMap& params) override;
InferenceEngine::RemoteContext::Ptr GetDefaultContext() override;
InferenceEngine::RemoteContext::Ptr GetDefaultContext(const ParamMap& params) override;
};

}; // namespace CLDNNPlugin
2 changes: 1 addition & 1 deletion inference-engine/src/gna_plugin/gna_plugin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class GNAPlugin : public InferenceEngine::IInferencePlugin {
InferenceEngine::Parameter GetMetric(const std::string& name,
const std::map<std::string, InferenceEngine::Parameter> & options) const override;
InferenceEngine::RemoteContext::Ptr CreateContext(const InferenceEngine::ParamMap& params) override { THROW_GNA_EXCEPTION << "Not implemented"; }
InferenceEngine::RemoteContext::Ptr GetDefaultContext() override { THROW_GNA_EXCEPTION << "Not implemented"; }
InferenceEngine::RemoteContext::Ptr GetDefaultContext(const InferenceEngine::ParamMap&) override { THROW_GNA_EXCEPTION << "Not implemented"; }

void Wait(uint32_t sync, InferenceEngine::Blob &result) { THROW_GNA_EXCEPTION << "Not implemented"; }

Expand Down
40 changes: 16 additions & 24 deletions inference-engine/src/inference_engine/ie_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -598,45 +598,37 @@ void Core::AddExtension(const IExtensionPtr& extension) {
ExecutableNetwork Core::LoadNetwork(const CNNNetwork& network, RemoteContext::Ptr context,
const std::map<std::string, std::string>& config) {
OV_ITT_SCOPED_TASK(itt::domains::IE, "Core::LoadNetwork");
std::map<std::string, std::string> config_ = config;

if (context == nullptr) {
THROW_IE_EXCEPTION << "Remote context is null";
}

std::string deviceName_ = context->getDeviceName();
DeviceIDParser device(deviceName_);
std::string deviceName = device.getDeviceName();

return _impl->GetCPPPluginByName(deviceName).LoadNetwork(network, config_, context);
auto parsed = parseDeviceNameIntoConfig(context->getDeviceName(), config);
return _impl->GetCPPPluginByName(parsed._deviceName).LoadNetwork(network, parsed._config, context);
}

RemoteContext::Ptr Core::CreateContext(const std::string& deviceName_, const ParamMap& params) {
if (deviceName_.find("HETERO") == 0) {
THROW_IE_EXCEPTION << "HETERO device does not support remote contexts";
RemoteContext::Ptr Core::CreateContext(const std::string& deviceName, const ParamMap& params) {
if (deviceName.find("HETERO") == 0) {
THROW_IE_EXCEPTION << "HETERO device does not support remote context";
}
if (deviceName_.find("MULTI") == 0) {
THROW_IE_EXCEPTION << "MULTI device does not support remote contexts";
if (deviceName.find("MULTI") == 0) {
THROW_IE_EXCEPTION << "MULTI device does not support remote context";
}

DeviceIDParser device(deviceName_);
std::string deviceName = device.getDeviceName();

return _impl->GetCPPPluginByName(deviceName).CreateContext(params);
auto parsed = parseDeviceNameIntoConfig(deviceName, params);
return _impl->GetCPPPluginByName(parsed._deviceName).CreateContext(parsed._config);
}

RemoteContext::Ptr Core::GetDefaultContext(const std::string& deviceName_) {
if (deviceName_.find("HETERO") == 0) {
THROW_IE_EXCEPTION << "HETERO device does not support remote contexts";
RemoteContext::Ptr Core::GetDefaultContext(const std::string& deviceName) {
if (deviceName.find("HETERO") == 0) {
THROW_IE_EXCEPTION << "HETERO device does not support remote context";
}
if (deviceName_.find("MULTI") == 0) {
THROW_IE_EXCEPTION << "MULTI device does not support remote contexts";
if (deviceName.find("MULTI") == 0) {
THROW_IE_EXCEPTION << "MULTI device does not support remote context";
}

DeviceIDParser device(deviceName_);
std::string deviceName = device.getDeviceName();

return _impl->GetCPPPluginByName(deviceName).GetDefaultContext();
auto parsed = parseDeviceNameIntoConfig(deviceName, ParamMap());
return _impl->GetCPPPluginByName(parsed._deviceName).GetDefaultContext(parsed._config);
}

void Core::AddExtension(IExtensionPtr extension, const std::string& deviceName_) {
Expand Down
4 changes: 2 additions & 2 deletions inference-engine/src/inference_engine/ie_plugin_cpp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ class InferencePlugin {
CALL_STATEMENT(return actual->CreateContext(params));
}

RemoteContext::Ptr GetDefaultContext() {
CALL_STATEMENT(return actual->GetDefaultContext());
RemoteContext::Ptr GetDefaultContext(const ParamMap& params) {
CALL_STATEMENT(return actual->GetDefaultContext(params));
}

ExecutableNetwork ImportNetwork(std::istream& networkModel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class InferencePluginInternal : public IInferencePlugin {
THROW_IE_EXCEPTION << NOT_IMPLEMENTED_str;
}

RemoteContext::Ptr GetDefaultContext() override {
RemoteContext::Ptr GetDefaultContext(const ParamMap& /*params*/) override {
THROW_IE_EXCEPTION << NOT_IMPLEMENTED_str;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,10 @@ class IInferencePlugin : public details::IRelease,

/**
* @brief Provides a default remote context instance if supported by a plugin
* @param[in] params The map of parameters
* @return The default context.
*/
virtual RemoteContext::Ptr GetDefaultContext() = 0;
virtual RemoteContext::Ptr GetDefaultContext(const ParamMap& params) = 0;

/**
* @deprecated Use ImportNetwork(std::istream& networkModel, const std::map<std::string, std::string>& config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ TEST_P(Basic_LSTM_S, CompareWithRefImpl_LowLatencyTransformation) {
manager.register_pass<ngraph::pass::LowLatency>(); // LowLatency enables UnrollTI
manager.run_passes(function);
LoadNetwork();
IE_SUPPRESS_DEPRECATED_START
auto states = executableNetwork.QueryState();
for (auto& state : states) {
auto name = state.GetName();
Expand All @@ -215,6 +216,7 @@ TEST_P(Basic_LSTM_S, CompareWithRefImpl_LowLatencyTransformation) {
GTEST_FAIL() << "unknown memory state";
}
}
IE_SUPPRESS_DEPRECATED_END
// Run and compare
Infer();
const auto& actualOutputs = GetOutputs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ namespace SubgraphTestsDefinitions {
void MemoryLSTMCellTest::Run() {
SKIP_IF_CURRENT_TEST_IS_DISABLED()

IE_SUPPRESS_DEPRECATED_START
LoadNetwork();
auto states = executableNetwork.QueryState();
for (auto& state : states) {
Expand All @@ -276,6 +277,7 @@ namespace SubgraphTestsDefinitions {
GTEST_FAIL() << "unknown memory state";
}
}
IE_SUPPRESS_DEPRECATED_END
Infer();
switchToNgraphFriendlyModel();
Validate();
Expand All @@ -297,6 +299,7 @@ namespace SubgraphTestsDefinitions {
manager.run_passes(function);
LoadNetwork();
}
IE_SUPPRESS_DEPRECATED_START
auto states = executableNetwork.QueryState();
for (auto& state : states) {
auto name = state.GetName();
Expand All @@ -312,6 +315,7 @@ namespace SubgraphTestsDefinitions {
GTEST_FAIL() << "unknown memory state";
}
}
IE_SUPPRESS_DEPRECATED_END
Infer();

CreatePureTensorIteratorModel();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,12 @@ void MemoryEltwiseReshapeConcatTest::Run() {
InferenceEngine::SizeVector({1, inputSize * concatSize}),
InferenceEngine::Layout::NC);

IE_SUPPRESS_DEPRECATED_START
auto states = executableNetwork.QueryState();
auto state_values_blob = FuncTestUtils::createAndFillBlobWithFloatArray(state_description,
memory_init.data(), memory_init.size());
states[0].SetState(state_values_blob);
IE_SUPPRESS_DEPRECATED_END
Infer();
initNgraphFriendlyModel();
Validate();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ void MultipleLSTMCellTest::Run() {
InferenceEngine::SizeVector({1, hiddenSize}),
InferenceEngine::Layout::NC);
LoadNetwork();
IE_SUPPRESS_DEPRECATED_START
auto states = executableNetwork.QueryState();
for (auto& state : states) {
auto name = state.GetName();
Expand All @@ -425,6 +426,7 @@ void MultipleLSTMCellTest::Run() {
GTEST_FAIL() << "unknown memory state";
}
}
IE_SUPPRESS_DEPRECATED_END
Infer();
switchToNgraphFriendlyModel();
Validate();
Expand All @@ -450,6 +452,7 @@ void MultipleLSTMCellTest::RunLowLatency(bool regular_api) {
manager.run_passes(function);
LoadNetwork();
}
IE_SUPPRESS_DEPRECATED_START
auto states = executableNetwork.QueryState();
for (auto& state : states) {
auto name = state.GetName();
Expand All @@ -473,6 +476,7 @@ void MultipleLSTMCellTest::RunLowLatency(bool regular_api) {
GTEST_FAIL() << "unknown memory state";
}
}
IE_SUPPRESS_DEPRECATED_END
Infer();

// Calculate ref values for Unrolled TI
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ namespace LayerTestsDefinitions {
SKIP_IF_CURRENT_TEST_IS_DISABLED()

LoadNetwork();
IE_SUPPRESS_DEPRECATED_START
auto states = executableNetwork.QueryState();
for (auto& state : states) {
auto name = state.GetName();
Expand All @@ -90,6 +91,7 @@ namespace LayerTestsDefinitions {
GTEST_FAIL() << "unknown memory state";
}
}
IE_SUPPRESS_DEPRECATED_END
Infer();
switchToNgraphFriendlyModel();
Validate();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class MockIInferencePlugin : public InferenceEngine::IInferencePlugin {
const std::string&, const std::map<std::string, InferenceEngine::Parameter>&));
MOCK_METHOD1(CreateContext,
InferenceEngine::RemoteContext::Ptr(const InferenceEngine::ParamMap&));
MOCK_METHOD0(GetDefaultContext, InferenceEngine::RemoteContext::Ptr(void));
MOCK_METHOD1(GetDefaultContext, InferenceEngine::RemoteContext::Ptr(const InferenceEngine::ParamMap&));
MOCK_METHOD3(LoadNetwork, InferenceEngine::ExecutableNetwork(
const InferenceEngine::ICNNNetwork&, const std::map<std::string, std::string>&,
InferenceEngine::RemoteContext::Ptr));
Expand Down

0 comments on commit fec3bc0

Please sign in to comment.