Skip to content

Commit

Permalink
Refactoring after offline discussion.
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnick committed Jan 19, 2021
1 parent dcdff34 commit 7b35e35
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,73 +9,95 @@ using namespace InferenceEngine;
using namespace MKLDNNPlugin;

namespace {
constexpr size_t channelsPos = 1lu;
}

InferenceEngine::TensorDesc PlainFormatCreator::createDesc(const InferenceEngine::Precision &precision, const InferenceEngine::SizeVector &srcDims) const {
SizeVector order(srcDims.size());
std::iota(order.begin(), order.end(), 0);
return TensorDesc(precision, srcDims, {srcDims, order});
}

InferenceEngine::TensorDesc PerChannelCreator::createDesc(const InferenceEngine::Precision &precision, const InferenceEngine::SizeVector &srcDims) const {
SizeVector order(srcDims.size());
std::iota(order.begin(), order.end(), 0);
SizeVector blkDims = srcDims;
if (srcDims.size() > 2) {
auto moveElementBack = [](SizeVector& vector, size_t indx) {
auto itr = vector.begin() + indx;
std::rotate(itr, itr + 1, vector.end());
};

moveElementBack(order, channelsPos);
moveElementBack(blkDims, channelsPos);
constexpr size_t channelsPos = 1lu;

class PlainFormatCreator : public TensorDescCreator {
public:
virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision& precision, const InferenceEngine::SizeVector& srcDims) const {
SizeVector order(srcDims.size());
std::iota(order.begin(), order.end(), 0);
return TensorDesc(precision, srcDims, {srcDims, order});
}
virtual size_t getMinimalRank() const { return 0lu; }
};

class PerChannelCreator : public TensorDescCreator {
public:
virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision &precision, const InferenceEngine::SizeVector &srcDims) const {
SizeVector order(srcDims.size());
std::iota(order.begin(), order.end(), 0);
SizeVector blkDims = srcDims;
if (srcDims.size() > 2) {
auto moveElementBack = [](SizeVector& vector, size_t indx) {
auto itr = vector.begin() + indx;
std::rotate(itr, itr + 1, vector.end());
};

moveElementBack(order, channelsPos);
moveElementBack(blkDims, channelsPos);
}

return TensorDesc(precision, srcDims, {blkDims, order});
}

InferenceEngine::TensorDesc ChannelBlockedCreator::createDesc(const InferenceEngine::Precision &precision, const InferenceEngine::SizeVector &srcDims) const {
if (srcDims.size() < 2) {
THROW_IE_EXCEPTION << "Can't create blocked tensor descriptor!";
return TensorDesc(precision, srcDims, {blkDims, order});
}
virtual size_t getMinimalRank() const { return 3lu; }
};

class ChannelBlockedCreator : public TensorDescCreator {
public:
ChannelBlockedCreator(size_t blockSize) : _blockSize(blockSize) {}
virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision& precision, const InferenceEngine::SizeVector& srcDims) const {
if (srcDims.size() < 2) {
THROW_IE_EXCEPTION << "Can't create blocked tensor descriptor!";
}

SizeVector order(srcDims.size());
std::iota(order.begin(), order.end(), 0);
order.push_back(channelsPos);
SizeVector order(srcDims.size());
std::iota(order.begin(), order.end(), 0);
order.push_back(channelsPos);

SizeVector blkDims = srcDims;
blkDims[channelsPos] = blkDims[channelsPos] / _blockSize + (blkDims[channelsPos] % _blockSize ? 1 : 0);
blkDims.push_back(_blockSize);
SizeVector blkDims = srcDims;
blkDims[channelsPos] = blkDims[channelsPos] / _blockSize + (blkDims[channelsPos] % _blockSize ? 1 : 0);
blkDims.push_back(_blockSize);

return TensorDesc(precision, srcDims, {blkDims, order});
}

std::map<TensorDescCreatorTypes, TensorDescCreator::CreatorConstPtr> TensorDescCreator::getCommonCreators() {
return { { TensorDescCreatorTypes::nspc, CreatorConstPtr(new PerChannelCreator) },
{ TensorDescCreatorTypes::nCsp8c, CreatorConstPtr(new ChannelBlockedCreator(8)) },
{ TensorDescCreatorTypes::nCsp16c, CreatorConstPtr(new ChannelBlockedCreator(16)) },
{ TensorDescCreatorTypes::ncsp, CreatorConstPtr(new PlainFormatCreator) } };
return TensorDesc(precision, srcDims, {blkDims, order});
}
virtual size_t getMinimalRank() const { return 2lu; }

private:
size_t _blockSize;
};
} // namespace

