Skip to content

Commit

Permalink
[CPU] Scatter nodes migration on nGraph (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxim Andronov committed May 3, 2021
1 parent 6cfba2b commit 7f18a71
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 54 deletions.
2 changes: 1 addition & 1 deletion inference-engine/src/mkldnn_plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ set(LAYERS
# ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_tile_node.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_mvn_node.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_normalize_node.cpp
# ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_scatter_update_node.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_scatter_update_node.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_interpolate_node.cpp
# ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_reduce_node.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_reference_node.cpp
Expand Down
8 changes: 4 additions & 4 deletions inference-engine/src/mkldnn_plugin/mkldnn_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ static const InferenceEngine::details::caseless_unordered_map<std::string, Type>
// { "Memory", MemoryOutput }, // for construction from layer ctor
// { "Convert", Convert },
{ "MVN", MVN},
{ "NormalizeL2", NormalizeL2},
// { "ScatterUpdate", ScatterUpdate},
// { "ScatterElementsUpdate", ScatterElementsUpdate},
// { "ScatterNDUpdate", ScatterNDUpdate},
{ "NormalizeL2", NormalizeL2},
{ "ScatterUpdate", ScatterUpdate},
{ "ScatterElementsUpdate", ScatterElementsUpdate},
{ "ScatterNDUpdate", ScatterNDUpdate},
// { "Interpolate", Interpolate},
// { "ReduceAnd", ReduceAnd},
// { "ReduceL1", ReduceL1},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,59 @@
//

#include "mkldnn_scatter_update_node.h"
#include <legacy/ie_layers.h>
#include <mkldnn.hpp>
#include <string>
#include <vector>
#include <mkldnn_types.h>
#include <mkldnn_extension_utils.h>
#include <legacy/ie_layers_internal.hpp>
#include "ie_parallel.hpp"
#include <algorithm>
#include "common/cpu_memcpy.h"

#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset4.hpp>

using namespace mkldnn;
using namespace MKLDNNPlugin;
using namespace InferenceEngine;

MKLDNNScatterUpdateNode::MKLDNNScatterUpdateNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache)
: MKLDNNNode(layer, eng, cache), dataSize(0lu), indicesSize(0lu), axisSize(0lu),
dataPrec(Precision::UNSPECIFIED), indicesPrec(Precision::UNSPECIFIED), axisPrec(Precision::UNSPECIFIED) {}
bool MKLDNNScatterUpdateNode::isSupportedOperation(const std::shared_ptr<ngraph::Node>& op, std::string& errorMessage) noexcept {
try {
const auto scatterElemUpd = std::dynamic_pointer_cast<const ngraph::opset3::ScatterElementsUpdate>(op);
const auto scatterUpd = std::dynamic_pointer_cast<const ngraph::opset3::ScatterUpdate>(op);
const auto scatterNdUpd = std::dynamic_pointer_cast<const ngraph::opset4::ScatterNDUpdate>(op);
if (scatterElemUpd == nullptr && scatterUpd == nullptr && scatterNdUpd == nullptr) {
const std::string opType = op->get_type_name();
errorMessage = "Only opset" + opType == "ScatterNDUpdate" ? "4 " : "3 " + opType + " operation is supported";
return false;
}
} catch (...) {
return false;
}
return true;
}

void MKLDNNScatterUpdateNode::getSupportedDescriptors() {
if (!descs.empty())
return;
MKLDNNScatterUpdateNode::MKLDNNScatterUpdateNode(const std::shared_ptr<ngraph::Node>& op, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache)
: MKLDNNNode(op, eng, cache), dataSize(0lu), indicesSize(0lu), axisSize(0lu), dataPrec(Precision::UNSPECIFIED), indicesPrec(Precision::UNSPECIFIED),
axisPrec(Precision::UNSPECIFIED) {
std::string errorMessage;
if (isSupportedOperation(op, errorMessage)) {
errorPrefix = std::string(op->get_type_name()) + " node with name '" + getName() + "'";
} else {
IE_THROW(NotImplemented) << errorMessage;
}
}

void MKLDNNScatterUpdateNode::getSupportedDescriptors() {
if ((getParentEdges().size() != 3) && (getParentEdges().size() != 4))
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' has incorrect number of input edges";
IE_THROW() << errorPrefix << " has incorrect number of input edges";
if (getChildEdges().empty())
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' has incorrect number of output edges";
IE_THROW() << errorPrefix << " has incorrect number of output edges";

if (getParentEdgeAt(DATA_ID)->getDims().ndims() < 1 ||
getParentEdgeAt(INDICES_ID)->getDims().ndims() < 1 ||
getParentEdgeAt(UPDATE_ID)->getDims().ndims() < 1) {
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' do not support scalar input";
IE_THROW() << errorPrefix << " do not support scalar input";
}

Type scatterUpdateType = getType();
Expand All @@ -51,8 +69,7 @@ void MKLDNNScatterUpdateNode::getSupportedDescriptors() {
scatterUpdateMode = ScatterUpdateMode::ScatterNDUpdate;
axisRelaxed = false;
} else {
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' is not supported";
IE_THROW() << errorPrefix << " is not supported";
}
}

Expand All @@ -72,31 +89,28 @@ void MKLDNNScatterUpdateNode::initSupportedPrimitiveDescriptors() {

// common check
if (srcRank != dstRank) {
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' should have same rank for input and outpt tensor";
IE_THROW() << errorPrefix << " should have same rank for input and output tensor";
} else {
for (size_t r = 0; r < srcRank; r++) {
if (srcDataDim[r] != dstDataDim[r]) {
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' should have same shape for input and outpt tensor." << " The input shape is "
<< srcDataDim[r] << ", while output shape is " << dstDataDim[r] << "for" << r << "th dimension";
IE_THROW() << errorPrefix << " should have same shape for input and output tensor. The input shape is "
<< srcDataDim[r] << ", while output shape is " << dstDataDim[r] << " for " << r << "th dimension";
}
}
}
// specific check
switch (scatterUpdateMode) {
case ScatterUpdateMode::ScatterUpdate: {
if (updateRank != (srcRank + indicesRank - 1)) {
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' do not have matched tensor rank relationship for input, indices and update";
IE_THROW() << errorPrefix << " do not have matched tensor rank relationship for input, indices and update";
}
break;
}
case ScatterUpdateMode::ScatterNDUpdate: {
size_t k = indicesDim[indicesRank - 1];
if (k > srcRank) {
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' do not have an correct indices' last dimension value, which should be smaller than or equal to input tensor rank";
IE_THROW() << errorPrefix << "' do not have an correct indices' last dimension value, "
<< "which should be smaller than or equal to input tensor rank";
}

SizeVector expectUpdateShape = {};
Expand All @@ -108,37 +122,32 @@ void MKLDNNScatterUpdateNode::initSupportedPrimitiveDescriptors() {
expectUpdateShape.push_back(srcDataDim[rd]);
}
if (expectUpdateShape.size() != updateRank) {
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' do not have matched tensor rank relationship for input, indices and update";
IE_THROW() << errorPrefix << " do not have matched tensor rank relationship for input, indices and update";
}
for (size_t ru = 0; ru < updateRank; ru++) {
if (updateDim[ru] != expectUpdateShape[ru]) {
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' do not have matched tensor shape relationship for input, indices and update";
IE_THROW() << errorPrefix << " do not have matched tensor shape relationship for input, indices and update";
}
}
break;
}
case ScatterUpdateMode::ScatterElementsUpdate: {
if (srcRank != indicesRank || srcRank != updateRank) {
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' do not have the same tensor rank for input, indices and update";
IE_THROW() << errorPrefix << " do not have the same tensor rank for input, indices and update";
}
for (size_t ri = 0; ri < indicesRank; ri++) {
if (indicesDim[ri] != updateDim[ri]) {
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' do not have the same tensor shape for indices and update";
IE_THROW() << errorPrefix << " do not have the same tensor shape for indices and update";
}
}
break;
}
default: {
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' is not supported";
IE_THROW() << errorPrefix << " is not supported";
}
}

indicesPrec = getCnnLayer()->insData[INDICES_ID].lock()->getPrecision();
indicesPrec = getOriginalInputPrecisions()[INDICES_ID];
auto indicesType = MKLDNNExtensionUtils::IEPrecisionToDataType(indicesPrec);
indicesSize = MKLDNNExtensionUtils::sizeOfDataType(indicesType);
if (indicesSize >= 8) {
Expand All @@ -151,7 +160,7 @@ void MKLDNNScatterUpdateNode::initSupportedPrimitiveDescriptors() {
indicesType = MKLDNNExtensionUtils::IEPrecisionToDataType(indicesPrec);

if (axisRelaxed) {
axisPrec = getCnnLayer()->insData[AXIS_ID].lock()->getPrecision();
axisPrec = getOriginalInputPrecisions()[AXIS_ID];
auto axisType = MKLDNNExtensionUtils::IEPrecisionToDataType(axisPrec);
axisSize = MKLDNNExtensionUtils::sizeOfDataType(axisType);
if (axisSize >= 8) {
Expand All @@ -163,7 +172,7 @@ void MKLDNNScatterUpdateNode::initSupportedPrimitiveDescriptors() {
}
}

dataPrec = getCnnLayer()->insData[DATA_ID].lock()->getPrecision();
dataPrec = getOriginalInputPrecisions()[DATA_ID];
auto dataType = MKLDNNExtensionUtils::IEPrecisionToDataType(dataPrec);
dataSize = MKLDNNExtensionUtils::sizeOfDataType(dataType);

Expand Down Expand Up @@ -215,20 +224,15 @@ void MKLDNNScatterUpdateNode::createPrimitive() {
auto &updateMemPtr = getParentEdgeAt(UPDATE_ID)->getMemoryPtr();

if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' did not allocate destination memory";
IE_THROW() << errorPrefix << " did not allocate destination memory";
if (!srcMemPtr || !srcMemPtr->GetPrimitivePtr())
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' did not allocate input memory";
IE_THROW() << errorPrefix << " did not allocate input memory";
if (!indicesMemPtr || !indicesMemPtr->GetPrimitivePtr())
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' did not allocate indices memory";
IE_THROW() << errorPrefix << " did not allocate indices memory";
if (!updateMemPtr || !updateMemPtr->GetPrimitivePtr())
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' did not allocate update memory";
IE_THROW() << errorPrefix << " did not allocate update memory";
if (getSelectedPrimitiveDescriptor() == nullptr)
IE_THROW() << "'" << getType() << "'" << " layer with name '" << getName()
<< "' did not set preferable primitive descriptor";
IE_THROW() << errorPrefix << " did not set preferable primitive descriptor";
}

int64_t MKLDNNScatterUpdateNode::getIndicesValue(uint8_t *indices, size_t offset) {
Expand Down Expand Up @@ -272,7 +276,6 @@ void MKLDNNScatterUpdateNode::execute(mkldnn::stream strm) {
SizeVector indicesDim = getParentEdgeAt(INDICES_ID)->getDesc().getDims();
size_t srcRank = srcDataDim.size();
int axis = 0;
std::string errorPrefix = std::string("'") + getTypeStr() + "'" + " layer with name '" + getName() + "'";
if (axisRelaxed) {
auto &axisMemPtr = getParentEdgeAt(AXIS_ID)->getMemoryPtr();
uint8_t *axisPtr = reinterpret_cast<uint8_t*>(axisMemPtr->GetData()) +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class MKLDNNScatterUpdateNode : public MKLDNNNode {
return false;
}

static bool isSupportedOperation(const std::shared_ptr<ngraph::Node>& op, std::string& errorMessage) noexcept;

private:
void scatterUpdate(uint8_t *indicesPtr, uint8_t *updatePtr, int axis, uint8_t *dstDataPtr);
void scatterNDUpdate(uint8_t *indicesPtr, uint8_t *updatePtr, uint8_t *dstDataPtr);
Expand All @@ -48,6 +50,8 @@ class MKLDNNScatterUpdateNode : public MKLDNNNode {
bool axisRelaxed = false;
size_t dataSize, indicesSize, axisSize;
InferenceEngine::Precision dataPrec, indicesPrec, axisPrec;

std::string errorPrefix;
};

} // namespace MKLDNNPlugin

0 comments on commit 7f18a71

Please sign in to comment.