Skip to content

Commit

Permalink
[CPU] Added separate inference for pc, splitted nodes (openvinotoolki…
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova authored and akuporos committed Sep 29, 2021
1 parent 3b87604 commit 70f727d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 16 deletions.
45 changes: 30 additions & 15 deletions inference-engine/src/mkldnn_plugin/mkldnn_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ void MKLDNNGraph::InitGraph() {
graphNode->cleanup();
}
#endif
ExtractConstantNodes();

ExecuteConstantNodesOnly();
}

Expand Down Expand Up @@ -390,6 +392,16 @@ void MKLDNNGraph::InitOptimalPrimitiveDescriptors() {
}
}

void MKLDNNGraph::ExtractConstantNodes() {
OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::MKLDNN_LT, "MKLDNNGraph::ExtractConstantNodes");
for (auto& graphNode : graphNodes) {
if (graphNode->isConstant())
constantGraphNodes.emplace_back(graphNode);
else
mutableGraphNodes.emplace_back(graphNode);
}
}

void MKLDNNGraph::ExecuteConstantNodesOnly() {
OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::MKLDNN_LT, "MKLDNNGraph::ExecuteConstantNodesOnly");
mkldnn::stream stream(eng);
Expand Down Expand Up @@ -418,10 +430,7 @@ void MKLDNNGraph::ExecuteConstantNodesOnly() {
return std::make_tuple(hasExternalInvalidEdges, hasLocalAllocatedEdges, outputs);
};

for (auto &graphNode : graphNodes) {
if (!graphNode->isConstant())
continue;

for (auto &graphNode : constantGraphNodes) {
if (weightsCache) {
auto sharedOutputs = acquireSharedOutputs(graphNode);

Expand Down Expand Up @@ -810,24 +819,30 @@ void MKLDNNGraph::Infer(MKLDNNInferRequest* request, int batch) {

ENABLE_CPU_DEBUG_CAP(NodeDumper nd(config.debugCaps, infer_count));

for (int i = 0; i < graphNodes.size(); i++) {
if (request != nullptr) {
#ifdef CPU_DEBUG_CAPS
for (const auto& node : constantGraphNodes) {
if (request != nullptr)
request->ThrowIfCanceled();
}

PERF(graphNodes[i]);
ENABLE_CPU_DEBUG_CAP(nd.dumpInputBlobs(node));
ENABLE_CPU_DEBUG_CAP(nd.dumpOutputBlobs(node));
}
#endif

for (const auto& node : mutableGraphNodes) {
PERF(config.collectPerfCounters, node);
if (request != nullptr)
request->ThrowIfCanceled();

if (batch > 0)
graphNodes[i]->setDynamicBatchLim(batch);
node->setDynamicBatchLim(batch);

ENABLE_CPU_DEBUG_CAP(nd.dumpInputBlobs(graphNodes[i]));
ENABLE_CPU_DEBUG_CAP(nd.dumpInputBlobs(node));

if (!graphNodes[i]->isConstant()) {
OV_ITT_SCOPED_TASK(itt::domains::MKLDNNPlugin, graphNodes[i]->profiling.execute);
graphNodes[i]->execute(stream);
}
OV_ITT_SCOPED_TASK(itt::domains::MKLDNNPlugin, node->profiling.execute);
node->execute(stream);

ENABLE_CPU_DEBUG_CAP(nd.dumpOutputBlobs(graphNodes[i]));
ENABLE_CPU_DEBUG_CAP(nd.dumpOutputBlobs(node));
}

if (infer_count != -1) infer_count++;
Expand Down
6 changes: 6 additions & 0 deletions inference-engine/src/mkldnn_plugin/mkldnn_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,19 @@ class MKLDNNGraph {
void Allocate();
void AllocateWithReuse();
void CreatePrimitives();
void ExtractConstantNodes();
void ExecuteConstantNodesOnly();

friend class MKLDNNInferRequest;
friend class MKLDNNGraphlessInferRequest;
friend InferenceEngine::CNNNetwork dump_graph_as_ie_ngraph_net(const MKLDNNGraph &graph);

private:
// these node pointers (from graphNodes) are to avoid regular checking for
// constant node in ExecuteConstantNodesOnly and Infer methods
std::vector<MKLDNNNodePtr> constantGraphNodes;
std::vector<MKLDNNNodePtr> mutableGraphNodes;

void EnforceBF16();
};

Expand Down
3 changes: 2 additions & 1 deletion inference-engine/src/mkldnn_plugin/perf_count.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ class PerfHelper {

} // namespace MKLDNNPlugin

#define PERF(_counter) PerfHelper __helper##__counter (_counter->PerfCounter());
#define GET_PERF(_counter) std::unique_ptr<PerfHelper>(new PerfHelper(_counter->PerfCounter()))
#define PERF(_need, _counter) auto pc = _need ? GET_PERF(_counter) : nullptr;

0 comments on commit 70f727d

Please sign in to comment.