Skip to content

Commit

Permalink
[GNA] Fix bug with broadcasting constant layer with fq layer (openvin…
Browse files Browse the repository at this point in the history
…otoolkit#5766)

* fix bug with broadcasting constant layer with fq layer

* BroadcastConstWithFakeQuantizePass is removed; BroadcastConstPass is moved up in pass list

* constLayer->outData.front()->setDims is moved to conditions

* prevLayer->outData.front()->setLayout(nextLayer->outData.front()->getLayout()); is added
  • Loading branch information
dmitriikhurtin authored and Alexey Lebedev committed May 27, 2021
1 parent 2dfb17f commit 93413ad
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 19 deletions.
2 changes: 1 addition & 1 deletion inference-engine/src/gna_plugin/gna_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -766,14 +766,14 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
passes->registerPass<InsertIdentityLayerPass>();
passes->registerPass<BreakFusingOfOutputLayersPass>();
passes->registerPass<InsertCopyLayerPass>();
passes->registerPass<BroadcastConstPass>();
passes->registerPass<InsertDiagonalLayerPass>();
passes->registerPass<HandleMultipleActivationsForTheLayerPass>();
#if GNA_LIB_VER == 2
passes->registerPass<ForbidActivationFusingPass>();
#endif
passes->registerPass<SubstituteScaleShiftBroadCastPass>();
passes->registerPass<FuseMultipleIdentitiesPass>();
passes->registerPass<BroadcastConstPass>();
passIdx = passes->run(passIdx);
};

Expand Down
43 changes: 25 additions & 18 deletions inference-engine/src/gna_plugin/optimizer/gna_pass_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1609,46 +1609,53 @@ void SubstituteScaleShiftBroadCastPass::run() {
}

