Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Sep 15, 2023
1 parent 91887f3 commit f1afdfc
Show file tree
Hide file tree
Showing 5 changed files with 356 additions and 6 deletions.
3 changes: 1 addition & 2 deletions src/common/low_precision_transformations/src/multiply.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand Down Expand Up @@ -44,7 +44,6 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass:
return false;
}

// TODO: normalizeDequantization + fold_fake_quantizes + foldDequantization <= ???
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 0));
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 1));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "simple_low_precision_transformer.hpp"
#include "lpt_ngraph_functions/multiply_function.hpp"

#include "ngraph/pass/serialize.hpp"

namespace {
using namespace testing;
using namespace ov;
Expand Down Expand Up @@ -109,13 +111,22 @@ class MultiplyTransformation : public LayerTransformation, public testing::WithP
model_precision,
to_multiply_values(testParams.actual));

ngraph::pass::Serialize("svg/test.original.xml", "svg/test.original.bin").run_on_model(actualFunction);
ngraph::pass::VisualizeTree("svg/test.original.svg").run_on_model(actualFunction);

SimpleLowPrecisionTransformer transform;
transform.add<ov::pass::low_precision::MultiplyTransformation, ov::op::v1::Multiply>(testParams.transformationParams);
transform.transform(actualFunction);

ngraph::pass::Serialize("svg/test.actual.xml", "svg/test.actual.bin").run_on_model(actualFunction);
ngraph::pass::VisualizeTree("svg/test.actual.svg").run_on_model(actualFunction);

referenceFunction = MultiplyFunction::get(
model_precision,
to_multiply_values(testParams.expected));

ngraph::pass::Serialize("svg/test.reference.xml", "svg/test.reference.bin").run_on_model(referenceFunction);
ngraph::pass::VisualizeTree("svg/test.reference.svg").run_on_model(referenceFunction);
}

static std::string getTestCaseName(testing::TestParamInfo<MultiplyTransformationParams> obj) {
Expand Down Expand Up @@ -192,7 +203,7 @@ const std::vector<std::pair<ov::element::Type, ov::element::Type>> input_precisi
{ ov::element::i8, ov::element::u8 },
};

// PartialShape inputShape;
namespace broadcast_no {
const std::vector<std::pair<PartialShape, PartialShape>> input_shapes = {
{{ 1, 3, 8, 16 }, { 1, 3, 8, 16 }},
{{ 1, 3, 8, 16 }, { 1, 3, 1, 1 }},
Expand Down Expand Up @@ -363,4 +374,332 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(input_precisions),
::testing::ValuesIn(multiplyTransformationTestValues)),
MultiplyTransformation::getTestCaseName);
} // namespace
} // namespace broadcast_no

namespace broadcast_right {
const std::vector<std::pair<PartialShape, PartialShape>> input_shapes = {
{{ 1, 3, 8, 16 }, { 1, 1, 1, 1 }}
};

const std::vector<MultiplyTransformationTestValues> multiplyTransformationTestValues = {
{
LayerTransformation::createParamsU8I8(),
{
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, { 2.f }, { 10.f }}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, { 3.f }, { 7.f }}
},
},
{
{
{},
MultiplyTransformationTestValues::input_precision,
{{}, {{ 2.f }, ov::element::f32}, {}}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{{}, {{ 3.f }, ov::element::f32}, {}}
},
{{}, {}, {{ 70.f }, MultiplyTransformationTestValues::model_precision}}
}
},

{
LayerTransformation::createParamsU8I8(),
{
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, {}, { 10.f }}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, {}, { 7.f }}
}
},
{
{
{},
MultiplyTransformationTestValues::input_precision,
{}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{}
},
{{}, {}, {{ 70.f }, MultiplyTransformationTestValues::model_precision}}
}
},

{
LayerTransformation::createParamsU8I8(),
{
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, {{ 1.f, 2.f, 3.f }}, {{ 10.f, 11.f, 12.f }}}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, { 3.f }, { 7.f }}
}
},
{
{
{},
MultiplyTransformationTestValues::input_precision,
{{}, {{ 1.f, 2.f, 3.f }, ov::element::f32}, {}}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{{}, {{ 3.f }, ov::element::f32}, {}}
},
{{}, {}, {{70.f, 77.f, 84.f}, MultiplyTransformationTestValues::model_precision}}
}
},

