Skip to content

Commit

Permalink
Refactoring code
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Feb 15, 2021
1 parent 18c861d commit 40acc52
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ void MKLDNNStridedSliceNode::getSupportedDescriptors() {
if (stridedSliceLayer->insData[DATA_ID].lock() == nullptr)
THROW_ERROR << "has nullable input data";

SizeVector srcDims = stridedSliceLayer->insData[DATA_ID].lock()->getTensorDesc().getDims();
SizeVector dstDims = stridedSliceLayer->outData[0]->getTensorDesc().getDims();
const SizeVector srcDims = stridedSliceLayer->insData[DATA_ID].lock()->getTensorDesc().getDims();
const SizeVector dstDims = stridedSliceLayer->outData[0]->getTensorDesc().getDims();

if (getParentEdges().size() < 3)
THROW_ERROR << "has incorrect number of input edges";
Expand Down Expand Up @@ -95,12 +95,12 @@ void MKLDNNStridedSliceNode::getSupportedDescriptors() {
int newAxis = 0;
for (auto na : newAxisMask)
newAxis += na;
params.maxDims = srcDims.size() + newAxis;
size_t maxDims = srcDims.size() + newAxis;

int shrinkAxis = 0;
for (auto sa : shrinkAxisMask)
shrinkAxis += sa;
params.equalDims = srcDims.size() == params.maxDims && shrinkAxis == 0;
params.equalDims = srcDims.size() == maxDims && shrinkAxis == 0;

auto fillingInParameters = [&](std::vector<int>& parameter, const size_t type, const size_t size, const int bit) {
auto blob = dynamic_cast<TBlob<int>*>(getParentEdgesAtPort(type)[0]->getParent()->getCnnLayer()->blobs["custom"].get());
Expand Down Expand Up @@ -232,9 +232,9 @@ void MKLDNNStridedSliceNode::createPrimitive() {
params.dstStrides = dstBlockingDesc.getStrides();
params.dataSize = getSelectedPrimitiveDescriptor()->getConfig().inConfs[DATA_ID].desc.getPrecision().size();

const bool isBlockedLayout = getParentEdgeAt(DATA_ID)->getMemory().GetDesc().isBlockedCFormat();
const bool isPerChannelLayout = getParentEdgeAt(DATA_ID)->getMemory().GetDesc().isTailCFormat();
params.maxDims += static_cast<size_t>(isBlockedLayout);
const bool isBlockedLayout = getParentEdgeAt(DATA_ID)->getMemory().GetDesc().isBlockedCFormat();
size_t realNDims = params.dstDims.size();

if (isBlockedLayout) {
const size_t blk = params.srcDims.back();
Expand Down Expand Up @@ -390,7 +390,6 @@ void MKLDNNStridedSliceNode::createPrimitive() {
params.srcStrides[indexes[idx - 1]] /= params.srcDims[jdx];

begin[indexes[idx - 1]] *= params.dstDims[jdx];
end[indexes[idx - 1]] *= params.dstDims[jdx];
}
const size_t beginShift = indexes[idx - 1] + 1;
const size_t endShift = indexes[idx] + 1;
Expand All @@ -401,7 +400,6 @@ void MKLDNNStridedSliceNode::createPrimitive() {
params.srcStrides.erase(params.srcStrides.begin() + beginShift, params.srcStrides.begin() + endShift);

begin.erase(begin.begin() + beginShift, begin.begin() + endShift);
end.erase(end.begin() + beginShift, end.begin() + endShift);
stride.erase(stride.begin() + beginShift, stride.begin() + endShift);
}
}
Expand All @@ -410,17 +408,15 @@ void MKLDNNStridedSliceNode::createPrimitive() {
params.lastDstDim = nGluingLastDims * params.dataSize;
params.nDimsForWork = params.dstDims.size() - static_cast<size_t>(vLastDim);

if (params.nDimsForWork == 1 && params.maxDims > 2) {
if (params.nDimsForWork == 1 && realNDims > 2) {
const size_t realSrcDim = newSrcDims[secondDim.first];
const size_t realDstDim = newDstDims[secondDim.first];

params.dstStrides.insert(params.dstStrides.begin() + 1, params.dstStrides[0] / realDstDim);
params.srcStrides.insert(params.srcStrides.begin() + 1, params.srcStrides[0] / realSrcDim);

for (size_t idx = secondDim.first + 1; idx < secondDim.second; idx++) {
for (size_t idx = secondDim.first + 1; idx < secondDim.second; idx++)
begin[1] /= newDstDims[idx];
end[1] /= newDstDims[idx];
}

const size_t maxThreads = dnnl_get_max_threads();
if (params.dstDims[0] < maxThreads) {
Expand All @@ -429,10 +425,10 @@ void MKLDNNStridedSliceNode::createPrimitive() {
params.dstDims.insert(params.dstDims.begin() + 1, realDstDim);
params.srcDims.insert(params.srcDims.begin() + 1, realSrcDim);
}
}

if (params.nDimsForWork == 1 && params.dstDims.size() > 2)
params.lastDstDim /= newDstDims[secondDim.first];
if (params.dstDims.size() > 2)
params.lastDstDim /= newDstDims[secondDim.first];
}
}

void MKLDNNStridedSliceNode::execute(mkldnn::stream strm) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class MKLDNNStridedSliceNode : public MKLDNNNode {
InferenceEngine::SizeVector dstDims;
InferenceEngine::SizeVector srcStrides;
InferenceEngine::SizeVector dstStrides;
size_t maxDims;
size_t nDimsForWork;
size_t workAmount;
size_t lastDstDim;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,4 @@ INSTANTIATE_TEST_CASE_P(
::testing::Values(std::map<std::string, std::string>())),
StridedSliceLayerTest::getTestCaseName);

} // namespace
} // namespace

0 comments on commit 40acc52

Please sign in to comment.