void BroadcastConstPass::run() {
for (auto& constLayer : *pLayers) {
for (auto constLayer : *pLayers) {
if (!LayerInfo(constLayer).isConst()) {
continue;
}
auto isNonFunctional = [](CNNLayerPtr l) {
return LayerInfo(l).isNonFunctional();

auto isNonFunctional = [](CNNLayerPtr layer) {
return LayerInfo(layer).isNonFunctional();
};
if (!CNNNetHasNextLayerSkipCertain(constLayer, 0, 0, isNonFunctional)) {

auto nextLayer = CNNNetCheckNextLayerSkipCertain(constLayer, 0, 0, true, isNonFunctional).first;
if (!nextLayer || !LayerInfo(nextLayer).isEltwise() && !LayerInfo(nextLayer).isFakeQuantize()) {
continue;
}

auto nextLayer = CNNNetGetNextLayerSkipCertain(constLayer, 0, 0, isNonFunctional).first;
auto prevLayer = nextLayer;
if (LayerInfo(nextLayer).isFakeQuantize()) {
if (CNNNetPrevLayer(nextLayer, 0) != constLayer) {
continue;
}

if (!LayerInfo(nextLayer).isEltwise()) {
continue;
nextLayer = CNNNetCheckNextLayerSkipCertain(nextLayer, 0, 0, true, isNonFunctional).first;
if (!nextLayer || !LayerInfo(nextLayer).isEltwise()) {
continue;
}
}

auto constDims = constLayer->outData.front()->getTensorDesc().getDims();
auto constDimsSize = product(constDims.begin(), constDims.end());
auto eltwiseDims = nextLayer->outData.front()->getTensorDesc().getDims();
auto eltwiseDimsSize = product(eltwiseDims.begin(), eltwiseDims.end());

if (constDimsSize == eltwiseDimsSize) {
continue;
}

if (eltwiseDimsSize % constDimsSize) {
if (constDimsSize == eltwiseDimsSize || eltwiseDimsSize % constDimsSize) {
continue;
}

if (constLayer->blobs.find("custom") == constLayer->blobs.end()) {
auto blobsIter = constLayer->blobs.find("custom");
if (blobsIter == constLayer->blobs.end()) {
THROW_GNA_LAYER_EXCEPTION(constLayer) << "Const layer " << constLayer->name << " is missing 'custom' parameter";
}

auto currentConstBlob = constLayer->blobs.find("custom")->second;

constLayer->blobs.find("custom")->second = tileBlob(currentConstBlob, eltwiseDimsSize);

auto currentConstBlob = blobsIter->second;
blobsIter->second = tileBlob(currentConstBlob, eltwiseDimsSize);
constLayer->outData.front()->setDims(nextLayer->outData.front()->getDims());
constLayer->outData.front()->setLayout(nextLayer->outData.front()->getLayout());
if (prevLayer != nextLayer) {
prevLayer->outData.front()->setDims(nextLayer->outData.front()->getDims());
prevLayer->outData.front()->setLayout(nextLayer->outData.front()->getLayout());
}
gnalog() << "Const layer '" << constLayer->name << "' was changed to match output of '" << nextLayer->name << "'\n";
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
//
#include <vector>
#include <tuple>
#include <string>

#include <ie_core.hpp>

#include "common_test_utils/common_utils.hpp"
#include "functional_test_utils/plugin_cache.hpp"
#include "shared_test_classes/base/layer_test_utils.hpp"
#include "functional_test_utils/blob_utils.hpp"
#include "ngraph_functions/utils/ngraph_helpers.hpp"
#include "ngraph_functions/builders.hpp"
#include "ngraph_functions/pass/convert_prc.hpp"

using BroadcastConstWithFqParamsTuple = typename std::tuple<
InferenceEngine::Precision, // Network Precision
std::vector<size_t>, // Input shapes for Params Layer
std::vector<size_t>, // Input shapes for Constant Layer
size_t, // Quantization level
std::map<std::string, std::string>, // Configuration
std::string>; // Device name

namespace LayerTestsDefinitions {

class BroadcastConstWithFq : public testing::WithParamInterface<BroadcastConstWithFqParamsTuple>,
public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(testing::TestParamInfo<BroadcastConstWithFqParamsTuple> obj) {
InferenceEngine::Precision netPrecision;
std::vector<size_t> inputShape1;
std::vector<size_t> inputShape2;
size_t level{0};
std::map<std::string, std::string> configuration;
std::string targetDevice;
std::tie(netPrecision, inputShape1, inputShape2, level, configuration, targetDevice) = obj.param;
std::ostringstream result;
result << "netPRC=" << netPrecision.name() << "_";
result << "targetDevice=" << targetDevice << "_";
for (auto const& configItem : configuration) {
result << "configItem=" << configItem.first << "_" << configItem.second << "_";
}
result << "inputShape1=" << CommonTestUtils::vec2str(inputShape1) << "_";
result << "inputShape2=" << CommonTestUtils::vec2str(inputShape2) << "_";
result << "level=" << level;
return result.str();
}

protected:
void SetUp() override {
size_t level{0};
InferenceEngine::Precision netPrecision;
std::vector<size_t> inputShape1;
std::vector<size_t> inputShape2;
std::tie(netPrecision, inputShape1, inputShape2, level, configuration, targetDevice) = this->GetParam();
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params = ngraph::builder::makeParams(ngPrc, {inputShape1});
auto fakeQuantize1 = ngraph::builder::makeFakeQuantize(params[0], ngPrc, level, {}, {-0.5}, {0.5}, {-0.5}, {0.5});
auto constant = ngraph::builder::makeConstant<float>(ngPrc, inputShape2, {}, true);
auto fakeQuantize2 = ngraph::builder::makeFakeQuantize(constant, ngPrc, level, {}, {-0.5}, {0.5}, {-0.5}, {0.5});
auto add = std::make_shared<ngraph::opset1::Add>(fakeQuantize1, fakeQuantize2);
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(add)};
function = std::make_shared<ngraph::Function>(results, params, "BroadcastConstWithFq");
}
};

TEST_P(BroadcastConstWithFq, CompareWithRefImpl) {
Run();
};

std::vector<std::vector<size_t>> inputShapes1 = { {1, 1, 21, 160} };
std::vector<std::vector<size_t>> inputShapes2 = { {1, 1, 1, 160} };
const std::vector<size_t> level = { 65535 };
const std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32, InferenceEngine::Precision::FP16};
const std::vector<std::map<std::string, std::string>> configs = {
{ {"GNA_DEVICE_MODE", "GNA_SW_FP32"} },
{ {"GNA_DEVICE_MODE", "GNA_SW_EXACT"} }
};

INSTANTIATE_TEST_CASE_P(smoke_broadcast_const_with_fq, BroadcastConstWithFq,
::testing::Combine(
::testing::ValuesIn(netPrecisions),
::testing::ValuesIn(inputShapes1),
::testing::ValuesIn(inputShapes2),
::testing::ValuesIn(level),
::testing::ValuesIn(configs),
::testing::Values(CommonTestUtils::DEVICE_GNA)),
BroadcastConstWithFq::getTestCaseName);
} // namespace LayerTestsDefinitions

0 comments on commit 93413ad

Please sign in to comment.