{
LayerTransformation::createParamsU8I8(),
{
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, { 2.f }, { 10.f }}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, {}, { 7.f }}
}
},
{
{
{},
MultiplyTransformationTestValues::input_precision,
{{}, {{2.f}, ov::element::f32}, {}}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{}
},
{{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}}
}
},

{
LayerTransformation::createParamsU8I8(),
{
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, {}, { 10.f }}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, { 3.f }, { 7.f }}
}
},
{
{
{},
MultiplyTransformationTestValues::input_precision,
{}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{{}, {{3.f}, ov::element::f32}, {}}
},
{{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}}
}
},
};

INSTANTIATE_TEST_SUITE_P(
smoke_LPT,
MultiplyTransformation,
::testing::Combine(
::testing::ValuesIn(model_precisions),
::testing::ValuesIn(input_shapes),
::testing::ValuesIn(input_precisions),
::testing::ValuesIn(multiplyTransformationTestValues)),
MultiplyTransformation::getTestCaseName);
} // namespace broadcast_right

namespace broadcast_left {
const std::vector<std::pair<PartialShape, PartialShape>> input_shapes = {
{{ 1, 1, 1, 1 }, { 1, 3, 8, 16 }}
};

const std::vector<MultiplyTransformationTestValues> multiplyTransformationTestValues = {
{
LayerTransformation::createParamsU8I8(),
{
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, { 2.f }, { 10.f }}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, { 3.f }, { 7.f }}
},
},
{
{
{},
MultiplyTransformationTestValues::input_precision,
{{}, {{ 2.f }, ov::element::f32}, {}}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{{}, {{ 3.f }, ov::element::f32}, {}}
},
{{}, {}, {{ 70.f }, MultiplyTransformationTestValues::model_precision}}
}
},

{
LayerTransformation::createParamsU8I8(),
{
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, {}, { 10.f }}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, {}, { 7.f }}
}
},
{
{
{},
MultiplyTransformationTestValues::input_precision,
{}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{}
},
{{}, {}, {{ 70.f }, MultiplyTransformationTestValues::model_precision}}
}
},

{
LayerTransformation::createParamsU8I8(),
{
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, { 2.f }, { 10.f }}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, {{ 3.f, 4.f, 5.f }}, {{ 7.f, 8.f, 9.f }}}
}
},
{
{
{},
MultiplyTransformationTestValues::input_precision,
{{}, {{ 2.f }, ov::element::f32}, {}}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{{}, {{ 3.f, 4.f, 5.f }, ov::element::f32}, {}}
},
{{}, {}, {{70.f, 80.f, 90.f}, MultiplyTransformationTestValues::model_precision}}
}
},

{
LayerTransformation::createParamsU8I8(),
{
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, { 2.f }, { 10.f }}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, {}, { 7.f }}
}
},
{
{
{},
MultiplyTransformationTestValues::input_precision,
{{}, {{2.f}, ov::element::f32}, {}}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{}
},
{{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}}
}
},

{
LayerTransformation::createParamsU8I8(),
{
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, {}, { 10.f }}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{ov::element::f32, { 3.f }, { 7.f }}
}
},
{
{
{},
MultiplyTransformationTestValues::input_precision,
{}
},
{
{},
MultiplyTransformationTestValues::input_precision,
{{}, {{3.f}, ov::element::f32}, {}}
},
{{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}}
}
},
};

INSTANTIATE_TEST_SUITE_P(
smoke_LPT,
MultiplyTransformation,
::testing::Combine(
::testing::ValuesIn(model_precisions),
::testing::ValuesIn(input_shapes),
::testing::ValuesIn(input_precisions),
::testing::ValuesIn(multiplyTransformationTestValues)),
MultiplyTransformation::getTestCaseName);
} // namespace broadcast_left

} // namespace
Loading

0 comments on commit f1afdfc

Please sign in to comment.