Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GNA] Fix bug with broadcasting constant layer with fq layer #5766

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion inference-engine/src/gna_plugin/gna_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,14 +764,14 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
passes->registerPass<InsertIdentityLayerPass>();
passes->registerPass<BreakFusingOfOutputLayersPass>();
passes->registerPass<InsertCopyLayerPass>();
passes->registerPass<BroadcastConstPass>();
passes->registerPass<InsertDiagonalLayerPass>();
passes->registerPass<HandleMultipleActivationsForTheLayerPass>();
#if GNA_LIB_VER == 2
passes->registerPass<ForbidActivationFusingPass>();
#endif
passes->registerPass<SubstituteScaleShiftBroadCastPass>();
passes->registerPass<FuseMultipleIdentitiesPass>();
passes->registerPass<BroadcastConstPass>();
passIdx = passes->run(passIdx);
};

Expand Down
43 changes: 25 additions & 18 deletions inference-engine/src/gna_plugin/optimizer/gna_pass_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1609,46 +1609,53 @@ void SubstituteScaleShiftBroadCastPass::run() {
}

void BroadcastConstPass::run() {
for (auto& constLayer : *pLayers) {
for (auto constLayer : *pLayers) {
if (!LayerInfo(constLayer).isConst()) {
continue;
}
auto isNonFunctional = [](CNNLayerPtr l) {
return LayerInfo(l).isNonFunctional();

auto isNonFunctional = [](CNNLayerPtr layer) {
return LayerInfo(layer).isNonFunctional();
};
if (!CNNNetHasNextLayerSkipCertain(constLayer, 0, 0, isNonFunctional)) {

auto nextLayer = CNNNetCheckNextLayerSkipCertain(constLayer, 0, 0, true, isNonFunctional).first;
if (!nextLayer || !LayerInfo(nextLayer).isEltwise() && !LayerInfo(nextLayer).isFakeQuantize()) {
continue;
}

auto nextLayer = CNNNetGetNextLayerSkipCertain(constLayer, 0, 0, isNonFunctional).first;
auto prevLayer = nextLayer;
if (LayerInfo(nextLayer).isFakeQuantize()) {
if (CNNNetPrevLayer(nextLayer, 0) != constLayer) {
continue;
}

if (!LayerInfo(nextLayer).isEltwise()) {
continue;
nextLayer = CNNNetCheckNextLayerSkipCertain(nextLayer, 0, 0, true, isNonFunctional).first;
if (!nextLayer || !LayerInfo(nextLayer).isEltwise()) {
continue;
}
}

auto constDims = constLayer->outData.front()->getTensorDesc().getDims();
auto constDimsSize = product(constDims.begin(), constDims.end());
auto eltwiseDims = nextLayer->outData.front()->getTensorDesc().getDims();
auto eltwiseDimsSize = product(eltwiseDims.begin(), eltwiseDims.end());

if (constDimsSize == eltwiseDimsSize) {
continue;
}

if (eltwiseDimsSize % constDimsSize) {
if (constDimsSize == eltwiseDimsSize || eltwiseDimsSize % constDimsSize) {
continue;
}

if (constLayer->blobs.find("custom") == constLayer->blobs.end()) {
auto blobsIter = constLayer->blobs.find("custom");
if (blobsIter == constLayer->blobs.end()) {
THROW_GNA_LAYER_EXCEPTION(constLayer) << "Const layer " << constLayer->name << " is missing 'custom' parameter";
}

auto currentConstBlob = constLayer->blobs.find("custom")->second;

constLayer->blobs.find("custom")->second = tileBlob(currentConstBlob, eltwiseDimsSize);

auto currentConstBlob = blobsIter->second;
blobsIter->second = tileBlob(currentConstBlob, eltwiseDimsSize);
constLayer->outData.front()->setDims(nextLayer->outData.front()->getDims());
constLayer->outData.front()->setLayout(nextLayer->outData.front()->getLayout());
if (prevLayer != nextLayer) {
prevLayer->insData.front().lock()->setDims(nextLayer->outData.front()->getDims());
elilobanova marked this conversation as resolved.
Show resolved Hide resolved
prevLayer->outData.front()->setDims(nextLayer->outData.front()->getDims());
}
gnalog() << "Const layer '" << constLayer->name << "' was changed to match output of '" << nextLayer->name << "'\n";
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
//
#include <vector>
#include <tuple>
#include <string>

#include <ie_core.hpp>

#include "common_test_utils/common_utils.hpp"
#include "functional_test_utils/plugin_cache.hpp"
#include "shared_test_classes/base/layer_test_utils.hpp"
#include "functional_test_utils/blob_utils.hpp"
#include "ngraph_functions/utils/ngraph_helpers.hpp"
#include "ngraph_functions/builders.hpp"
#include "ngraph_functions/pass/convert_prc.hpp"

using BroadcastConstWithFqParamsTuple = typename std::tuple<
InferenceEngine::Precision, // Network Precision
std::vector<size_t>, // Input shapes for Params Layer
std::vector<size_t>, // Input shapes for Constant Layer
size_t, // Quantization level
std::map<std::string, std::string>, // Configuration
std::string>; // Device name

namespace LayerTestsDefinitions {

class BroadcastConstWithFq : public testing::WithParamInterface<BroadcastConstWithFqParamsTuple>,
public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(testing::TestParamInfo<BroadcastConstWithFqParamsTuple> obj) {
InferenceEngine::Precision netPrecision;
std::vector<size_t> inputShape1;
std::vector<size_t> inputShape2;
size_t level{0};
std::map<std::string, std::string> configuration;
std::string targetDevice;
std::tie(netPrecision, inputShape1, inputShape2, level, configuration, targetDevice) = obj.param;
std::ostringstream result;
result << "netPRC=" << netPrecision.name() << "_";
result << "targetDevice=" << targetDevice << "_";
for (auto const& configItem : configuration) {
result << "configItem=" << configItem.first << "_" << configItem.second << "_";
}
result << "inputShape1=" << CommonTestUtils::vec2str(inputShape1) << "_";
result << "inputShape2=" << CommonTestUtils::vec2str(inputShape2) << "_";
result << "level=" << level;
return result.str();
}

protected:
void SetUp() override {
size_t level{0};
InferenceEngine::Precision netPrecision;
std::vector<size_t> inputShape1;
std::vector<size_t> inputShape2;
std::tie(netPrecision, inputShape1, inputShape2, level, configuration, targetDevice) = this->GetParam();
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params = ngraph::builder::makeParams(ngPrc, {inputShape1});
auto fakeQuantize1 = ngraph::builder::makeFakeQuantize(params[0], ngPrc, level, {}, {-0.5}, {0.5}, {-0.5}, {0.5});
auto constant = ngraph::builder::makeConstant<float>(ngPrc, inputShape2, {}, true);
auto fakeQuantize2 = ngraph::builder::makeFakeQuantize(constant, ngPrc, level, {}, {-0.5}, {0.5}, {-0.5}, {0.5});
auto add = std::make_shared<ngraph::opset1::Add>(fakeQuantize1, fakeQuantize2);
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(add)};
function = std::make_shared<ngraph::Function>(results, params, "BroadcastConstWithFq");
}
};

TEST_P(BroadcastConstWithFq, CompareWithRefImpl) {
Run();
};

std::vector<std::vector<size_t>> inputShapes1 = { {1, 1, 21, 160} };
std::vector<std::vector<size_t>> inputShapes2 = { {1, 1, 1, 160} };
const std::vector<size_t> level = { 65535 };
const std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32, InferenceEngine::Precision::FP16};
const std::vector<std::map<std::string, std::string>> configs = {
{ {"GNA_DEVICE_MODE", "GNA_SW_FP32"} },
{ {"GNA_DEVICE_MODE", "GNA_SW_EXACT"} }
};

INSTANTIATE_TEST_CASE_P(smoke_broadcast_const_with_fq, BroadcastConstWithFq,
::testing::Combine(
::testing::ValuesIn(netPrecisions),
::testing::ValuesIn(inputShapes1),
::testing::ValuesIn(inputShapes2),
::testing::ValuesIn(level),
::testing::ValuesIn(configs),
::testing::Values(CommonTestUtils::DEVICE_GNA)),
BroadcastConstWithFq::getTestCaseName);
} // namespace LayerTestsDefinitions