Skip to content

Commit

Permalink
[LPT] Handle scale with convert in MatMulTransformation (openvinotool…
Browse files Browse the repository at this point in the history
…kit#26398)

### Details:
This PR handles Matmul dequantization subgraphs with convert on scales.

In general, LPT don't expect Convert on dequantization scale: the
convert is usually folded at the previous transformation pipeline steps.
However, such subgraphs may arise on Matmul weights (in CPU plugin)
since decompression handling logic keeps these subgraphs unchanged.
Since decompression handling logic concerns only MatMuls, the changes in
MatMul LPT is enough.

### Tickets:
 - *CVS-151589*
  • Loading branch information
v-Golubev authored Sep 5, 2024
1 parent 5ba9884 commit 20ee134
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 48 deletions.
13 changes: 13 additions & 0 deletions src/common/low_precision_transformations/src/mat_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,19 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context
return false;
}

// WA: LPT don't expect Convert on dequantization scale: the convert is usually folded at the previous transformation pipeline steps.
// However, such subgraphs may arise on Matmul weights since decompression handling logic keeps these subgraphs unchanged
// TODO: remove this logic when
// 1. CompressedMatmul is implemented
// 2. Or convert on scales is supported across the whole LPT pipeline
if (auto mul = ov::as_type_ptr<ov::op::v1::Multiply>(layer->get_input_node_shared_ptr(1))) {
if (auto convert = ov::as_type_ptr<ov::op::v0::Convert>(mul->get_input_node_shared_ptr(1))) {
if (auto constant = ov::as_type_ptr<ov::op::v0::Constant>(convert->get_input_node_shared_ptr(0))) {
auto new_constant = foldConvert(constant, convert->get_destination_type());
ov::replace_node_update_name(convert, new_constant);
}
}
}
const auto dequantization2 = NetworkHelper::getDequantization(layer, defaultPrecisions, 1);
if (!dequantization2.empty()) {
if ((updatePrecisions && !dequantization2.isLowPrecision())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,22 @@ std::vector<MatMulWithConstantTransformationTestValues> testValues = {
"FullyConnected",
"u8"
},
// 4D with Dq on weights, with convert on scales
{
{ 1, 1, 3, 4 },
{ 256ul, {{1, 1, 1}, {1, 1, 1}, {1, 3, 1}, {1, 3, 1}}, {0.f}, {255.f}, {0.f, 0.f, 0.f}, {255.f, 25.5f, 255.f} },
{ std::vector<float>(4 * 2, 2.f), ov::element::i8, ov::Shape{ 2, 4 } },
{},
{
ov::element::f32,
{},
ov::builder::subgraph::DequantizationOperations::Multiply({0.1f, 0.01f}, ov::element::f32, ov::Shape{ 2, 1 })
.setConstantPrecision(ov::element::f16)
.setAddConvert(true)
},
"FullyConnected",
"u8"
},
// 3D with the same values
{
{ 1, 3, 4 },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class DequantizationOperations {
isEmpty = true;
}
Subtract& setConstantPrecision(const ov::element::Type& precision);
Subtract& setAddConvert(bool value);

std::vector<float> values;
ov::element::Type outPrecision = ov::element::undefined;
Expand Down Expand Up @@ -81,20 +82,23 @@ class DequantizationOperations {
const ov::Shape& constantShape,
const bool toRemove = false,
const size_t constantIndex = 1ul,
const ov::element::Type constantPrecision = ov::element::undefined);
const ov::element::Type constantPrecision = ov::element::undefined,
const bool addConvert = false);
bool empty() const noexcept;
bool equal(const DequantizationOperations::Multiply& value) const noexcept;
bool operator==(const Multiply& value) const noexcept {
return equal(value);
}
Multiply& setConstantPrecision(const ov::element::Type& precision);
Multiply& setAddConvert(bool value);

std::vector<float> values;
ov::element::Type outPrecision = ov::element::undefined;
ov::Shape constantShape;
bool constantShapeIsDefined = false;
size_t constantIndex = 1ul;
ov::element::Type constantPrecision = ov::element::undefined;
bool addConvert = false;

private:
bool isEmpty;
Expand Down
71 changes: 27 additions & 44 deletions src/tests/ov_helpers/ov_lpt_models/src/common/builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,13 @@ std::shared_ptr<Node> makeDequantization(
(((dequantizationOperations.subtract.constantPrecision == ov::element::undefined) ||
(dequantizationOperations.subtract.constantPrecision == parent.get_element_type())) ||
dequantizationOperations.subtract.addConvert)) {
subtract = dequantizationOperations.subtract.constantIndex == 1ul ?
std::make_shared<ov::opset1::Subtract>(parent, subtractConst) :
subtract = std::make_shared<ov::opset1::Subtract>(subtractConst, parent);
subtract = std::make_shared<ov::opset1::Subtract>(leftBranchParent, rightBranchParent);
} else {
if (dequantizationOperations.subtract.constantIndex == 1ul) {
subtract = std::make_shared<ov::op::TypeRelaxed<ov::opset1::Subtract>>(
std::vector<ov::element::Type>{ov::element::f32, ov::element::f32},
std::vector<ov::element::Type>{ov::element::f32},
ov::op::TemporaryReplaceOutputType(parent, ov::element::f32).get(),
ov::op::TemporaryReplaceOutputType(subtractConst, ov::element::f32).get());
} else {
subtract = std::make_shared<ov::op::TypeRelaxed<ov::opset1::Subtract>>(
std::vector<ov::element::Type>{ov::element::f32, ov::element::f32},
std::vector<ov::element::Type>{ov::element::f32},
ov::op::TemporaryReplaceOutputType(subtractConst, ov::element::f32).get(),
ov::op::TemporaryReplaceOutputType(parent, ov::element::f32).get());
}

subtract = std::make_shared<ov::op::TypeRelaxed<ov::opset1::Subtract>>(
std::vector<ov::element::Type>{ov::element::f32, ov::element::f32},
std::vector<ov::element::Type>{ov::element::f32},
ov::op::TemporaryReplaceOutputType(leftBranchParent, ov::element::f32).get(),
ov::op::TemporaryReplaceOutputType(rightBranchParent, ov::element::f32).get());
ov::pass::low_precision::NetworkHelper::setOutDataPrecision(subtract, dequantizationOperations.subtract.outPrecision);
}

Expand Down Expand Up @@ -143,38 +132,32 @@ std::shared_ptr<Node> makeMultiply(const ov::Output<Node>& parent, const Dequant
}
}

std::shared_ptr<Node> constant = std::make_shared<ov::opset1::Constant>(
multiply.constantPrecision != ov::element::undefined ? multiply.constantPrecision : parent.get_element_type(),
shape,
values);
if (multiply.addConvert) {
constant = std::make_shared<ov::opset1::Convert>(
constant,
multiply.outPrecision == ov::element::undefined ? parent.get_element_type() : multiply.outPrecision);
}

ov::Output<Node> leftBranchParent = multiply.constantIndex == 1 ? parent : constant;
ov::Output<Node> rightBranchParent = multiply.constantIndex == 1 ? constant : parent;

std::shared_ptr<ov::opset1::Multiply> newMultiply;
if (((multiply.outPrecision == ov::element::undefined) || (multiply.outPrecision == parent.get_element_type())) &&
((multiply.constantPrecision == ov::element::undefined) ||
(multiply.constantPrecision == parent.get_element_type()))) {
const std::shared_ptr<ov::opset1::Constant> constant = std::make_shared<ov::opset1::Constant>(
multiply.constantPrecision != ov::element::undefined ? multiply.constantPrecision
: parent.get_element_type(),
shape,
values);

newMultiply = multiply.constantIndex == 1ul ?
std::make_shared<ov::opset1::Multiply>(parent, constant) :
std::make_shared<ov::opset1::Multiply>(constant, parent);
(multiply.constantPrecision == parent.get_element_type())) ||
multiply.addConvert) {
newMultiply = std::make_shared<ov::opset1::Multiply>(leftBranchParent, rightBranchParent);
} else {
const std::shared_ptr<ov::opset1::Constant> constant = std::make_shared<ov::opset1::Constant>(
multiply.constantPrecision != ov::element::undefined ? multiply.constantPrecision
: parent.get_element_type(),
shape,
values);

// TODO: use templates
newMultiply = multiply.constantIndex == 1ul
? std::make_shared<ov::op::TypeRelaxed<ov::opset1::Multiply>>(
std::vector<ov::element::Type>{ov::element::f32, ov::element::f32},
std::vector<ov::element::Type>{multiply.outPrecision},
ov::op::TemporaryReplaceOutputType(parent, ov::element::f32).get(),
ov::op::TemporaryReplaceOutputType(constant, ov::element::f32).get())
: std::make_shared<ov::op::TypeRelaxed<ov::opset1::Multiply>>(
std::vector<ov::element::Type>{ov::element::f32, ov::element::f32},
std::vector<ov::element::Type>{multiply.outPrecision},
ov::op::TemporaryReplaceOutputType(constant, ov::element::f32).get(),
ov::op::TemporaryReplaceOutputType(parent, ov::element::f32).get());
newMultiply = std::make_shared<ov::op::TypeRelaxed<ov::opset1::Multiply>>(
std::vector<ov::element::Type>{ov::element::f32, ov::element::f32},
std::vector<ov::element::Type>{multiply.outPrecision},
ov::op::TemporaryReplaceOutputType(leftBranchParent, ov::element::f32).get(),
ov::op::TemporaryReplaceOutputType(rightBranchParent, ov::element::f32).get());
}