const TensorDescCreator::CreatorsMap& TensorDescCreator::getCommonCreators() {
static const CreatorsMap map{ { TensorDescCreatorTypes::nspc, CreatorConstPtr(new PerChannelCreator) },
{ TensorDescCreatorTypes::nCsp8c, CreatorConstPtr(new ChannelBlockedCreator(8)) },
{ TensorDescCreatorTypes::nCsp16c, CreatorConstPtr(new ChannelBlockedCreator(16)) },
{ TensorDescCreatorTypes::ncsp, CreatorConstPtr(new PlainFormatCreator) } };
return map;
}

std::pair<CreatorsMapFilterIterator, CreatorsMapFilterIterator>
TensorDescCreator::makeFilteredRange(TensorDescCreator::CreatorsMap &map, unsigned int rank) {
std::pair<CreatorsMapFilterConstIterator, CreatorsMapFilterConstIterator>
TensorDescCreator::makeFilteredRange(const CreatorsMap &map, unsigned int rank) {
auto rankFilter = [rank](const CreatorsMap::value_type& item) {
if (item.second->getMinimalRank() > rank) {
return false;
}
return true;
};

auto first = CreatorsMapFilterIterator(std::move(rankFilter), map.begin(), map.end());
auto first = CreatorsMapFilterConstIterator(std::move(rankFilter), map.begin(), map.end());
auto last = first.end();
return std::make_pair(first, last);
}

std::pair<CreatorsMapFilterIterator, CreatorsMapFilterIterator>
TensorDescCreator::makeFilteredRange(TensorDescCreator::CreatorsMap &map, unsigned int rank, std::set<TensorDescCreatorTypes> supportedTypes) {
auto rankTypesFilter = [rank, supportedTypes](const CreatorsMap::value_type& item) {
if (!supportedTypes.count(item.first)) {
std::pair<CreatorsMapFilterConstIterator, CreatorsMapFilterConstIterator>
TensorDescCreator::makeFilteredRange(const CreatorsMap& map, unsigned rank, const std::vector<TensorDescCreatorTypes>& supportedTypes) {
size_t bitMask = 0ul;
for (auto& item : supportedTypes) {
bitMask |= 1 << static_cast<unsigned>(item);
}

auto rankTypesFilter = [rank, bitMask](const CreatorsMap::value_type& item) {
if (!(bitMask & (1 << static_cast<unsigned>(item.first)))) {
return false;
}
if (item.second->getMinimalRank() > rank) {
Expand All @@ -84,7 +106,14 @@ TensorDescCreator::makeFilteredRange(TensorDescCreator::CreatorsMap &map, unsign
return true;
};

auto first = CreatorsMapFilterIterator(std::move(rankTypesFilter), map.begin(), map.end());
auto first = CreatorsMapFilterConstIterator(std::move(rankTypesFilter), map.begin(), map.end());
auto last = first.end();
return std::make_pair(first, last);
}

std::pair<CreatorsMapFilterConstIterator, CreatorsMapFilterConstIterator>
TensorDescCreator::makeFilteredRange(const CreatorsMap &map, TensorDescCreator::Predicate predicate) {
auto first = CreatorsMapFilterConstIterator(std::move(predicate), map.begin(), map.end());
auto last = first.end();
return std::make_pair(first, last);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,61 +5,41 @@
#pragma once

#include <ie_layouts.h>
#include <set>

namespace MKLDNNPlugin {

enum class TensorDescCreatorTypes {
nspc,
nCsp8c,
nCsp16c,
ncsp
enum class TensorDescCreatorTypes : unsigned {
nspc, // general per channels format
nCsp8c, // general channels blocked by 8
nCsp16c, // general channels blocked by 16
ncsp // general planar
};

class CreatorsMapFilterIterator;
class CreatorsMapFilterConstIterator;

class TensorDescCreator {
public:
typedef std::shared_ptr<TensorDescCreator> CreatorPtr;
typedef std::shared_ptr<const TensorDescCreator> CreatorConstPtr;
typedef std::map<TensorDescCreatorTypes, CreatorConstPtr> CreatorsMap;
typedef std::function<bool(const CreatorsMap::value_type&)> Predicate;

public:
static CreatorsMap getCommonCreators();
static std::pair<CreatorsMapFilterIterator, CreatorsMapFilterIterator>
makeFilteredRange(CreatorsMap& map, unsigned rank);
static std::pair<CreatorsMapFilterIterator, CreatorsMapFilterIterator>
makeFilteredRange(CreatorsMap& map, unsigned rank, std::set<TensorDescCreatorTypes> supportedTypes);
static const CreatorsMap& getCommonCreators();
static std::pair<CreatorsMapFilterConstIterator, CreatorsMapFilterConstIterator>
makeFilteredRange(const CreatorsMap &map, unsigned rank);
static std::pair<CreatorsMapFilterConstIterator, CreatorsMapFilterConstIterator>
makeFilteredRange(const CreatorsMap& map, unsigned rank, const std::vector<TensorDescCreatorTypes>& supportedTypes);
static std::pair<CreatorsMapFilterConstIterator, CreatorsMapFilterConstIterator>
makeFilteredRange(const CreatorsMap& map, Predicate predicate);
virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision& precision, const InferenceEngine::SizeVector& srcDims) const = 0;
virtual size_t getMinimalRank() const = 0;
virtual ~TensorDescCreator() {}
virtual ~TensorDescCreator() = default;
};

class PlainFormatCreator : public TensorDescCreator {
class CreatorsMapFilterConstIterator {
public:
virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision& precision, const InferenceEngine::SizeVector& srcDims) const;
virtual size_t getMinimalRank() const { return 0lu; }
};

class PerChannelCreator : public TensorDescCreator {
public:
virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision &precision, const InferenceEngine::SizeVector &srcDims) const;
virtual size_t getMinimalRank() const { return 3lu; }
};

class ChannelBlockedCreator : public TensorDescCreator {
public:
ChannelBlockedCreator(size_t blockSize) : _blockSize(blockSize) {}
virtual InferenceEngine::TensorDesc createDesc(const InferenceEngine::Precision& precision, const InferenceEngine::SizeVector& srcDims) const;
virtual size_t getMinimalRank() const { return 2lu; }

private:
size_t _blockSize;
};

class CreatorsMapFilterIterator {
public:
typedef TensorDescCreator::CreatorsMap::iterator Iterator;
typedef TensorDescCreator::CreatorsMap::const_iterator Iterator;
typedef std::iterator_traits<Iterator>::value_type value_type;
typedef std::iterator_traits<Iterator>::reference reference;
typedef std::iterator_traits<Iterator>::pointer pointer;
Expand All @@ -68,20 +48,24 @@ class CreatorsMapFilterIterator {
typedef std::function<bool(const value_type&)> predicate_type;

public:
CreatorsMapFilterIterator(predicate_type filter, Iterator begin, Iterator end) : _filter(std::move(filter)), _iter(begin), _end(end) {}
CreatorsMapFilterIterator& operator++() {
CreatorsMapFilterConstIterator(predicate_type filter, Iterator begin, Iterator end) : _filter(std::move(filter)), _iter(begin), _end(end) {
while (_iter != _end && !_filter(*_iter)) {
++_iter;
}
}
CreatorsMapFilterConstIterator& operator++() {
do {
++_iter;
} while (_iter != _end && !_filter(*_iter));
return *this;
}

CreatorsMapFilterIterator end() const {
return CreatorsMapFilterIterator(predicate_type(), _end, _end);
CreatorsMapFilterConstIterator end() const {
return CreatorsMapFilterConstIterator(predicate_type(), _end, _end);
}

CreatorsMapFilterIterator operator++(int) {
CreatorsMapFilterIterator temp(*this);
CreatorsMapFilterConstIterator operator++(int) {
CreatorsMapFilterConstIterator temp(*this);
++*this;
return temp;
}
Expand All @@ -94,11 +78,11 @@ class CreatorsMapFilterIterator {
return std::addressof(*_iter);
}

friend bool operator==(const CreatorsMapFilterIterator& lhs, const CreatorsMapFilterIterator& rhs) {
friend bool operator==(const CreatorsMapFilterConstIterator& lhs, const CreatorsMapFilterConstIterator& rhs) {
return lhs._iter == rhs._iter;
}

friend bool operator!=(const CreatorsMapFilterIterator& lhs, const CreatorsMapFilterIterator& rhs) {
friend bool operator!=(const CreatorsMapFilterConstIterator& lhs, const CreatorsMapFilterConstIterator& rhs) {
return !(lhs == rhs);
}

Expand Down

0 comments on commit 7b35e35

Please sign in to comment.