Skip to content

Commit

Permalink
[GNA] KSOFunction test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
andreysapozhn committed Jul 16, 2021
1 parent fa8f45b commit af8c018
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 15 deletions.
5 changes: 3 additions & 2 deletions inference-engine/src/gna_plugin/gna_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,12 +752,14 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
passes->registerPass<FuseFQIntoWeightsPass>();
passes->registerPass<MoveFakeQuantizeLayerIntoQuantParamsPass>();

passes->registerPass<SubstituteScaleShiftBroadCastPass>();
passes->registerPass<BroadcastConstPass>();

passes->registerPass<TransposeWeightsFromNCHWToNHWCPass>();

passes->registerPass<SubstitutePReluPass>();
passes->registerPass<SubstituteSoftSignPass>();

passes->registerPass<BroadcastConstPass>();
passes->registerPass<ReorderMaxPoolPass>();
passes->registerPass<EltwiseSplitOverChannelsPass>();
passes->registerPass<InsertSplitAligningFilterPass>();
Expand All @@ -775,7 +777,6 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
#if GNA_LIB_VER == 2
passes->registerPass<ForbidActivationFusingPass>();
#endif
passes->registerPass<SubstituteScaleShiftBroadCastPass>();
passes->registerPass<FuseMultipleIdentitiesPass>();
passIdx = passes->run(passIdx);
};
Expand Down
20 changes: 10 additions & 10 deletions inference-engine/src/gna_plugin/optimizer/gna_pass_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1530,16 +1530,7 @@ void SubstituteScaleShiftBroadCastPass::run() {
continue;
}

// only 3d scaleshift supported where number of c is arbitrary
auto lastD = reshape_batch ? dataDims[1] : dataDims.back();
if (lastD != weightsElements) {
THROW_GNA_EXCEPTION << "Unsupported layer: " << l->name
<< " should have last dim(" << lastD << ") equal to weights(" << weightsElements << ") length";
}
if (dataDims.size() == 2) {
THROW_GNA_EXCEPTION << "For layer: " << l->name
<< " weights size(" << weightsElements<< ") invalid: should match input size of(" << lastD << ")";
}
// TODO: add broadcasting rules checks

gnalog() << "Substitution ScaleShift broadcast for layer: " << l->name << "\n";
if (nElements % scaleShift->_weights->size()) {
Expand Down Expand Up @@ -2237,6 +2228,15 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
}
auto weightable = dynamic_cast<WeightableLayer*>(l.get());
IE_ASSERT(weightable != nullptr);
size_t totalElements = 0;
auto totalWeights = weightable->_weights->size();
for (auto && transpositionInfoPart : transpositionInfo) {
totalElements += transpositionInfoPart.num_transpose_rows * transpositionInfoPart.num_transpose_columns;
}
if (totalElements != totalWeights) {
THROW_GNA_EXCEPTION << l->name << " weights columns from transposition info (" << totalElements
<< ") don't match input dimensions (" << totalWeights << ")";
}
ConvertTensorFromNCHWToNHWC(weightable->precision.size(), 1, weightable->_weights->size(),
weightable->_weights->cbuffer().as<uint8_t*>(), true, transpositionInfo);
if (weightable->_biases) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ std::vector<std::string> disabledTestPatterns() {
R"(.*ConstantResultSubgraphTest.*inPrc=(U8|I8|I32|U64|I64|BOOL).*)",
// TODO: Issue 51528
R"(.*CachingSupport.*_(u8|i16)_.*)",
// TODO: Issue 51525
R"(.*CachingSupport.*KSOFunction.*)",
// TODO: Issue 57363 (Param -> Result subgraphs)
R"(.*smoke_MemoryTest.*LOW_LATENCY.*iteration_count=1_.*)",
// TODO: Issue 57368 (accuracy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ const std::vector<std::vector<std::vector<size_t>>> shapes = {
{{1, 64}, {64, 1}},
{{8, 256}, {16, 128}},
{{6, 384}, {18, 128}},
{{8, 2048}, {32, 512}}
{{8, 2048}, {32, 512}},
{{2, 4, 64, 64}, {1, 8, 64, 64}}
};

const std::vector<InferenceEngine::Precision> netPrecisions = {
Expand Down

0 comments on commit af8c018

Please sign in to comment.