return newMultiply;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ DequantizationOperations::Subtract& DequantizationOperations::Subtract::setConst
return *this;
}

DequantizationOperations::Subtract& DequantizationOperations::Subtract::setAddConvert(bool value) {
addConvert = value;
return *this;
}

DequantizationOperations::Multiply::Multiply() :
isEmpty(true),
outPrecision(ov::element::undefined),
Expand Down Expand Up @@ -129,13 +134,15 @@ DequantizationOperations::Multiply::Multiply(
const ov::Shape& constantShape,
const bool toRemove,
const size_t constantIndex,
ov::element::Type constantPrecision) :
ov::element::Type constantPrecision,
const bool addConvert) :
isEmpty(false),
values(values),
outPrecision(outPrecision),
constantShape(constantShape),
constantIndex(constantIndex),
constantPrecision(constantPrecision),
addConvert(addConvert),
constantShapeIsDefined(true) {
}

Expand Down Expand Up @@ -166,6 +173,11 @@ DequantizationOperations::Multiply& DequantizationOperations::Multiply::setConst
return *this;
}

DequantizationOperations::Multiply& DequantizationOperations::Multiply::setAddConvert(bool value) {
addConvert = value;
return *this;
}

DequantizationOperations::DequantizationOperations() {}

DequantizationOperations::DequantizationOperations(
Expand All @@ -179,9 +191,11 @@ DequantizationOperations::DequantizationOperations(

void DequantizationOperations::setPrecision(const ov::element::Type& type) noexcept {
convert.outPrecision = type;
subtract.constantPrecision = type;
if (!subtract.addConvert)
subtract.constantPrecision = type;
subtract.outPrecision = type;
multiply.constantPrecision = type;
if (!multiply.addConvert)
multiply.constantPrecision = type;
multiply.outPrecision = type;
}

Expand Down

0 comments on commit 20ee134

Please sign in to comment.