Skip to content

Commit

Permalink
[LPT] EliminateFakeQuantizeTransformation: bf16 support (openvinotool…
Browse files Browse the repository at this point in the history
…kit#24755)

### Details:
 - *[LPT] EliminateFakeQuantizeTransformation: bf16 support*

### Tickets:
 - *CVS-138910*
  • Loading branch information
eshoguli authored May 29, 2024
1 parent b16108d commit 45e219b
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class LP_TRANSFORMATIONS_API DataPrecision {
}
}

// the lowest value (example, for signed symetric types: -max)
static float getMinValue(const element::Type precision, const size_t levels) {
switch (precision) {
case element::u4:
Expand Down Expand Up @@ -134,6 +135,8 @@ class LP_TRANSFORMATIONS_API DataPrecision {
break;
case element::f16:
return -1.0e15f;
case element::bf16:
return -3.38953139e38f;
case element::f32:
return std::numeric_limits<float>::lowest();
default:
Expand Down Expand Up @@ -172,6 +175,8 @@ class LP_TRANSFORMATIONS_API DataPrecision {
return 2147483648.f; // 2147483648.f == 2147483647.f
case element::f16:
return 1.0e15f;
case element::bf16:
return 3.38953139e38f;
case element::f32:
return std::numeric_limits<float>::max();
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,13 @@ bool check_interval(const std::shared_ptr<ov::opset1::FakeQuantize>& fq,
bool check_intervals(const std::shared_ptr<ov::opset1::FakeQuantize>& fakeQuantize) {
const auto& element_type = fakeQuantize->get_output_element_type(0);
const auto levels = fakeQuantize->get_levels();
if (levels == 0) {
return false;
}
const auto min_value = DataPrecision::getMinValue(element_type, levels);
const auto max_value = DataPrecision::getMaxValue(element_type, levels);
const auto max_diff = (max_value - min_value) / levels;
// let's divide before to avoid overflow
const auto max_diff = max_value / levels - min_value / levels;
// input intervals can be not equal with type intervals for low precision only
const auto exact_comparison = !element_type.is_integral();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@ class TransformationTestValues {
public:
class Actual {
public:
ov::element::Type precisionBefore;
ov::builder::subgraph::FakeQuantizeOnData fakeQuantizeOnData1;
ov::builder::subgraph::FakeQuantizeOnData fakeQuantizeOnData2;
};

class Expected {
public:
ov::element::Type precisionBefore;
ov::builder::subgraph::FakeQuantizeOnData fakeQuantizeOnData1;
ov::builder::subgraph::FakeQuantizeOnData fakeQuantizeOnData2;
ov::builder::subgraph::DequantizationOperations dequantizationOperations2;
Expand All @@ -44,17 +42,28 @@ class TransformationTestValues {
Expected expected;
};

typedef std::tuple<
ov::element::Type,
TransformationTestValues
> EliminateFakeQuantizeTransformationParams;

class EliminateFakeQuantizeTransformation : public LayerTransformation,
public testing::WithParamInterface<TransformationTestValues> {
public testing::WithParamInterface<EliminateFakeQuantizeTransformationParams> {
public:
void SetUp() override {
const TransformationTestValues testValues = GetParam();
const ov::element::Type execPrecision = std::get<0>(GetParam());
TransformationTestValues testValues = std::get<1>(GetParam());

if (!testValues.expected.dequantizationOperations2.multiply.empty()) {
testValues.expected.dequantizationOperations2.multiply.outPrecision = execPrecision;
}

actualFunction = ov::builder::subgraph::FuseFakeQuantizeFunction::get(testValues.inputShape,
testValues.actual.precisionBefore,
testValues.actual.fakeQuantizeOnData1,
testValues.actual.fakeQuantizeOnData2,
{});
execPrecision,
testValues.actual.fakeQuantizeOnData1,
testValues.actual.fakeQuantizeOnData2,
{});

SimpleLowPrecisionTransformer transformer;
transformer.add<ov::pass::low_precision::FakeQuantizeDecompositionTransformation, ov::op::v0::FakeQuantize>(
testValues.params);
Expand All @@ -67,20 +76,28 @@ class EliminateFakeQuantizeTransformation : public LayerTransformation,

referenceFunction =
ov::builder::subgraph::FuseFakeQuantizeFunction::get(testValues.inputShape,
testValues.expected.precisionBefore,
testValues.expected.fakeQuantizeOnData1,
testValues.expected.fakeQuantizeOnData2,
testValues.expected.dequantizationOperations2);
execPrecision,
testValues.expected.fakeQuantizeOnData1,
testValues.expected.fakeQuantizeOnData2,
testValues.expected.dequantizationOperations2);

}

static std::string getTestCaseName(testing::TestParamInfo<TransformationTestValues> obj) {
const TransformationTestValues testValues = obj.param;
static std::string getTestCaseName(testing::TestParamInfo<EliminateFakeQuantizeTransformationParams> obj) {
const ov::element::Type execPrecision = std::get<0>(obj.param);
TransformationTestValues testValues = std::get<1>(obj.param);

if (!testValues.expected.dequantizationOperations2.multiply.empty()) {
testValues.expected.dequantizationOperations2.multiply.outPrecision = execPrecision;
}

std::ostringstream result;
result << testValues.inputShape << "_" << testValues.params.updatePrecisions << "_"
<< testValues.actual.precisionBefore << "_" << testValues.actual.fakeQuantizeOnData1 << "_"
<< testValues.actual.fakeQuantizeOnData2 << "_" << testValues.expected.precisionBefore << "_"
<< testValues.expected.fakeQuantizeOnData1 << "_" << testValues.expected.fakeQuantizeOnData2 << "_"
<< execPrecision << "_"
<< testValues.actual.fakeQuantizeOnData1 << "_"
<< testValues.actual.fakeQuantizeOnData2 << "_"
<< testValues.expected.fakeQuantizeOnData1 << "_"
<< testValues.expected.fakeQuantizeOnData2 << "_"
<< testValues.expected.dequantizationOperations2;
return result.str();
}
Expand All @@ -100,12 +117,10 @@ const std::vector<TransformationTestValues> testValues = {
{1, 3, 16, 16},
TestTransformationParams(true, {ov::element::u8}, {ov::element::i8}),
{
element::f32,
{256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}},
{256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}
},
{
element::f32,
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, element::u8},
{},
{ ov::element::f32, {}, {{0.01f}, ov::element::f32, {}} }
Expand All @@ -115,12 +130,10 @@ const std::vector<TransformationTestValues> testValues = {
{1, 3, 16, 16},
TestTransformationParams(true, {ov::element::u8}, {ov::element::i8}),
{
element::f32,
{256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}},
{256ul, {}, {0.f}, {2.549f}, {0.f}, {2.55f}}
},
{
element::f32,
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, element::u8},
{},
{ ov::element::f32, {}, {{0.01f}, ov::element::f32, {}} }
Expand All @@ -130,27 +143,35 @@ const std::vector<TransformationTestValues> testValues = {
{1, 3, 16, 16},
TestTransformationParams(true, {ov::element::u8}, {ov::element::i8}),
{
element::f32,
{256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}},
{256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f / 2.f}}
},
{
element::f32,
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, element::u8},
{},
{ ov::element::f32, {}, {{0.005f}, ov::element::f32, {}} }
}
},
}
};
// clang-format on

INSTANTIATE_TEST_SUITE_P(smoke_LPT,
EliminateFakeQuantizeTransformation,
::testing::Combine(
::testing::ValuesIn({ov::element::f32, ov::element::bf16}),
::testing::ValuesIn(testValues)),
EliminateFakeQuantizeTransformation::getTestCaseName);

// clang-format off
const std::vector<TransformationTestValues> testValuesDiffFq = {
{
{1, 3, 16, 16},
TestTransformationParams(true, {ov::element::u8}, {ov::element::i8}),
{
element::f32,
{256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}},
{256ul, {}, {0.f}, {2.55f / 2.f}, {0.f}, {2.55f / 2.f}}
},
{
element::f32,
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, element::u8},
{256ul, {}, {0.f}, {127.5f}, {0.f}, {255.f}, element::u8},
{ ov::element::f32, {}, {{0.005f}, ov::element::f32, {}} }
Expand All @@ -159,9 +180,11 @@ const std::vector<TransformationTestValues> testValues = {
};
// clang-format on

INSTANTIATE_TEST_SUITE_P(smoke_LPT,
INSTANTIATE_TEST_SUITE_P(smoke_LPT_DiffFq,
EliminateFakeQuantizeTransformation,
::testing::ValuesIn(testValues),
::testing::Combine(
::testing::ValuesIn({ov::element::f32}),
::testing::ValuesIn(testValuesDiffFq)),
EliminateFakeQuantizeTransformation::getTestCaseName);

} // namespace
13 changes: 7 additions & 6 deletions src/tests/ov_helpers/ov_lpt_models/src/fuse_fake_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,24 @@ std::shared_ptr<ov::Model> FuseFakeQuantizeFunction::getOriginal(
namespace {
std::shared_ptr<ov::opset1::Convolution> make_convolution(
const ov::PartialShape& inputShape,
const ov::element::Type precisionBefore,
const ov::element::Type precisionData,
const ov::element::Type precisionWeights,
const std::shared_ptr<Node>& parent,
const size_t index) {
const ov::Shape shape = inputShape.to_shape();
const ov::Shape weightsShape({ shape[1], shape[1], 1ull, 1ull });
auto weightsConstant = std::make_shared<ov::op::v0::Constant>(ov::element::f32, weightsShape, std::vector<float>(9, 1.f));
auto weightsConstant = std::make_shared<ov::op::v0::Constant>(precisionWeights, weightsShape, std::vector<float>(9, 1.f));
auto weights = makeFakeQuantize(
weightsConstant,
precisionBefore,
precisionData,
FakeQuantizeOnData(
255,
ov::Shape({ shape[1], 1ull, 1ull, 1ull }),
{ -1.27f, -1.27f, -1.27f },
{ 1.28f, 1.28f, 1.28f },
{ -1.27f, -1.27f, -1.27f },
{ 1.28f, 1.28f, 1.28f },
precisionBefore));
precisionData));

auto convolution = std::make_shared<ov::opset1::Convolution>(
parent,
Expand Down Expand Up @@ -160,8 +161,8 @@ std::shared_ptr<ov::Model> FuseFakeQuantizeFunction::get(
}

ov::ResultVector results{
std::make_shared<ov::opset1::Result>(make_convolution(inputShape, precisionBefore, parent, 0)),
std::make_shared<ov::opset1::Result>(make_convolution(inputShape, precisionBefore, parent, 1))
std::make_shared<ov::opset1::Result>(make_convolution(inputShape, precisionBefore, precisionBefore, parent, 0)),
std::make_shared<ov::opset1::Result>(make_convolution(inputShape, precisionBefore, precisionBefore, parent, 1))
};
return std::make_shared<ov::Model>(results, ov::ParameterVector{ input }, "FuseFakeQuantizeFunction");
}
Expand Down

0 comments on commit 45e219b

Please sign in to comment.