Skip to content

Commit

Permalink
[LPT][TESTS] Concat with split tests: added verification of output names
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Apr 30, 2021
1 parent 7a4caa9 commit 5e69f3f
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,23 @@ inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationTes

typedef std::tuple <
ngraph::element::Type,
ConcatTransformationTestValues
ConcatTransformationTestValues,
bool // additional Convolution after Split
> ConcatTransformationParams;

class ConcatWithSplitTransformation : public LayerTransformation, public testing::WithParamInterface<ConcatTransformationParams> {
public:
void SetUp() override {
const ngraph::element::Type precision = std::get<0>(GetParam());
ConcatTransformationTestValues testValues = std::get<1>(GetParam());
const ConcatTransformationTestValues testValues = std::get<1>(GetParam());
const bool addConvolution = std::get<2>(GetParam());

actualFunction = ngraph::builder::subgraph::ConcatFunction::getOriginalWithSplitedIntermediate(
precision,
testValues.inputShape,
testValues.actual.fakeQuantize1,
testValues.actual.fakeQuantize2);
testValues.actual.fakeQuantize2,
addConvolution);

SimpleLowPrecisionTransformer transform;
if (testValues.multiChannels) {
Expand All @@ -107,18 +110,21 @@ class ConcatWithSplitTransformation : public LayerTransformation, public testing
testValues.result.dequantizationBefore1,
testValues.result.dequantizationBefore2,
testValues.result.precisionAfterOperation,
addConvolution,
testValues.result.dequantizationOperations1,
testValues.result.dequantizationOperations2);
}

static std::string getTestCaseName(testing::TestParamInfo<ConcatTransformationParams> obj) {
const ngraph::element::Type precision = std::get<0>(obj.param);
const ConcatTransformationTestValues testValues = std::get<1>(obj.param);
const bool addConvolution = std::get<2>(obj.param);

std::ostringstream result;
result <<
LayerTransformation::getTestCaseNameByParams(precision, testValues.inputShape, testValues.params) << "_" <<
(testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") <<
(addConvolution ? "" : "without_convolution_") <<
testValues.actual << "_" <<
testValues.result << "_";
return result.str();
Expand All @@ -127,7 +133,7 @@ class ConcatWithSplitTransformation : public LayerTransformation, public testing

TEST_P(ConcatWithSplitTransformation, CompareFunctions) {
actualFunction->validate_nodes_and_infer_types();
auto res = compare_functions(referenceFunction, actualFunction, true);
auto res = compare_functions(referenceFunction, actualFunction, true, true);
ASSERT_TRUE(res.first) << res.second;
}

Expand All @@ -136,6 +142,7 @@ const std::vector<ngraph::element::Type> precisions = {
// ngraph::element::f16
};

namespace casesWithConvolution {
const std::vector<ConcatTransformationTestValues> testValues = {
// U8: concat
{
Expand Down Expand Up @@ -298,6 +305,43 @@ INSTANTIATE_TEST_CASE_P(
ConcatWithSplitTransformation,
::testing::Combine(
::testing::ValuesIn(precisions),
::testing::ValuesIn(testValues)),
::testing::ValuesIn(testValues),
::testing::Values(true)),
ConcatWithSplitTransformation::getTestCaseName);
} // namespace casesWithConvolution

// test cases to check output names
namespace casesWithoutConvolution {
const std::vector<ConcatTransformationTestValues> testValues = {
{
{ 1, 6, 10, 10 },
LayerTransformation::createParamsU8I8(),
true,
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, {2.55f / 2.f} }
},
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f}},
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f / 2.f}, {0.f}, { 255.f}},
ngraph::element::u8,
{{}, {}, {}},
{{}, {}, {}},
ngraph::element::u8,
{ ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} },
{ ngraph::element::f32, {}, { 0.005f } }
}
},
};

INSTANTIATE_TEST_CASE_P(
smoke_LPT,
ConcatWithSplitTransformation,
::testing::Combine(
::testing::ValuesIn(precisions),
::testing::ValuesIn(testValues),
::testing::Values(false)),
ConcatWithSplitTransformation::getTestCaseName);
} // namespace casesWithoutConvolution

} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ void ConcatWithSplitTransformation::SetUp() {
netPrecision,
inputShapes,
param.fqOnData1,
param.fqOnData2);
param.fqOnData2,
true);
}

TEST_P(ConcatWithSplitTransformation, CompareWithRefImpl) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class ConcatFunction {
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const FakeQuantizeOnData& fqOnData1,
const FakeQuantizeOnData& fqOnData2);
const FakeQuantizeOnData& fqOnData2,
const bool addConvolution);

