Skip to content

Commit

Permalink
Merge pull request #6 from vzinovie/es/lpt/lpt_to_ngraph_integration
Browse files Browse the repository at this point in the history
[LPT] GPU Plugin set config fix
  • Loading branch information
dmitry-gorokhov authored Sep 28, 2020
2 parents 53a672f + f4cf3a5 commit 07d1b56
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
13 changes: 6 additions & 7 deletions inference-engine/src/cldnn_engine/cldnn_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ cldnn::device_info clDNNEngine::GetDeviceInfo(const std::map<std::string, std::s
return device_info;
}

InferenceEngine::ICNNNetwork::Ptr clDNNEngine::CloneAndTransformNetwork(const InferenceEngine::ICNNNetwork& network) const {
InferenceEngine::ICNNNetwork::Ptr clDNNEngine::CloneAndTransformNetwork(const InferenceEngine::ICNNNetwork& network, CLDNNPlugin::Config config) const {
std::shared_ptr<ICNNNetwork> clonedNetwork = cloneNetwork(network);
bool baselineIsFP16 = false;

Expand Down Expand Up @@ -139,7 +139,6 @@ InferenceEngine::ICNNNetwork::Ptr clDNNEngine::CloneAndTransformNetwork(const In
// Disable shape inference (WA for generic operations)
::ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);

CLDNNPlugin::Config config = _impl->m_config;
const bool enableInt8 = config.enableInt8 && (config.lptVersion == Config::LptVersion::nGraph);

{
Expand Down Expand Up @@ -285,7 +284,7 @@ ExecutableNetworkInternal::Ptr clDNNEngine::LoadExeNetworkImpl(const InferenceEn

CLDNNPlugin::Config conf = _impl->m_config;
auto device_info = GetDeviceInfo(config);
conf.enableInt8 = true; // device_info.supports_imad || device_info.supports_immad;
conf.enableInt8 = device_info.supports_imad || device_info.supports_immad;
conf.UpdateFromMap(config);

if (conf.enableDynamicBatch) {
Expand Down Expand Up @@ -322,7 +321,7 @@ ExecutableNetworkInternal::Ptr clDNNEngine::LoadExeNetworkImpl(const InferenceEn

context = m_defaultContext;

return std::make_shared<CLDNNExecNetwork>(*CloneAndTransformNetwork(network), context, conf);
return std::make_shared<CLDNNExecNetwork>(*CloneAndTransformNetwork(network, conf), context, conf);
}

ExecutableNetworkInternal::Ptr clDNNEngine::LoadExeNetworkImpl(const InferenceEngine::ICNNNetwork &network,
Expand All @@ -339,14 +338,14 @@ ExecutableNetworkInternal::Ptr clDNNEngine::LoadExeNetworkImpl(const InferenceEn

CLDNNPlugin::Config conf = getContextImpl(casted)->GetConfig();
auto device_info = GetDeviceInfo(config);
conf.enableInt8 = true; // device_info.supports_imad || device_info.supports_immad;
conf.enableInt8 = device_info.supports_imad || device_info.supports_immad;
conf.UpdateFromMap(config);

if (conf.enableDynamicBatch) {
conf.max_dynamic_batch = static_cast<int>(network.getBatchSize());
}

return std::make_shared<CLDNNExecNetwork>(*CloneAndTransformNetwork(network), casted, conf);
return std::make_shared<CLDNNExecNetwork>(*CloneAndTransformNetwork(network, conf), casted, conf);
}

RemoteContext::Ptr clDNNEngine::CreateContext(const ParamMap& params) {
Expand Down Expand Up @@ -390,7 +389,7 @@ void clDNNEngine::QueryNetwork(const ICNNNetwork& network,
for (auto&& node : function->get_ops()) {
originalOps.emplace(node->get_friendly_name());
}
auto clonedNetwork = CloneAndTransformNetwork(network);
auto clonedNetwork = CloneAndTransformNetwork(network, _impl->m_config);
std::unordered_set<std::string> supported;
std::unordered_set<std::string> unsupported;

Expand Down
3 changes: 2 additions & 1 deletion inference-engine/src/cldnn_engine/cldnn_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class clDNNEngine : public InferenceEngine::InferencePluginInternal,
CLDNNRemoteCLContext::Ptr m_defaultContext;

cldnn::device_info GetDeviceInfo(const std::map<std::string, std::string> &config) const;
InferenceEngine::ICNNNetwork::Ptr CloneAndTransformNetwork(const InferenceEngine::ICNNNetwork& network) const;
InferenceEngine::ICNNNetwork::Ptr CloneAndTransformNetwork(const InferenceEngine::ICNNNetwork& network,
CLDNNPlugin::Config config) const;
public:
clDNNEngine();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,11 @@ bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Fun
// otherwise we insert Convert operation.
for (auto &node : f->get_ordered_ops()) {
m_transformation_callback(node);

// Recursively run for TensorIterator body function
if (auto ti = std::dynamic_pointer_cast<opset4::TensorIterator>(node)) {
convert_function_precision(ti->get_body());
// Recursively apply transformation for sub-graph based operations
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
if (auto sub_graph = sub_graph_node->get_function()) {
convert_function_precision(sub_graph);
}
}
convert_node_input_precision(node);
}
Expand Down

0 comments on commit 07d1b56

Please sign in to comment.