Skip to content

Commit

Permalink
Changed InferencePlugin, ICore to return internal interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed May 15, 2021
1 parent 9810c9f commit db5b669
Show file tree
Hide file tree
Showing 31 changed files with 248 additions and 657 deletions.
5 changes: 4 additions & 1 deletion inference-engine/include/cpp/ie_executable_network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ class IExecutableNetwork;
*/
class INFERENCE_ENGINE_API_CLASS(ExecutableNetwork) : protected details::SOPointer<IExecutableNetworkInternal> {
using details::SOPointer<IExecutableNetworkInternal>::SOPointer;
friend class InferencePlugin;
// TODO: remove?
ExecutableNetwork(const details::SOPointer<IExecutableNetworkInternal> & obj) :
details::SOPointer<IExecutableNetworkInternal>::SOPointer(obj) { }
friend class Core;

public:
/**
Expand Down
23 changes: 12 additions & 11 deletions inference-engine/src/auto_plugin/auto_exec_network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
#include "auto_infer_request.hpp"

namespace AutoPlugin {
using namespace InferenceEngine;

AutoExecutableNetwork::AutoExecutableNetwork(const ExecutableNetwork& network,
const DeviceInformation& deviceInfo,
const bool needPerfCounters) :
using namespace InferenceEngine;

AutoExecutableNetwork::AutoExecutableNetwork(const SoExecutableNetworkInternal& network,
const DeviceInformation& deviceInfo,
const bool needPerfCounters) :
_deviceInfo(deviceInfo),
_network(network),
_config(deviceInfo.config.begin(), deviceInfo.config.end()),
Expand All @@ -28,32 +29,32 @@ AutoExecutableNetwork::~AutoExecutableNetwork() = default;

IInferRequestInternal::Ptr AutoExecutableNetwork::CreateInferRequestImpl(InputsDataMap networkInputs,
OutputsDataMap networkOutputs) {
auto inferRequest = _network.CreateInferRequest();
auto inferRequest = _network->CreateInferRequest();
return std::make_shared<AutoInferRequest>(networkInputs, networkOutputs, inferRequest);
}

void AutoExecutableNetwork::Export(std::ostream& networkModel) {
_network.Export(networkModel);
_network->Export(networkModel);
}

RemoteContext::Ptr AutoExecutableNetwork::GetContext() const {
return _network.GetContext();
return _network->GetContext();
}

InferenceEngine::CNNNetwork AutoExecutableNetwork::GetExecGraphInfo() {
return _network.GetExecGraphInfo();
return _network->GetExecGraphInfo();
}

Parameter AutoExecutableNetwork::GetMetric(const std::string &name) const {
return _network.GetMetric(name);
return _network->GetMetric(name);
}

void AutoExecutableNetwork::SetConfig(const std::map<std::string, Parameter>& config) {
_network.SetConfig(config);
_network->SetConfig(config);
}

Parameter AutoExecutableNetwork::GetConfig(const std::string& name) const {
return _network.GetConfig(name);
return _network->GetConfig(name);
}

} // namespace AutoPlugin
8 changes: 4 additions & 4 deletions inference-engine/src/auto_plugin/auto_exec_network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ class AutoExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSaf
public:
using Ptr = std::shared_ptr<AutoExecutableNetwork>;

AutoExecutableNetwork(const InferenceEngine::ExecutableNetwork& network,
const DeviceInformation& deviceInfo,
const bool needPerfCounters = false);
AutoExecutableNetwork(const InferenceEngine::SoExecutableNetworkInternal& network,
const DeviceInformation& deviceInfo,
const bool needPerfCounters = false);

void Export(std::ostream& networkModel) override;
InferenceEngine::RemoteContext::Ptr GetContext() const override;
Expand All @@ -43,7 +43,7 @@ class AutoExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSaf
~AutoExecutableNetwork() override;

DeviceInformation _deviceInfo;
InferenceEngine::ExecutableNetwork _network;
InferenceEngine::SoExecutableNetworkInternal _network;
std::unordered_map<std::string, InferenceEngine::Parameter> _config;
bool _needPerfCounters = false;
};
Expand Down
16 changes: 8 additions & 8 deletions inference-engine/src/auto_plugin/auto_infer_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,31 @@
namespace AutoPlugin {
using namespace InferenceEngine;

AutoInferRequest::AutoInferRequest(const InputsDataMap& networkInputs,
const OutputsDataMap& networkOutputs,
const InferRequest& inferRequest)
AutoInferRequest::AutoInferRequest(const InputsDataMap& networkInputs,
const OutputsDataMap& networkOutputs,
const IInferRequestInternal::Ptr& inferRequest)
: IInferRequestInternal(networkInputs, networkOutputs)
, _inferRequest(inferRequest) {
}

std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> AutoInferRequest::GetPerformanceCounts() const {
return _inferRequest.GetPerformanceCounts();
return _inferRequest->GetPerformanceCounts();
}

void AutoInferRequest::InferImpl() {
_inferRequest.Infer();
_inferRequest->Infer();
}

void AutoInferRequest::SetBlob(const std::string& name, const InferenceEngine::Blob::Ptr& data) {
_inferRequest.SetBlob(name, data);
_inferRequest->SetBlob(name, data);
}

Blob::Ptr AutoInferRequest::GetBlob(const std::string& name) {
return _inferRequest.GetBlob(name);
return _inferRequest->GetBlob(name);
}

void AutoInferRequest::Cancel() {
_inferRequest.Cancel();
_inferRequest->Cancel();
}

} // namespace AutoPlugin
8 changes: 4 additions & 4 deletions inference-engine/src/auto_plugin/auto_infer_request.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ namespace AutoPlugin {
class AutoInferRequest : public InferenceEngine::IInferRequestInternal {
public:
using Ptr = std::shared_ptr<AutoInferRequest>;
explicit AutoInferRequest(const InferenceEngine::InputsDataMap& networkInputs,
const InferenceEngine::OutputsDataMap& networkOutputs,
const InferenceEngine::InferRequest& inferRequest);
explicit AutoInferRequest(const InferenceEngine::InputsDataMap& networkInputs,
const InferenceEngine::OutputsDataMap& networkOutputs,
const InferenceEngine::IInferRequestInternal::Ptr& inferRequest);
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
void InferImpl() override;
void SetBlob(const std::string& name, const InferenceEngine::Blob::Ptr& data) override;
InferenceEngine::Blob::Ptr GetBlob(const std::string& name) override;
void Cancel() override;

private:
InferenceEngine::InferRequest _inferRequest;
InferenceEngine::IInferRequestInternal::Ptr _inferRequest;
};

} // namespace AutoPlugin
8 changes: 4 additions & 4 deletions inference-engine/src/auto_plugin/auto_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ IE::IExecutableNetworkInternal::Ptr AutoInferencePlugin::LoadNetwork(const std::

// FIXME: always select CPU device now
DeviceInformation selectedDevice = SelectDevice(metaDevices);
IE::ExecutableNetwork executableNetwork;
IE::SoExecutableNetworkInternal executableNetwork;
try {
executableNetwork = GetCore()->LoadNetwork(fileName, selectedDevice.deviceName, selectedDevice.config);
} catch(const IE::Exception &iie) {
Expand All @@ -60,7 +60,7 @@ IE::IExecutableNetworkInternal::Ptr AutoInferencePlugin::LoadNetwork(const std::
bool enablePerfCounters = false;
try {
enablePerfCounters =
executableNetwork.GetConfig(IE::PluginConfigParams::KEY_PERF_COUNT).as<std::string>() ==
executableNetwork->GetConfig(IE::PluginConfigParams::KEY_PERF_COUNT).as<std::string>() ==
IE::PluginConfigParams::YES;
} catch (...) {
}
Expand All @@ -85,7 +85,7 @@ IE::ExecutableNetworkInternal::Ptr AutoInferencePlugin::LoadExeNetworkImpl(const

// FIXME: always select CPU device now
DeviceInformation selectedDevice = SelectDevice(metaDevices);
IE::ExecutableNetwork executableNetwork;
IE::SoExecutableNetworkInternal executableNetwork;
try {
executableNetwork = GetCore()->LoadNetwork(network, selectedDevice.deviceName, selectedDevice.config);
} catch(const IE::Exception &iie) {
Expand All @@ -96,7 +96,7 @@ IE::ExecutableNetworkInternal::Ptr AutoInferencePlugin::LoadExeNetworkImpl(const
bool enablePerfCounters = false;
try {
enablePerfCounters =
executableNetwork.GetConfig(IE::PluginConfigParams::KEY_PERF_COUNT).as<std::string>() ==
executableNetwork->GetConfig(IE::PluginConfigParams::KEY_PERF_COUNT).as<std::string>() ==
IE::PluginConfigParams::YES;
} catch (...) {
}
Expand Down
28 changes: 12 additions & 16 deletions inference-engine/src/hetero_plugin/hetero_async_infer_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,31 @@ HeteroAsyncInferRequest::HeteroAsyncInferRequest(const IInferRequestInternal::Pt
const ITaskExecutor::Ptr& taskExecutor,
const ITaskExecutor::Ptr& callbackExecutor) :
AsyncInferRequestThreadSafeDefault(request, taskExecutor, callbackExecutor),
_heteroInferRequest(std::static_pointer_cast<HeteroInferRequest>(request)),
_statusCodes{_heteroInferRequest->_inferRequests.size(), StatusCode::OK} {
_heteroInferRequest(std::static_pointer_cast<HeteroInferRequest>(request)) {
_pipeline.clear();
for (std::size_t requestId = 0; requestId < _heteroInferRequest->_inferRequests.size(); ++requestId) {
struct RequestExecutor : ITaskExecutor {
explicit RequestExecutor(InferRequest & inferRequest) : _inferRequest(inferRequest) {
_inferRequest.SetCompletionCallback<std::function<void(InferRequest, StatusCode)>>(
[this] (InferRequest, StatusCode sts) mutable {
_status = sts;
explicit RequestExecutor(IInferRequestInternal::Ptr & inferRequest) : _inferRequest(inferRequest) {
_inferRequest->SetCallback(
[this] (std::exception_ptr exceptionPtr) mutable {
_exceptionPtr = exceptionPtr;
auto capturedTask = std::move(_task);
capturedTask();
});
}
void run(Task task) override {
_task = std::move(task);
_inferRequest.StartAsync();
_inferRequest->StartAsync();
};
InferRequest & _inferRequest;
StatusCode _status = StatusCode::OK;
Task _task;
IInferRequestInternal::Ptr & _inferRequest;
std::exception_ptr _exceptionPtr;
Task _task;
};

auto requestExecutor = std::make_shared<RequestExecutor>(_heteroInferRequest->_inferRequests[requestId]._request);
_pipeline.emplace_back(requestExecutor, [requestExecutor] {
if (StatusCode::OK != requestExecutor->_status) {
IE_EXCEPTION_SWITCH(requestExecutor->_status, ExceptionType,
InferenceEngine::details::ThrowNow<ExceptionType>{}
<<= std::stringstream{} << IE_LOCATION
<< InferenceEngine::details::ExceptionTraits<ExceptionType>::string());
if (nullptr != requestExecutor->_exceptionPtr) {
std::rethrow_exception(requestExecutor->_exceptionPtr);
}
});
}
Expand All @@ -58,7 +54,7 @@ StatusCode HeteroAsyncInferRequest::Wait(int64_t millis_timeout) {
waitStatus = AsyncInferRequestThreadSafeDefault::Wait(millis_timeout);
} catch(...) {
for (auto&& requestDesc : _heteroInferRequest->_inferRequests) {
requestDesc._request.Wait(InferRequest::RESULT_READY);
requestDesc._request->Wait(InferRequest::RESULT_READY);
}
throw;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class HeteroAsyncInferRequest : public InferenceEngine::AsyncInferRequestThreadS

private:
HeteroInferRequest::Ptr _heteroInferRequest;
std::vector<InferenceEngine::StatusCode> _statusCodes;
};

} // namespace HeteroPlugin
Expand Down
Loading

0 comments on commit db5b669

Please sign in to comment.