static std::shared_ptr<ngraph::Function> getOriginalSelectionWithIntermediate(
const ngraph::element::Type precision,
Expand Down Expand Up @@ -151,6 +152,7 @@ class ConcatFunction {
const DequantizationOperations& dequantizationBefore1,
const DequantizationOperations& dequantizationBefore2,
const ngraph::element::Type precisionAfterOperation,
const bool addConvolution,
const DequantizationOperations& dequantizationOperations1,
const DequantizationOperations& dequantizationOperations2);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithSplitedIntermed
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const FakeQuantizeOnData& fqOnData1,
const FakeQuantizeOnData& fqOnData2) {
const FakeQuantizeOnData& fqOnData2,
const bool addConvolution) {
size_t numSplit = 2;
size_t splitedAxis = 1;

Expand Down Expand Up @@ -272,24 +273,28 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithSplitedIntermed

const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(
ngraph::OutputVector{ fakeQuantize1->output(0), intermediateOp->output(0) }, splitedAxis);
concat->set_friendly_name("concat");
concat->set_friendly_name("output_1");

auto& rtInfo = concat->get_rt_info();
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");

auto weights = ngraph::opset1::Constant::create(precision, ngraph::Shape{ inputShape[1] / numSplit, inputShape[1] / numSplit, 1, 1 }, { 1 });
auto convolution = std::make_shared<ngraph::opset1::Convolution>(
intermediateOp->output(1),
weights,
ngraph::Strides{ 1, 1 },
ngraph::CoordinateDiff{ 0, 0 },
ngraph::CoordinateDiff{ 0, 0 },
ngraph::Strides{ 1, 1 });
convolution->set_friendly_name("convolution");
Output<Node> lastOutput = intermediateOp->output(1);
if (addConvolution) {
auto weights = ngraph::opset1::Constant::create(precision, ngraph::Shape{ inputShape[1] / numSplit, inputShape[1] / numSplit, 1, 1 }, { 1 });
auto convolution = std::make_shared<ngraph::opset1::Convolution>(
intermediateOp->output(1),
weights,
ngraph::Strides{ 1, 1 },
ngraph::CoordinateDiff{ 0, 0 },
ngraph::CoordinateDiff{ 0, 0 },
ngraph::Strides{ 1, 1 });
lastOutput = convolution->output(0);
}
lastOutput.get_node_shared_ptr()->set_friendly_name("output_2");

ngraph::ResultVector results{
std::make_shared<ngraph::opset1::Result>(concat),
std::make_shared<ngraph::opset1::Result>(convolution),
std::make_shared<ngraph::opset1::Result>(lastOutput),
};

std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
Expand Down Expand Up @@ -964,6 +969,7 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithSplitedInterme
const DequantizationOperations& dequantizationBefore1,
const DequantizationOperations& dequantizationBefore2,
const ngraph::element::Type precisionAfterOperation,
const bool addConvolution,
const DequantizationOperations& dequantizationOperations1,
const DequantizationOperations& dequantizationOperations2) {
size_t numSplit = 2;
Expand Down Expand Up @@ -1005,7 +1011,6 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithSplitedInterme

const auto constant = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{ }, splitedAxis);
intermediateOp = std::make_shared<ngraph::opset1::Split>(deqBefore2, constant, numSplit);

intermediateOp->set_friendly_name("intermediate");

const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(
Expand All @@ -1017,23 +1022,30 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithSplitedInterme

const auto lastDequantization1 = makeDequantization(concat, dequantizationOperations1);
const auto lastDequantization2 = makeDequantization(intermediateOp->output(1), dequantizationOperations2);
lastDequantization1->set_friendly_name("output_1");

auto weights = ngraph::opset1::Constant::create(
precision,
ngraph::Shape{ inputShape[1] / numSplit, inputShape[1] / numSplit, 1, 1 }, { 1 });
Output<Node> lastOutput = lastDequantization2;
if (addConvolution) {
auto weights = ngraph::opset1::Constant::create(
precision,
ngraph::Shape{ inputShape[1] / numSplit, inputShape[1] / numSplit, 1, 1 }, { 1 });

auto convolution = std::make_shared<ngraph::opset1::Convolution>(
lastDequantization2,
weights,
ngraph::Strides{ 1, 1 },
ngraph::CoordinateDiff{ 0, 0 },
ngraph::CoordinateDiff{ 0, 0 },
ngraph::Strides{ 1, 1 });
convolution->set_friendly_name("convolution");
auto convolution = std::make_shared<ngraph::opset1::Convolution>(
lastDequantization2,
weights,
ngraph::Strides{ 1, 1 },
ngraph::CoordinateDiff{ 0, 0 },
ngraph::CoordinateDiff{ 0, 0 },
ngraph::Strides{ 1, 1 });
convolution->set_friendly_name("output_2");
lastOutput = convolution->output(0);
} else {
lastOutput.get_node_shared_ptr()->set_friendly_name("output_2.1");
}

ngraph::ResultVector results{
std::make_shared<ngraph::opset1::Result>(lastDequantization1),
std::make_shared<ngraph::opset1::Result>(convolution)
std::make_shared<ngraph::opset1::Result>(lastOutput)
};

std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
Expand Down

0 comments on commit 5e69f3f

Please sign in to comment.