diff --git a/inference-engine/src/mkldnn_plugin/nodes/common/tensor_desc_creator.cpp b/inference-engine/src/mkldnn_plugin/nodes/common/tensor_desc_creator.cpp index 57eb047f282317..9675dc4072e166 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/common/tensor_desc_creator.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/common/tensor_desc_creator.cpp @@ -9,57 +9,74 @@ using namespace InferenceEngine; using namespace MKLDNNPlugin; namespace { - constexpr size_t channelsPos = 1lu; -} - -InferenceEngine::TensorDesc PlainFormatCreator::createDesc(const InferenceEngine::Precision &precision, const InferenceEngine::SizeVector &srcDims) const { - SizeVector order(srcDims.size()); - std::iota(order.begin(), order.end(), 0); - return TensorDesc(precision, srcDims, {srcDims, order}); -} - -InferenceEngine::TensorDesc PerChannelCreator::createDesc(const InferenceEngine::Precision &precision, const InferenceEngine::SizeVector &srcDims) const { - SizeVector order(srcDims.size()); - std::iota(order.begin(), order.end(), 0); - SizeVector blkDims = srcDims; - if (srcDims.size() > 2) { - auto moveElementBack = [](SizeVector& vector, size_t indx) { - auto itr = vector.begin() + indx; - std::rotate(itr, itr + 1, vector.end()); - }; - - moveElementBack(order, channelsPos); - moveElementBack(blkDims, channelsPos); +constexpr size_t channelsPos = 1lu; + +class PlainFormatCreator : public TensorDescCreator { +public: + virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision& precision, const InferenceEngine::SizeVector& srcDims) const { + SizeVector order(srcDims.size()); + std::iota(order.begin(), order.end(), 0); + return TensorDesc(precision, srcDims, {srcDims, order}); } + virtual size_t getMinimalRank() const { return 0lu; } +}; + +class PerChannelCreator : public TensorDescCreator { +public: + virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision &precision, const InferenceEngine::SizeVector &srcDims) const { + SizeVector order(srcDims.size()); + std::iota(order.begin(), order.end(), 0); + SizeVector blkDims = srcDims; + if (srcDims.size() > 2) { + auto moveElementBack = [](SizeVector& vector, size_t indx) { + auto itr = vector.begin() + indx; + std::rotate(itr, itr + 1, vector.end()); + }; + + moveElementBack(order, channelsPos); + moveElementBack(blkDims, channelsPos); + } - return TensorDesc(precision, srcDims, {blkDims, order}); -} - -InferenceEngine::TensorDesc ChannelBlockedCreator::createDesc(const InferenceEngine::Precision &precision, const InferenceEngine::SizeVector &srcDims) const { - if (srcDims.size() < 2) { - THROW_IE_EXCEPTION << "Can't create blocked tensor descriptor!"; + return TensorDesc(precision, srcDims, {blkDims, order}); } + virtual size_t getMinimalRank() const { return 3lu; } +}; + +class ChannelBlockedCreator : public TensorDescCreator { +public: + ChannelBlockedCreator(size_t blockSize) : _blockSize(blockSize) {} + virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision& precision, const InferenceEngine::SizeVector& srcDims) const { + if (srcDims.size() < 2) { + THROW_IE_EXCEPTION << "Can't create blocked tensor descriptor!"; + } - SizeVector order(srcDims.size()); - std::iota(order.begin(), order.end(), 0); - order.push_back(channelsPos); + SizeVector order(srcDims.size()); + std::iota(order.begin(), order.end(), 0); + order.push_back(channelsPos); - SizeVector blkDims = srcDims; - blkDims[channelsPos] = blkDims[channelsPos] / _blockSize + (blkDims[channelsPos] % _blockSize ? 1 : 0); - blkDims.push_back(_blockSize); + SizeVector blkDims = srcDims; + blkDims[channelsPos] = blkDims[channelsPos] / _blockSize + (blkDims[channelsPos] % _blockSize ? 1 : 0); + blkDims.push_back(_blockSize); - return TensorDesc(precision, srcDims, {blkDims, order}); -} - -std::map TensorDescCreator::getCommonCreators() { - return { { TensorDescCreatorTypes::nspc, CreatorConstPtr(new PerChannelCreator) }, - { TensorDescCreatorTypes::nCsp8c, CreatorConstPtr(new ChannelBlockedCreator(8)) }, - { TensorDescCreatorTypes::nCsp16c, CreatorConstPtr(new ChannelBlockedCreator(16)) }, - { TensorDescCreatorTypes::ncsp, CreatorConstPtr(new PlainFormatCreator) } }; + return TensorDesc(precision, srcDims, {blkDims, order}); + } + virtual size_t getMinimalRank() const { return 2lu; } + +private: + size_t _blockSize; +}; +} // namespace + +const TensorDescCreator::CreatorsMap& TensorDescCreator::getCommonCreators() { + static const CreatorsMap map{ { TensorDescCreatorTypes::nspc, CreatorConstPtr(new PerChannelCreator) }, + { TensorDescCreatorTypes::nCsp8c, CreatorConstPtr(new ChannelBlockedCreator(8)) }, + { TensorDescCreatorTypes::nCsp16c, CreatorConstPtr(new ChannelBlockedCreator(16)) }, + { TensorDescCreatorTypes::ncsp, CreatorConstPtr(new PlainFormatCreator) } }; + return map; } -std::pair -TensorDescCreator::makeFilteredRange(TensorDescCreator::CreatorsMap &map, unsigned int rank) { +std::pair +TensorDescCreator::makeFilteredRange(const CreatorsMap &map, unsigned int rank) { auto rankFilter = [rank](const CreatorsMap::value_type& item) { if (item.second->getMinimalRank() > rank) { return false; @@ -67,15 +84,20 @@ TensorDescCreator::makeFilteredRange(TensorDescCreator::CreatorsMap &map, unsign return true; }; - auto first = CreatorsMapFilterIterator(std::move(rankFilter), map.begin(), map.end()); + auto first = CreatorsMapFilterConstIterator(std::move(rankFilter), map.begin(), map.end()); auto last = first.end(); return std::make_pair(first, last); } -std::pair -TensorDescCreator::makeFilteredRange(TensorDescCreator::CreatorsMap &map, unsigned int rank, std::set supportedTypes) { - auto rankTypesFilter = [rank, supportedTypes](const CreatorsMap::value_type& item) { - if (!supportedTypes.count(item.first)) { +std::pair +TensorDescCreator::makeFilteredRange(const CreatorsMap& map, unsigned rank, const std::vector& supportedTypes) { + size_t bitMask = 0ul; + for (auto& item : supportedTypes) { + bitMask |= 1 << static_cast(item); + } + + auto rankTypesFilter = [rank, bitMask](const CreatorsMap::value_type& item) { + if (!(bitMask & (1 << static_cast(item.first)))) { return false; } if (item.second->getMinimalRank() > rank) { @@ -84,7 +106,14 @@ TensorDescCreator::makeFilteredRange(TensorDescCreator::CreatorsMap &map, unsign return true; }; - auto first = CreatorsMapFilterIterator(std::move(rankTypesFilter), map.begin(), map.end()); + auto first = CreatorsMapFilterConstIterator(std::move(rankTypesFilter), map.begin(), map.end()); + auto last = first.end(); + return std::make_pair(first, last); +} + +std::pair +TensorDescCreator::makeFilteredRange(const CreatorsMap &map, TensorDescCreator::Predicate predicate) { + auto first = CreatorsMapFilterConstIterator(std::move(predicate), map.begin(), map.end()); auto last = first.end(); return std::make_pair(first, last); } diff --git a/inference-engine/src/mkldnn_plugin/nodes/common/tensor_desc_creator.h b/inference-engine/src/mkldnn_plugin/nodes/common/tensor_desc_creator.h index 513167607aa162..0707dc69d329f5 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/common/tensor_desc_creator.h +++ b/inference-engine/src/mkldnn_plugin/nodes/common/tensor_desc_creator.h @@ -5,61 +5,41 @@ #pragma once #include -#include namespace MKLDNNPlugin { -enum class TensorDescCreatorTypes { - nspc, - nCsp8c, - nCsp16c, - ncsp +enum class TensorDescCreatorTypes : unsigned { + nspc, // general per channels format + nCsp8c, // general channels blocked by 8 + nCsp16c, // general channels blocked by 16 + ncsp // general planar }; -class CreatorsMapFilterIterator; +class CreatorsMapFilterConstIterator; class TensorDescCreator { public: typedef std::shared_ptr CreatorPtr; typedef std::shared_ptr CreatorConstPtr; typedef std::map CreatorsMap; + typedef std::function Predicate; public: - static CreatorsMap getCommonCreators(); - static std::pair - makeFilteredRange(CreatorsMap& map, unsigned rank); - static std::pair - makeFilteredRange(CreatorsMap& map, unsigned rank, std::set supportedTypes); + static const CreatorsMap& getCommonCreators(); + static std::pair + makeFilteredRange(const CreatorsMap &map, unsigned rank); + static std::pair + makeFilteredRange(const CreatorsMap& map, unsigned rank, const std::vector& supportedTypes); + static std::pair + makeFilteredRange(const CreatorsMap& map, Predicate predicate); virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision& precision, const InferenceEngine::SizeVector& srcDims) const = 0; virtual size_t getMinimalRank() const = 0; - virtual ~TensorDescCreator() {} + virtual ~TensorDescCreator() = default; }; -class PlainFormatCreator : public TensorDescCreator { +class CreatorsMapFilterConstIterator { public: - virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision& precision, const InferenceEngine::SizeVector& srcDims) const; - virtual size_t getMinimalRank() const { return 0lu; } -}; - -class PerChannelCreator : public TensorDescCreator { -public: - virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision &precision, const InferenceEngine::SizeVector &srcDims) const; - virtual size_t getMinimalRank() const { return 3lu; } -}; - -class ChannelBlockedCreator : public TensorDescCreator { -public: - ChannelBlockedCreator(size_t blockSize) : _blockSize(blockSize) {} - virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision& precision, const InferenceEngine::SizeVector& srcDims) const; - virtual size_t getMinimalRank() const { return 2lu; } - -private: - size_t _blockSize; -}; - -class CreatorsMapFilterIterator { -public: - typedef TensorDescCreator::CreatorsMap::iterator Iterator; + typedef TensorDescCreator::CreatorsMap::const_iterator Iterator; typedef std::iterator_traits::value_type value_type; typedef std::iterator_traits::reference reference; typedef std::iterator_traits::pointer pointer; @@ -68,20 +48,24 @@ class CreatorsMapFilterIterator { typedef std::function predicate_type; public: - CreatorsMapFilterIterator(predicate_type filter, Iterator begin, Iterator end) : _filter(std::move(filter)), _iter(begin), _end(end) {} - CreatorsMapFilterIterator& operator++() { + CreatorsMapFilterConstIterator(predicate_type filter, Iterator begin, Iterator end) : _filter(std::move(filter)), _iter(begin), _end(end) { + while (_iter != _end && !_filter(*_iter)) { + ++_iter; + } + } + CreatorsMapFilterConstIterator& operator++() { do { ++_iter; } while (_iter != _end && !_filter(*_iter)); return *this; } - CreatorsMapFilterIterator end() const { - return CreatorsMapFilterIterator(predicate_type(), _end, _end); + CreatorsMapFilterConstIterator end() const { + return CreatorsMapFilterConstIterator(predicate_type(), _end, _end); } - CreatorsMapFilterIterator operator++(int) { - CreatorsMapFilterIterator temp(*this); + CreatorsMapFilterConstIterator operator++(int) { + CreatorsMapFilterConstIterator temp(*this); ++*this; return temp; } @@ -94,11 +78,11 @@ class CreatorsMapFilterIterator { return std::addressof(*_iter); } - friend bool operator==(const CreatorsMapFilterIterator& lhs, const CreatorsMapFilterIterator& rhs) { + friend bool operator==(const CreatorsMapFilterConstIterator& lhs, const CreatorsMapFilterConstIterator& rhs) { return lhs._iter == rhs._iter; } - friend bool operator!=(const CreatorsMapFilterIterator& lhs, const CreatorsMapFilterIterator& rhs) { + friend bool operator!=(const CreatorsMapFilterConstIterator& lhs, const CreatorsMapFilterConstIterator& rhs) { return !(lhs == rhs); }