Skip to content

Commit

Permalink
[GNA] KSOFunction test fix (openvinotoolkit#6678)
Browse files Browse the repository at this point in the history
* [GNA] KSOFunction test fix

* Lambda for dimensions matching check
  • Loading branch information
andreysapozhn authored and akuporos committed Sep 29, 2021
1 parent 1ec6ff4 commit a2aed39
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 31 deletions.
5 changes: 3 additions & 2 deletions inference-engine/src/gna_plugin/gna_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,12 +752,14 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
passes->registerPass<FuseFQIntoWeightsPass>();
passes->registerPass<MoveFakeQuantizeLayerIntoQuantParamsPass>();

passes->registerPass<SubstituteScaleShiftBroadCastPass>();
passes->registerPass<BroadcastConstPass>();

passes->registerPass<TransposeWeightsFromNCHWToNHWCPass>();

passes->registerPass<SubstitutePReluPass>();
passes->registerPass<SubstituteSoftSignPass>();

passes->registerPass<BroadcastConstPass>();
passes->registerPass<ReorderMaxPoolPass>();
passes->registerPass<EltwiseSplitOverChannelsPass>();
passes->registerPass<InsertSplitAligningFilterPass>();
Expand All @@ -775,7 +777,6 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
#if GNA_LIB_VER == 2
passes->registerPass<ForbidActivationFusingPass>();
#endif
passes->registerPass<SubstituteScaleShiftBroadCastPass>();
passes->registerPass<FuseMultipleIdentitiesPass>();
passIdx = passes->run(passIdx);
};
Expand Down
48 changes: 22 additions & 26 deletions inference-engine/src/gna_plugin/optimizer/gna_pass_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1530,16 +1530,7 @@ void SubstituteScaleShiftBroadCastPass::run() {
continue;
}

// only 3d scaleshift supported where number of c is arbitrary
auto lastD = reshape_batch ? dataDims[1] : dataDims.back();
if (lastD != weightsElements) {
THROW_GNA_EXCEPTION << "Unsupported layer: " << l->name
<< " should have last dim(" << lastD << ") equal to weights(" << weightsElements << ") length";
}
if (dataDims.size() == 2) {
THROW_GNA_EXCEPTION << "For layer: " << l->name
<< " weights size(" << weightsElements<< ") invalid: should match input size of(" << lastD << ")";
}
// TODO: add broadcasting rules checks

gnalog() << "Substitution ScaleShift broadcast for layer: " << l->name << "\n";
if (nElements % scaleShift->_weights->size()) {
Expand Down Expand Up @@ -2220,6 +2211,17 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
}
};

auto transpInfoMatchWeightsSize = [](const std::vector<TranspositionInfo> &transpositionInfo, size_t weightsSize, const std::string &layerName) {
size_t totalElements = 0;
for (auto && transpositionInfoPart : transpositionInfo) {
totalElements += transpositionInfoPart.num_transpose_rows * transpositionInfoPart.num_transpose_columns;
}
if (totalElements != weightsSize) {
THROW_GNA_EXCEPTION << layerName << " weights elements from transposition info (" << totalElements
<< ") don't match input dimensions (" << weightsSize << ")";
}
};

for (auto &&l : *pLayers) {
if (LayerInfo(l).isScaleShift()) {
std::vector<TranspositionInfo> transpositionInfo;
Expand All @@ -2237,6 +2239,10 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
}
auto weightable = dynamic_cast<WeightableLayer*>(l.get());
IE_ASSERT(weightable != nullptr);

size_t totalWeights = weightable->_weights->size();
transpInfoMatchWeightsSize(transpositionInfo, totalWeights, l->name);

ConvertTensorFromNCHWToNHWC(weightable->precision.size(), 1, weightable->_weights->size(),
weightable->_weights->cbuffer().as<uint8_t*>(), true, transpositionInfo);
if (weightable->_biases) {
Expand Down Expand Up @@ -2270,14 +2276,9 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
// If we found a split it's not possible to rotate data
THROW_GNA_EXCEPTION << l->name << " won't be transposed due to a split before it";
}
size_t totalColumns = 0;
for (auto && transpositionInfoPart : transpositionInfo) {
totalColumns += transpositionInfoPart.num_transpose_rows * transpositionInfoPart.num_transpose_columns;
}
if (weightsColumns != totalColumns) {
THROW_GNA_EXCEPTION << l->name << " weights columns from transposition info (" << totalColumns
<< ") don't match input dimensions (" << weightsColumns << ")";
}

transpInfoMatchWeightsSize(transpositionInfo, weightsColumns, l->name);

ConvertTensorFromNCHWToNHWC(precision, weightsRows, weightsColumns, weightable->_weights->cbuffer().as<uint8_t*>(),
true, transpositionInfo);
gnalog() << l->name << " weights rows transposition info:\n";
Expand All @@ -2297,14 +2298,9 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
// If we found a concat it's not possible to rotate data
THROW_GNA_EXCEPTION << l->name << " won't be transposed due to a concat after it";
}
size_t totalRows = 0;
for (const auto& transpositionInfoPart : transpositionInfo) {
totalRows += transpositionInfoPart.num_transpose_rows * transpositionInfoPart.num_transpose_columns;
}
if (weightsRows != totalRows) {
THROW_GNA_EXCEPTION << l->name << " weights rows from transposition info (" << totalRows
<< ") don't match output dimensions (" << weightsRows << ")";
}

transpInfoMatchWeightsSize(transpositionInfo, weightsRows, l->name);

ConvertTensorFromNCHWToNHWC(precision, weightsRows, weightsColumns, weightable->_weights->cbuffer().as<uint8_t*>(),
false, transpositionInfo);
gnalog() << l->name << " weights columns transposition info:\n";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ std::vector<std::string> disabledTestPatterns() {
R"(.*ConstantResultSubgraphTest.*inPrc=(U8|I8|I32|U64|I64|BOOL).*)",
// TODO: Issue 51528
R"(.*CachingSupport.*_(u8|i16)_.*)",
// TODO: Issue 51525
R"(.*CachingSupport.*KSOFunction.*)",
// TODO: Issue 57363 (Param -> Result subgraphs)
R"(.*smoke_MemoryTest.*LOW_LATENCY.*iteration_count=1_.*)",
// TODO: Issue 57368 (accuracy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ const std::vector<std::vector<std::vector<size_t>>> shapes = {
{{1, 64}, {64, 1}},
{{8, 256}, {16, 128}},
{{6, 384}, {18, 128}},
{{8, 2048}, {32, 512}}
{{8, 2048}, {32, 512}},
{{2, 4, 64, 64}, {1, 8, 64, 64}}
};

const std::vector<InferenceEngine::Precision> netPrecisions = {
Expand Down

0 comments on commit a2aed39

Please sign in to comment.