Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
ndemashov committed Aug 2, 2021
1 parent 0cc7412 commit 8654a22
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,31 @@ MoveFakeQuantize::MoveFakeQuantize(const Params& params) : LayerTransformation(p

bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern::Matcher& m) {
auto fq = m.get_match_root();
auto relu = fq->get_input_node_shared_ptr(0);
auto concat = relu->get_input_node_shared_ptr(0);
auto result = *fq->output(0).get_target_inputs().begin();
auto input1 = concat->get_input_node_shared_ptr(0);
auto input2 = concat->get_input_node_shared_ptr(1);
auto relu1 = std::make_shared<ngraph::opset1::Relu>(input1->output(0));
auto relu2 = std::make_shared<ngraph::opset1::Relu>(input2->output(0));
auto fq1 = std::make_shared<opset1::FakeQuantize>(relu1,
auto operation = fq->get_input_node_shared_ptr(0);
auto type = operation->get_type_name();
std::shared_ptr<ngraph::Node> concat, fq1input, fq2input;
if (strcmp(type, "Concat") == 0) {
concat = operation;
fq1input = operation->get_input_node_shared_ptr(0);
fq2input = operation->get_input_node_shared_ptr(1);
}
else {
concat = operation->get_input_node_shared_ptr(0);
auto input1 = concat->get_input_node_shared_ptr(0);
auto input2 = concat->get_input_node_shared_ptr(1);
if (strcmp(type, "Relu") == 0) {
fq1input = std::make_shared<ngraph::opset1::Relu>(input1->output(0));
fq2input = std::make_shared<ngraph::opset1::Relu>(input2->output(0));
}
}
auto fq1 = std::make_shared<opset1::FakeQuantize>(fq1input,
fq->get_input_node_shared_ptr(1),
fq->get_input_node_shared_ptr(2),
fq->get_input_node_shared_ptr(3),
fq->get_input_node_shared_ptr(4),
as_type_ptr<opset1::FakeQuantize>(fq)->get_levels());
auto fq2 = std::make_shared<opset1::FakeQuantize>(relu2,
auto fq2 = std::make_shared<opset1::FakeQuantize>(fq2input,
fq->get_input_node_shared_ptr(1),
fq->get_input_node_shared_ptr(2),
fq->get_input_node_shared_ptr(3),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,51 +43,61 @@ namespace {

class MoveFakeQuantizeActualValues {
public:
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize1;
ngraph::builder::subgraph::DequantizationOperations::Convert convert1;
ngraph::builder::subgraph::DequantizationOperations dequantization1;
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize2;
ngraph::builder::subgraph::DequantizationOperations::Convert convert2;
ngraph::builder::subgraph::DequantizationOperations dequantization2;
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize3;
ngraph::builder::subgraph::DequantizationOperations::Convert convert3;
ngraph::builder::subgraph::DequantizationOperations dequantization3;
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeBefore1; //before1
ngraph::builder::subgraph::DequantizationOperations::Convert convertBefore1;
ngraph::builder::subgraph::DequantizationOperations dequantizationBefore1;
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeBefore2; //before1
ngraph::builder::subgraph::DequantizationOperations::Convert convertBefore2;
ngraph::builder::subgraph::DequantizationOperations dequantizationBefore2;
std::string operation;
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeAfter; // after
ngraph::builder::subgraph::DequantizationOperations::Convert convertAfter;
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
};

inline std::ostream& operator<<(std::ostream& out, const MoveFakeQuantizeActualValues& values) {
return out << "_" <<
values.fakeQuantize1 << "_" <<
values.convert1.outPrecision << "_" <<
values.dequantization1 << "_" <<
values.fakeQuantize2 << "_" <<
values.convert2.outPrecision << "_" <<
values.dequantization2;
values.fakeQuantizeBefore1 << "_" <<
values.convertBefore1.outPrecision << "_" <<
values.dequantizationBefore1 << "_" <<
values.fakeQuantizeBefore2 << "_" <<
values.convertBefore2.outPrecision << "_" <<
values.dequantizationBefore2 << "_" <<
values.operation << "_" <<
values.fakeQuantizeAfter << "_" <<
values.convertAfter.outPrecision << "_" <<
values.dequantizationAfter;
}

class MoveFakeQuantizeResultValues {
public:
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize1;
ngraph::builder::subgraph::DequantizationOperations::Convert convert1;
ngraph::builder::subgraph::DequantizationOperations dequantization1;
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize2;
ngraph::builder::subgraph::DequantizationOperations::Convert convert2;
ngraph::builder::subgraph::DequantizationOperations dequantization2;
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize3;
ngraph::builder::subgraph::DequantizationOperations::Convert convert3;
ngraph::builder::subgraph::DequantizationOperations dequantization3;
ngraph::element::Type precisionAfterOperation;
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeBefore1;
ngraph::builder::subgraph::DequantizationOperations::Convert convertBefore1;
ngraph::builder::subgraph::DequantizationOperations dequantizationBefore1;
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeBefore2;
ngraph::builder::subgraph::DequantizationOperations::Convert convertBefore2;
ngraph::builder::subgraph::DequantizationOperations dequantizationBefore2;
std::string operation;
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeAfter;
ngraph::builder::subgraph::DequantizationOperations::Convert convertAfter;
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
ngraph::element::Type precisionAfterOperation;
ngraph::builder::subgraph::DequantizationOperations dequantizationAfterNotFQ;
};

inline std::ostream& operator<<(std::ostream& out, const MoveFakeQuantizeResultValues& values) {
return out << "_" <<
values.fakeQuantize1 << "_" <<
values.convert1.outPrecision << "_" <<
values.dequantization1 << "_" <<
values.fakeQuantize2 << "_" <<
values.convert2.outPrecision << "_" <<
values.dequantization2 << "_" <<
values.dequantizationAfter;
values.fakeQuantizeBefore1 << "_" <<
values.convertBefore1.outPrecision << "_" <<
values.dequantizationBefore1 << "_" <<
values.fakeQuantizeBefore2 << "_" <<
values.convertBefore2.outPrecision << "_" <<
values.dequantizationBefore2 << "_" <<
values.operation << "_" <<
values.fakeQuantizeAfter << "_" <<
values.convertAfter << "_" <<
values.dequantizationAfter << "_" <<
values.dequantizationAfterNotFQ;
}

class MoveFakeQuantizeTestValues {
Expand Down Expand Up @@ -139,24 +149,25 @@ class MoveFakeQuantize : public LayerTransformation, public testing::WithParamIn

// dequantization output precision depends on input precision
// to avoid huge amount of tests cases let's define dequantization output precision as input precision
if (!testValues.actual.dequantization1.multiply.empty()) {
testValues.actual.dequantization1.multiply.outPrecision = precision;
if (!testValues.actual.dequantizationBefore1.multiply.empty()) {
testValues.actual.dequantizationBefore1.multiply.outPrecision = precision;
}
if (!testValues.actual.dequantization2.multiply.empty()) {
testValues.actual.dequantization2.multiply.outPrecision = precision;
if (!testValues.actual.dequantizationBefore2.multiply.empty()) {
testValues.actual.dequantizationBefore2.multiply.outPrecision = precision;
}
actualFunction = ngraph::builder::subgraph::MoveFakeQuantize::get(
precision,
shape,
testValues.actual.fakeQuantize1,
testValues.actual.convert1,
testValues.actual.dequantization1,
testValues.actual.fakeQuantize2,
testValues.actual.convert2,
testValues.actual.dequantization2,
testValues.actual.fakeQuantize3,
testValues.actual.convert3,
testValues.actual.dequantization3,
testValues.actual.fakeQuantizeBefore1,
testValues.actual.convertBefore1,
testValues.actual.dequantizationBefore1,
testValues.actual.fakeQuantizeBefore2,
testValues.actual.convertBefore2,
testValues.actual.dequantizationBefore2,
testValues.actual.operation,
testValues.actual.fakeQuantizeAfter,
testValues.actual.convertAfter,
testValues.actual.dequantizationAfter,
ngraph::element::undefined,
{},
testValues.axis);
Expand Down Expand Up @@ -196,15 +207,16 @@ class MoveFakeQuantize : public LayerTransformation, public testing::WithParamIn
referenceFunction = ngraph::builder::subgraph::MoveFakeQuantize::get(
precision,
shape,
testValues.result.fakeQuantize1,
testValues.result.convert1,
testValues.result.dequantization1,
testValues.result.fakeQuantize2,
testValues.result.convert2,
testValues.result.dequantization2,
testValues.result.fakeQuantize3,
testValues.result.convert3,
testValues.result.dequantization3,
testValues.result.fakeQuantizeBefore1,
testValues.result.convertBefore1,
testValues.result.dequantizationBefore1,
testValues.result.fakeQuantizeBefore2,
testValues.result.convertBefore2,
testValues.result.dequantizationBefore2,
testValues.result.operation,
testValues.result.fakeQuantizeAfter,
testValues.result.convertAfter,
testValues.result.dequantizationAfter,
testValues.result.precisionAfterOperation,
{},
testValues.axis);
Expand Down Expand Up @@ -270,6 +282,38 @@ const std::vector<MoveFakeQuantizeTestValues> testValues = {
{},
{},
{},
"",
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
{},
{}
},
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
{},
{},
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
{},
{},
"",
{},
{},
{},
},
false,
false
},
{
LayerTransformation::createParamsU8I8(),
false,
1,
{
{},
{},
{},
{},
{},
{},
"relu",
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
{},
{}
Expand All @@ -281,13 +325,14 @@ const std::vector<MoveFakeQuantizeTestValues> testValues = {
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
{},
{},
"relu",
{},
{},
{},
},
false,
false
},
}
};
INSTANTIATE_TEST_SUITE_P(
smoke_LPT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class MoveFakeQuantize {
const FakeQuantizeOnDataWithConstant& fakeQuantize2,
const DequantizationOperations::Convert& convert2,
const DequantizationOperations& dequantization2,
const std::string operation,
const FakeQuantizeOnDataWithConstant& fakeQuantize3,
const DequantizationOperations::Convert& convert3,
const DequantizationOperations& dequantization3,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ std::shared_ptr<ngraph::Function> MoveFakeQuantize::get(
const FakeQuantizeOnDataWithConstant& fqOnData2,
const DequantizationOperations::Convert& convert2,
const DequantizationOperations& dequantization2,
const std::string operation,
const FakeQuantizeOnDataWithConstant& fqOnData3,
const DequantizationOperations::Convert& convert3,
const DequantizationOperations& dequantization3,
Expand All @@ -43,17 +44,25 @@ std::shared_ptr<ngraph::Function> MoveFakeQuantize::get(


if (fqOnData3.empty()) {
auto relu1 = std::make_shared<ngraph::opset1::Relu>(input1->output(0));
auto relu2 = std::make_shared<ngraph::opset1::Relu>(input2->output(0));
std::shared_ptr<Node> parent1 = makeFakeQuantizeTypeRelaxed(relu1, inputPrecision, fqOnData1);
std::shared_ptr<Node> parent1, parent2;
if (operation == "relu") {
auto relu1 = std::make_shared<ngraph::opset1::Relu>(input1->output(0));
auto relu2 = std::make_shared<ngraph::opset1::Relu>(input2->output(0));
parent1 = makeFakeQuantizeTypeRelaxed(relu1, inputPrecision, fqOnData1);
parent2 = makeFakeQuantizeTypeRelaxed(relu2, inputPrecision, fqOnData2);
}
else {
parent1 = makeFakeQuantizeTypeRelaxed(input1, inputPrecision, fqOnData1);
parent2 = makeFakeQuantizeTypeRelaxed(input1, inputPrecision, fqOnData2);
}

if (!convert1.empty()) {
parent1 = std::make_shared<opset1::Convert>(parent1, convert1.outPrecision);
}
if (!dequantization1.empty()) {
parent1 = makeDequantization(parent1, dequantization1);
}

std::shared_ptr<Node> parent2 = makeFakeQuantizeTypeRelaxed(relu2, inputPrecision, fqOnData2);
if (!convert2.empty()) {
parent2 = std::make_shared<opset1::Convert>(parent2, convert2.outPrecision);
}
Expand Down Expand Up @@ -81,8 +90,14 @@ std::shared_ptr<ngraph::Function> MoveFakeQuantize::get(
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");

concat->set_friendly_name("output");
auto relu = std::make_shared<ngraph::opset1::Relu>(concat->output(0));
std::shared_ptr<Node> fq = makeFakeQuantize(relu, inputPrecision, fqOnData3);
std::shared_ptr<Node> fq;
if (operation == "relu") {
auto relu = std::make_shared<ngraph::opset1::Relu>(concat->output(0));
fq = makeFakeQuantize(relu, inputPrecision, fqOnData3);
}
else {
fq = makeFakeQuantize(concat, inputPrecision, fqOnData3);
}

ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(fq) };
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
Expand Down

0 comments on commit 8654a22

Please sign in to comment.