Skip to content

Commit

Permalink
Lambda for dimensions matching check
Browse files Browse the repository at this point in the history
  • Loading branch information
andreysapozhn committed Jul 16, 2021
1 parent af8c018 commit 7f9d816
Showing 1 changed file with 21 additions and 25 deletions.
46 changes: 21 additions & 25 deletions inference-engine/src/gna_plugin/optimizer/gna_pass_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2211,6 +2211,17 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
}
};

auto transpInfoMatchWeightsSize = [](const std::vector<TranspositionInfo> &transpositionInfo, size_t weightsSize, const std::string &layerName) {
size_t totalElements = 0;
for (auto && transpositionInfoPart : transpositionInfo) {
totalElements += transpositionInfoPart.num_transpose_rows * transpositionInfoPart.num_transpose_columns;
}
if (totalElements != weightsSize) {
THROW_GNA_EXCEPTION << layerName << " weights elements from transposition info (" << totalElements
<< ") don't match input dimensions (" << weightsSize << ")";
}
};

for (auto &&l : *pLayers) {
if (LayerInfo(l).isScaleShift()) {
std::vector<TranspositionInfo> transpositionInfo;
Expand All @@ -2228,15 +2239,10 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
}
auto weightable = dynamic_cast<WeightableLayer*>(l.get());
IE_ASSERT(weightable != nullptr);
size_t totalElements = 0;
auto totalWeights = weightable->_weights->size();
for (auto && transpositionInfoPart : transpositionInfo) {
totalElements += transpositionInfoPart.num_transpose_rows * transpositionInfoPart.num_transpose_columns;
}
if (totalElements != totalWeights) {
THROW_GNA_EXCEPTION << l->name << " weights columns from transposition info (" << totalElements
<< ") don't match input dimensions (" << totalWeights << ")";
}

size_t totalWeights = weightable->_weights->size();
transpInfoMatchWeightsSize(transpositionInfo, totalWeights, l->name);

ConvertTensorFromNCHWToNHWC(weightable->precision.size(), 1, weightable->_weights->size(),
weightable->_weights->cbuffer().as<uint8_t*>(), true, transpositionInfo);
if (weightable->_biases) {
Expand Down Expand Up @@ -2270,14 +2276,9 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
// If we found a split it's not possible to rotate data
THROW_GNA_EXCEPTION << l->name << " won't be transposed due to a split before it";
}
size_t totalColumns = 0;
for (auto && transpositionInfoPart : transpositionInfo) {
totalColumns += transpositionInfoPart.num_transpose_rows * transpositionInfoPart.num_transpose_columns;
}
if (weightsColumns != totalColumns) {
THROW_GNA_EXCEPTION << l->name << " weights columns from transposition info (" << totalColumns
<< ") don't match input dimensions (" << weightsColumns << ")";
}

transpInfoMatchWeightsSize(transpositionInfo, weightsColumns, l->name);

ConvertTensorFromNCHWToNHWC(precision, weightsRows, weightsColumns, weightable->_weights->cbuffer().as<uint8_t*>(),
true, transpositionInfo);
gnalog() << l->name << " weights rows transposition info:\n";
Expand All @@ -2297,14 +2298,9 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
// If we found a concat it's not possible to rotate data
THROW_GNA_EXCEPTION << l->name << " won't be transposed due to a concat after it";
}
size_t totalRows = 0;
for (const auto& transpositionInfoPart : transpositionInfo) {
totalRows += transpositionInfoPart.num_transpose_rows * transpositionInfoPart.num_transpose_columns;
}
if (weightsRows != totalRows) {
THROW_GNA_EXCEPTION << l->name << " weights rows from transposition info (" << totalRows
<< ") don't match output dimensions (" << weightsRows << ")";
}

transpInfoMatchWeightsSize(transpositionInfo, weightsRows, l->name);

ConvertTensorFromNCHWToNHWC(precision, weightsRows, weightsColumns, weightable->_weights->cbuffer().as<uint8_t*>(),
false, transpositionInfo);
gnalog() << l->name << " weights columns transposition info:\n";
Expand Down

0 comments on commit 7f9d816

Please sign in to comment.