-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The convert layer was fixed after review.
- Loading branch information
Showing
4 changed files
with
157 additions
and
35 deletions.
There are no files selected for viewing
59 changes: 59 additions & 0 deletions
59
inference-engine/src/mkldnn_plugin/nodes/common/tensor_desc_creator.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
// Copyright (C) 2020 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "tensor_desc_creator.h" | ||
#include <numeric> | ||
|
||
using namespace InferenceEngine; | ||
using namespace MKLDNNPlugin; | ||
|
||
namespace { | ||
constexpr size_t channelsPos = 1lu; | ||
} | ||
|
||
InferenceEngine::TensorDesc PlainFormatCreator::createDesc(const InferenceEngine::Precision &precision, const InferenceEngine::SizeVector &srcDims) const { | ||
SizeVector order(srcDims.size()); | ||
std::iota(order.begin(), order.end(), 0); | ||
return TensorDesc(precision, srcDims, {srcDims, order}); | ||
} | ||
|
||
InferenceEngine::TensorDesc PerChannelCreator::createDesc(const InferenceEngine::Precision &precision, const InferenceEngine::SizeVector &srcDims) const { | ||
SizeVector order(srcDims.size()); | ||
std::iota(order.begin(), order.end(), 0); | ||
SizeVector blkDims = srcDims; | ||
if (srcDims.size() > 2) { | ||
auto moveElementBack = [](SizeVector& vector, size_t indx) { | ||
auto itr = vector.begin() + indx; | ||
std::rotate(itr, itr + 1, vector.end()); | ||
}; | ||
|
||
moveElementBack(order, channelsPos); | ||
moveElementBack(blkDims, channelsPos); | ||
} | ||
|
||
return TensorDesc(precision, srcDims, {blkDims, order}); | ||
} | ||
|
||
InferenceEngine::TensorDesc ChannelBlockedCreator::createDesc(const InferenceEngine::Precision &precision, const InferenceEngine::SizeVector &srcDims) const { | ||
if (srcDims.size() < 2) { | ||
THROW_IE_EXCEPTION << "Can't create blocked tensor descriptor!"; | ||
} | ||
|
||
SizeVector order(srcDims.size()); | ||
std::iota(order.begin(), order.end(), 0); | ||
order.push_back(channelsPos); | ||
|
||
SizeVector blkDims = srcDims; | ||
blkDims[channelsPos] = blkDims[channelsPos] / _blockSize + (blkDims[channelsPos] % _blockSize ? 1 : 0); | ||
blkDims.push_back(_blockSize); | ||
|
||
return TensorDesc(precision, srcDims, {blkDims, order}); | ||
} | ||
|
||
std::map<TensorDescCreatorTypes, TensorDescCreator::CreatorConstPtr> TensorDescCreator::getCommonCreators() { | ||
return { { TensorDescCreatorTypes::plain, CreatorConstPtr(new PlainFormatCreator) }, | ||
{ TensorDescCreatorTypes::perChannel, CreatorConstPtr(new PerChannelCreator) }, | ||
{ TensorDescCreatorTypes::channelBlocked8, CreatorConstPtr(new ChannelBlockedCreator(8)) }, | ||
{ TensorDescCreatorTypes::channelBlocked16, CreatorConstPtr(new ChannelBlockedCreator(16)) } }; | ||
} |
47 changes: 47 additions & 0 deletions
47
inference-engine/src/mkldnn_plugin/nodes/common/tensor_desc_creator.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// Copyright (C) 2020 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include <ie_layouts.h> | ||
|
||
namespace MKLDNNPlugin { | ||
|
||
enum class TensorDescCreatorTypes { | ||
plain, | ||
perChannel, | ||
channelBlocked8, | ||
channelBlocked16 | ||
}; | ||
|
||
class TensorDescCreator { | ||
public: | ||
typedef std::shared_ptr<TensorDescCreator> CreatorPtr; | ||
typedef std::shared_ptr<const TensorDescCreator> CreatorConstPtr; | ||
|
||
public: | ||
static std::map<TensorDescCreatorTypes, CreatorConstPtr> getCommonCreators(); | ||
virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision& precision, const InferenceEngine::SizeVector& srcDims) const = 0; | ||
virtual ~TensorDescCreator() {} | ||
}; | ||
|
||
class PlainFormatCreator : public TensorDescCreator { | ||
public: | ||
virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision& precision, const InferenceEngine::SizeVector& srcDims) const; | ||
}; | ||
|
||
class PerChannelCreator : public TensorDescCreator { | ||
public: | ||
virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision& precision, const InferenceEngine::SizeVector& srcDims) const; | ||
}; | ||
|
||
class ChannelBlockedCreator : public TensorDescCreator { | ||
public: | ||
ChannelBlockedCreator(size_t blockSize) : _blockSize(blockSize) {} | ||
virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision& precision, const InferenceEngine::SizeVector& srcDims) const; | ||
|
||
private: | ||
size_t _blockSize; | ||
}; | ||
} // namespace MKLDNNPlugin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters