Skip to content

Commit

Permalink
[LPT] separateInStandaloneBranch fix (openvinotoolkit#26333)
Browse files Browse the repository at this point in the history
### Details:
- *Added `normalizeDequantization` in `separateInStandaloneBranch`: it
is needed since `separateInStandaloneBranch` works correctly only with
normalized dequantization*
- *Handle all inputs (not only 0's as before) in
`separateInStandaloneBranch`*
- *Removed `normalizeDequantization` in transformations since this logic
is placed inside `separateInStandaloneBranch` now*

### Tickets:
 - *CVS-150001*
  • Loading branch information
v-Golubev authored Sep 5, 2024
1 parent 20ee134 commit 8a604a6
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 31 deletions.
3 changes: 0 additions & 3 deletions src/common/low_precision_transformations/src/add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ bool AddTransformation::transform(TransformationContext& context, ov::pass::patt
return false;
}

NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(op, defaultPrecisions, 0));
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(op, defaultPrecisions, 1));

std::shared_ptr<Node> addNode = NetworkHelper::separateInStandaloneBranch(op, defaultPrecisions);
std::shared_ptr<ov::opset1::Add> add = ov::as_type_ptr<ov::opset1::Add>(addNode);

Expand Down
4 changes: 2 additions & 2 deletions src/common/low_precision_transformations/src/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ ConcatTransformation::ConcatTransformation(const Params& params) : LayerTransfor
}

bool ConcatTransformation::transform(TransformationContext& context, ov::pass::pattern::Matcher &m) {
std::shared_ptr<ov::opset1::Concat> concat = ov::as_type_ptr<ov::opset1::Concat>(m.get_match_root());
if (!canBeTransformed(context, concat)) {
if (!canBeTransformed(context, m.get_match_root())) {
return false;
}

const auto concat = ov::as_type_ptr<ov::opset1::Concat>(NetworkHelper::separateInStandaloneBranch(m.get_match_root(), defaultPrecisions));
std::vector<FakeQuantizeDequantization> layerDequantizations;
layerDequantizations.reserve(concat->get_input_size());
for (size_t parentIndex = 0ul; parentIndex < concat->get_input_size(); parentIndex++) {
Expand Down
3 changes: 0 additions & 3 deletions src/common/low_precision_transformations/src/multiply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass:
multiply = NetworkHelper::separateInStandaloneBranch(multiply, defaultPrecisions);
decomposeFakeQuantizeForWeightsPath(multiply);

NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 0));
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 1));

const auto dequantization1 = NetworkHelper::getDequantization(multiply, defaultPrecisions, 0);
const auto dequantization2 = NetworkHelper::getDequantization(multiply, defaultPrecisions, 1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ bool MultiplyPartialTransformation::transform(TransformationContext& context, ov
return false;
}

NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 0));
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 1));

multiply = NetworkHelper::separateInStandaloneBranch(multiply, defaultPrecisions);
auto newMultiply = multiply;

Expand Down
6 changes: 1 addition & 5 deletions src/common/low_precision_transformations/src/mvn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,7 @@ bool MVNTransformation::transform(TransformationContext &context, ov::pass::patt
return false;
}

std::shared_ptr<Node> mvn = ov::as_type_ptr<op::v0::MVN>(operation);
if (!mvn) {
mvn = ov::as_type_ptr<opset6::MVN>(operation);
}

const auto mvn = NetworkHelper::separateInStandaloneBranch(operation, defaultPrecisions);
bool normalizeVariance;
if (ov::is_type<op::v0::MVN>(mvn)) {
normalizeVariance = ov::as_type_ptr<op::v0::MVN>(mvn)->get_normalize_variance();
Expand Down
29 changes: 19 additions & 10 deletions src/common/low_precision_transformations/src/network_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,8 +585,12 @@ FakeQuantizeDequantization NetworkHelper::foldDequantization(const std::shared_p

std::shared_ptr<ov::Node> NetworkHelper::separateInStandaloneBranch(std::shared_ptr<ov::Node> node,
const std::vector<ov::element::Type>& defaultPrecisions) {
FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(node, defaultPrecisions);
if (dequantization.isShared() && !dequantization.empty()) {
auto inputs = node->input_values();
auto separate_branch = [&](size_t input_idx) {
const auto dequantization = NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(node, defaultPrecisions, input_idx));
if (dequantization.empty() || !dequantization.isShared())
return false;

Output<Node> parent = dequantization.data;
if (dequantization.convert != nullptr) {
auto convert = dequantization.convert->clone_with_new_inputs({ parent });
Expand Down Expand Up @@ -619,22 +623,27 @@ std::shared_ptr<ov::Node> NetworkHelper::separateInStandaloneBranch(std::shared_
parent = multiply->output(0);
}

std::vector<Output<Node>> inputs = node->input_values();
const auto originalParent = dequantization.multiply ?
dequantization.multiply->shared_from_this() :
dequantization.subtract->shared_from_this();

const size_t inputIndex = NetworkHelper::getChildInputIndex(originalParent, node);
inputs[inputIndex] = parent;
const std::shared_ptr<Node> newNode = node->clone_with_new_inputs(inputs);
copy_runtime_info(node, newNode);
replace_node(node, newNode);
newNode->set_friendly_name(node->get_friendly_name());
return true;
};

return newNode;
}
bool branch_separation_happened = false;
for (size_t i = 0; i < node->get_input_size(); ++i)
branch_separation_happened |= separate_branch(i);

if (!branch_separation_happened)
return node;

return node;
const auto newNode = node->clone_with_new_inputs(inputs);
copy_runtime_info(node, newNode);
replace_node(node, newNode);
newNode->set_friendly_name(node->get_friendly_name());
return newNode;
}

std::shared_ptr<ov::opset1::FakeQuantize> NetworkHelper::fuseConvert(const std::shared_ptr<ov::opset1::FakeQuantize>& fakeQuantize) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ bool ReduceBaseTransformation::transform(TransformationContext& context, ov::pas
}

const auto reduce = NetworkHelper::separateInStandaloneBranch(m.get_match_root(), defaultPrecisions);
auto dequantization = NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(reduce, defaultPrecisions));
auto dequantization = NetworkHelper::getDequantization(reduce, defaultPrecisions);

// prepare dequantization to propagate
changeDequantizationValues(reduce, dequantization);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,13 @@ class SeparateInStandaloneBranchTransformation :
const auto createReferenceFunction = [](
const ov::element::Type precision,
const ov::Shape& inputShape,
const ov::builder::subgraph::DequantizationOperations& dequantization) -> std::shared_ptr<ov::Model> {
ov::builder::subgraph::DequantizationOperations dequantization) -> std::shared_ptr<ov::Model> {
// Note: separateInStandaloneBranch normalizes dequantization so constant indexes become equal to 1
if (!dequantization.subtract.empty())
dequantization.subtract.constantIndex = 1;
if (!dequantization.multiply.empty())
dequantization.multiply.constantIndex = 1;

const std::shared_ptr<ov::op::v0::Parameter> input = std::make_shared<ov::op::v0::Parameter>(precision, inputShape);
const auto relu = std::make_shared<ov::op::v0::Relu>(input);

Expand Down Expand Up @@ -118,7 +124,7 @@ class SeparateInStandaloneBranchTransformation :
std::tie(shapes, testValues) = obj.param;

std::stringstream ss;
ss << shapes << "_" << "_" << testValues;
ss << shapes << "_" << testValues;
return ss.str();
}
};
Expand All @@ -133,7 +139,6 @@ TEST_P(SeparateInStandaloneBranchTransformation, CompareFunctions) {

const std::vector<ov::Shape> shapes = {
{ 1, 3, 9, 9 },
{ 4, 3, 9, 9 }
};

std::vector<SeparateInStandaloneBranchTransformationTestValues> testValues = {
Expand All @@ -155,7 +160,16 @@ std::vector<SeparateInStandaloneBranchTransformationTestValues> testValues = {
{ {127.f}, ov::element::f32, {}, true, 1ul, ov::element::u8, true},
{ 0.02f }
}
}
},
{
LayerTransformation::createParamsU8U8(),
ov::element::u8,
{
ov::element::f32,
{ {127.f}, ov::element::f32, {}, false, 0ul},
{ {0.02f}, ov::element::f32, {}, false, 0ul }
}
},
};

INSTANTIATE_TEST_SUITE_P(
Expand Down

0 comments on commit 8a604a6

Please sign in to comment.