Skip to content

Commit

Permalink
[GNA] Fixed import of model with several inputs (#7277)
Browse files Browse the repository at this point in the history
  • Loading branch information
mryzhov authored Sep 2, 2021
1 parent 6cbeb18 commit 2cf7065
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 1 deletion.
2 changes: 1 addition & 1 deletion inference-engine/src/gna_plugin/gna_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1612,7 +1612,7 @@ InferenceEngine::IExecutableNetworkInternal::Ptr GNAPlugin::ImportNetwork(std::i
// If scale factors are defined in configuration we still need to use them instead of imported values,
// for example to change the scale factors for the old models.
if (!config.inputScaleFactors.empty()) {
IE_ASSERT(config.inputScaleFactors.size() == inputsDesc->inputScaleFactors.size());
IE_ASSERT(config.inputScaleFactors.size() <= inputsDesc->inputScaleFactors.size());
for (size_t i = 0; i < config.inputScaleFactors.size(); ++i) {
if (config.inputScaleFactors[i] != GNAPluginNS::kScaleFactorDefault) {
gnalog() << "[Import Network] Using input scale factor defined in configuration for input " << i << std::endl;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <vector>
#include <memory>
#include <tuple>
#include <vector>
#include <string>
#include <fstream>

#include "ngraph_functions/builders.hpp"
#include "base/import_export_base/import_export_base.hpp"

namespace LayerTestsDefinitions {

class ImportMultiInput : public FuncTestUtils::ImportNetworkTestBase {
protected:
void SetUp() override {
InferenceEngine::Precision netPrecision;
std::tie(netPrecision, targetDevice, exportConfiguration, importConfiguration, applicationHeader) = this->GetParam();

auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto input = ngraph::builder::makeParams(ngPrc, {{1, 10}, {1, 10}});
auto mul1 = ngraph::builder::makeEltwise(input[0], input[1], ngraph::helpers::EltwiseTypes::ADD);
auto result = std::make_shared<ngraph::opset7::Result>(mul1);

function = std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, input, "multiple_input");
}
};

class ImportMultiInputChanged : public ImportMultiInput {};
class ImportMultiInputUnchanged : public ImportMultiInput {};

TEST_P(ImportMultiInputUnchanged, CompareWithRefImpl) {
TestRun(false);
};

TEST_P(ImportMultiInputChanged, CompareWithRefImpl) {
TestRun(true);
};

const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32
};

const std::vector<std::map<std::string, std::string>> exportConfigs = {
{
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
{"GNA_SCALE_FACTOR_0", "327.67"},
{"GNA_SCALE_FACTOR_1", "327.67"}
}
};

const std::vector<std::map<std::string, std::string>> importConfigsChanged = {
{
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
{"GNA_SCALE_FACTOR_0", "32767"}
},
{
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
{"GNA_SCALE_FACTOR_1", "32767"}
},
{
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
{"GNA_SCALE_FACTOR_0", "32767"},
{"GNA_SCALE_FACTOR_1", "32767"}
},
{
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
{"GNA_SCALE_FACTOR_0", "1"},
{"GNA_SCALE_FACTOR_1", "32767"}
}
};

const std::vector<std::map<std::string, std::string>> importConfigsUnchanged = {
{
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
{"GNA_SCALE_FACTOR_0", "327.67"}
},
{
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
{"GNA_SCALE_FACTOR_0", "1"}
},
{
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"}
},
{
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
{"GNA_SCALE_FACTOR_0", "327.67"},
{"GNA_SCALE_FACTOR_1", "327.67"}
},
{
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
{"GNA_SCALE_FACTOR_1", "327.67"}
},
};

INSTANTIATE_TEST_CASE_P(smoke_ImportNetworkGNA, ImportMultiInputUnchanged,
::testing::Combine(
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::ValuesIn(exportConfigs),
::testing::ValuesIn(importConfigsUnchanged),
::testing::Values("")),
ImportMultiInputUnchanged::getTestCaseName);

INSTANTIATE_TEST_CASE_P(smoke_ImportNetworkGNA, ImportMultiInputChanged,
::testing::Combine(
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::ValuesIn(exportConfigs),
::testing::ValuesIn(importConfigsChanged),
::testing::Values("")),
ImportMultiInputChanged::getTestCaseName);

} // namespace LayerTestsDefinitions

0 comments on commit 2cf7065

Please sign in to comment.