Skip to content

Commit

Permalink
[LPT] matMul 3D: support Q/DQ on weights
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Feb 10, 2021
1 parent 249e8f7 commit 3382e11
Showing 1 changed file with 38 additions and 38 deletions.
76 changes: 38 additions & 38 deletions inference-engine/src/low_precision_transformations/src/mat_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ using namespace ngraph::pass;
using namespace ngraph::pass::low_precision;

bool MatMulTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
std::shared_ptr<ngraph::opset1::MatMul> matMul = as_type_ptr<ngraph::opset1::MatMul>(m.get_match_root());
std::shared_ptr<opset1::MatMul> matMul = as_type_ptr<opset1::MatMul>(m.get_match_root());
if ((matMul == nullptr) || !canBeTransformed(context, matMul)) {
return false;
}

matMul = as_type_ptr<ngraph::opset1::MatMul>(NetworkHelper::separateInStandaloneBranch(matMul));
matMul = as_type_ptr<opset1::MatMul>(NetworkHelper::separateInStandaloneBranch(matMul));

const auto dequantization1 = NetworkHelper::getDequantization(matMul, 0);
auto dequantization2 = NetworkHelper::getDequantization(matMul, 1);

FakeQuantizeDequantization dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1);
if (dequantization2.empty()) {
const std::shared_ptr<opset1::FakeQuantize> fakeQuantize =
as_type_ptr<opset1::FakeQuantize>(dequantization2.data.get_node_shared_ptr());
Expand All @@ -40,21 +42,19 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat
dataPrecision.hasZeroPoint,
updatePrecisions);

dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1);
dequantization2 = NetworkHelper::getDequantization(matMul, 1);
}
}

const FakeQuantizeDequantization dequantization1 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 0);

if (dequantization2.subtract != nullptr) {
NetworkHelper::optimizeSubtract(dequantization2.subtract);
dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1);
dequantization2 = NetworkHelper::getDequantization(matMul, 1);
}

const std::shared_ptr<opset1::MatMul> newMatMul = std::make_shared<ngraph::op::TypeRelaxed<opset1::MatMul>>(
const std::shared_ptr<opset1::MatMul> newMatMul = std::make_shared<op::TypeRelaxed<opset1::MatMul>>(
std::vector<element::Type>({ element::f32, element::f32 }), std::vector<element::Type>({}),
ngraph::op::TemporaryReplaceOutputType(dequantization1.data, element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(dequantization2.data, element::f32).get(),
op::TemporaryReplaceOutputType(dequantization1.data, element::f32).get(),
op::TemporaryReplaceOutputType(dequantization2.data, element::f32).get(),
matMul->get_transpose_a(),
matMul->get_transpose_b());
NetworkHelper::setOutDataPrecisionForTypeRelaxed(newMatMul, matMul->get_output_element_type(0));
Expand All @@ -64,15 +64,15 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat

// dequantization with subtract on activations & constant weights
if (dequantization1.subtract) {
auto broadcastShape = NetworkHelper::isScalarLike(as_type_ptr<opset1::Constant>(dequantization1.subtract->get_input_node_shared_ptr(1))) ?
ngraph::Shape(dequantization1.subtract->get_shape().size(), 1) :
dequantization1.subtract->get_input_node_shared_ptr(1)->get_shape();
auto broadcastShape = NetworkHelper::isScalarLike(as_type_ptr<opset1::Constant>(dequantization1.subtractConstant)) ?
Shape(dequantization1.subtract->get_shape().size(), 1) :
dequantization1.subtractConstant->get_shape();
const size_t lastIdx = matMul->get_transpose_a() ? broadcastShape.size() - 2 : broadcastShape.size() - 1;
broadcastShape[lastIdx] = dequantization1.subtract->get_shape()[lastIdx];

// broadcasted sub const to form [1, ..., 1, Y]
const auto broadcastedConst = fold<opset1::Broadcast>(
dequantization1.subtract->get_input_node_shared_ptr(1),
dequantization1.subtractConstant,
opset1::Constant::create(ngraph::element::i32, { broadcastShape.size() }, broadcastShape));

// multiply by weights: [1, ..., 1, Y] x [Y, Z] => [1, ..., 1, Z]
Expand All @@ -84,7 +84,7 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat

const auto newSubtract = std::make_shared<DequantizationSubtract>(newMatMul, newSubConst);
newSubtract->set_friendly_name(newMatMul->get_friendly_name() + "/DequantizationSubtract");
ngraph::copy_runtime_info({ newSubtract, matMul }, newSubtract);
copy_runtime_info({ newSubtract, matMul }, newSubtract);

parent = newSubtract;
}
Expand All @@ -100,17 +100,12 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat
std::swap(*(transposeConstant.end() - 1), *(transposeConstant.end() - 2));

auto order = opset1::Constant::create(element::u32, Shape{ transposeConstant.size() }, transposeConstant);
std::shared_ptr<Node> transposedConstant = fold<ngraph::opset1::Transpose>(node, order);
std::shared_ptr<Node> transposedConstant = fold<opset1::Transpose>(node, order);
return transposedConstant;
};

const auto mulConst1 = matMul->get_transpose_a() ?
transpose(dequantization1.multiply->get_input_node_shared_ptr(1)) :
dequantization1.multiply->get_input_node_shared_ptr(1);

auto mulConst2 = matMul->get_transpose_b() ?
transpose(dequantization2.multiply->get_input_node_shared_ptr(1)) :
dequantization2.multiply->get_input_node_shared_ptr(1);
const auto mulConst1 = matMul->get_transpose_a() ? transpose(dequantization1.multiplyConstant) : dequantization1.multiplyConstant;
auto mulConst2 = matMul->get_transpose_b() ? transpose(dequantization2.multiplyConstant) : dequantization2.multiplyConstant;

if (NetworkHelper::isScalarLike(as_type_ptr<opset1::Constant>(mulConst2))) {
mulConst2 = NetworkHelper::toScalar(as_type_ptr<opset1::Constant>(mulConst2));
Expand All @@ -125,16 +120,16 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat

mulConst2 = fold<opset1::Unsqueeze>(
mulConst2,
op::Constant::create(ngraph::element::i32, Shape{ unsqueezeConstantShape.size() }, unsqueezeConstantShape));
op::Constant::create(element::i32, Shape{ unsqueezeConstantShape.size() }, unsqueezeConstantShape));
}
}

const auto newMulConst = NetworkHelper::toScalarIfPossible(fold<ngraph::opset1::Multiply>(mulConst1, mulConst2));
const auto newMulConst = NetworkHelper::toScalarIfPossible(fold<opset1::Multiply>(mulConst1, mulConst2));
const std::shared_ptr<opset1::Multiply> newMultiply = std::make_shared<DequantizationMultiply>(parent, newMulConst);
newMultiply->set_friendly_name(newMatMul->get_friendly_name() + "/DequantizationMultiply");

replace_node(matMul, newMultiply);
ngraph::copy_runtime_info({ newMultiply, matMul }, newMultiply);
copy_runtime_info({ newMultiply, matMul }, newMultiply);

updateOutput(context, newMultiply, matMul);

Expand All @@ -145,12 +140,12 @@ void MatMulTransformation::registerMatcherIn(GraphRewrite& pass, TransformationC
addPattern(
pass,
context,
make_op_pattern<opset1::MatMul>({ make_op_label<ngraph::opset1::Multiply>(), make_op_label<ngraph::opset1::Multiply>() }));
make_op_pattern<opset1::MatMul>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::Multiply>() }));

addPattern(
pass,
context,
make_op_pattern<opset1::MatMul>({ make_op_label<ngraph::opset1::Multiply>(), make_op_label<ngraph::opset1::FakeQuantize>() }));
make_op_pattern<opset1::MatMul>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::FakeQuantize>() }));
}

bool MatMulTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
Expand All @@ -167,15 +162,14 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context
return false;
}

const auto dequantization1 = ngraph::pass::low_precision::NetworkHelper::getDequantization(layer);
const auto dequantization1 = NetworkHelper::getDequantization(layer, 0);
if (!dequantization1.empty()) {
if (updatePrecisions && !dequantization1.isLowPrecision()) {
return false;
}

const auto mulConst = as_type_ptr<opset1::Constant>(dequantization1.multiply->get_input_node_shared_ptr(1));
if (!NetworkHelper::isScalarLike(mulConst)) {
const auto constantShape = mulConst->get_shape();
if (!NetworkHelper::isScalarLike(dequantization1.multiplyConstant)) {
const auto constantShape = dequantization1.multiplyConstant->get_shape();
const auto mulShape = dequantization1.multiply->get_shape();
const size_t columnsIdx = matMul->get_transpose_a() ? mulShape.size() - 2ul : mulShape.size() - 1ul;

Expand All @@ -186,15 +180,21 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context
}
}

const auto dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(layer, 1);
const auto dequantization2 = NetworkHelper::getDequantization(layer, 1);
if (!dequantization2.empty()) {
if ((updatePrecisions && !dequantization2.isLowPrecision()) || (dequantization2.subtract)) {
if ((updatePrecisions && !dequantization2.isLowPrecision())) {
return false;
}

const auto mulConst = as_type_ptr<opset1::Constant>(dequantization2.multiply->get_input_node_shared_ptr(1));
if (!NetworkHelper::isScalarLike(mulConst)) {
const auto constantShape = mulConst->get_shape();
if (dequantization2.subtract) {
const auto roundedConst = NetworkHelper::round(dequantization2.subtractConstant, dequantization2.data.get_element_type());
if (!NetworkHelper::isZeroConst(roundedConst)) {
return false;
}
}

if (!NetworkHelper::isScalarLike(dequantization2.multiplyConstant)) {
const auto constantShape = dequantization2.multiplyConstant->get_shape();
const auto mulShape = dequantization2.multiply->get_shape();
const size_t rowsIdx = matMul->get_transpose_b() ? mulShape.size() - 1ul : mulShape.size() - 2ul;

Expand Down Expand Up @@ -229,7 +229,7 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context
}
}

if (fakeQuantize == nullptr && dequantization1.subtract) {
if ((!NetworkHelper::isConstantPath(layer->get_input_node_shared_ptr(1))) && (dequantization1.subtract)) {
return false;
}

Expand Down

0 comments on commit 3382e11

Please sign in to comment.