Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ndemashov committed Apr 4, 2022
1 parent 90858bd commit 7634c44
Show file tree
Hide file tree
Showing 13 changed files with 164 additions and 314 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,6 @@ class LP_TRANSFORMATIONS_API NetworkHelper {
float& updatedOutputLowValue,
float& updatedOutputHighValue);

static std::shared_ptr<ov::Node> fakeQuantizeWraper
(const std::shared_ptr<ov::Node> parameter);

private:
static std::shared_ptr<Node> foldFakeQuantize(
const std::shared_ptr<opset1::FakeQuantize>& fq,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class LP_TRANSFORMATIONS_API RecurrentCellTransformation : public LayerTransform
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
void propagateSkipCleanupAttribute(std::shared_ptr<Node> dequantization_multiply);
static std::shared_ptr<ov::Node> wrap_fake_quantize(const std::shared_ptr<ov::Node> parameter);
static std::shared_ptr<ov::Node> wrap_quantization(const std::shared_ptr<ov::Node> parameter);
static std::shared_ptr<ov::Node> wrap_dequantization(const std::shared_ptr<ov::Node> parameter, const bool with_subtract);
};

} // namespace low_precision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,8 @@

#pragma once

#include <memory>
#include <set>
#include <unordered_set>
#include <vector>

#include <ngraph/node.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include <ngraph/variant.hpp>

#include "low_precision/lpt_visibility.hpp"
#include "low_precision/rt_info/attribute_parameters.hpp"

namespace ngraph {
Expand Down
15 changes: 0 additions & 15 deletions src/common/low_precision_transformations/src/network_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1982,21 +1982,6 @@ void NetworkHelper::insertDequantizationAfter(
}
}
}

std::shared_ptr<ov::Node> NetworkHelper::fakeQuantizeWraper(
const std::shared_ptr<ov::Node> parameter) {
const auto input_low = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto input_high = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto output_low = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto output_high = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
return ngraph::pattern::wrap_type<opset1::FakeQuantize>({
parameter,
input_low,
input_high,
output_low,
output_high});
}

} // namespace low_precision
} // namespace pass
} // namespace ngraph
121 changes: 87 additions & 34 deletions src/common/low_precision_transformations/src/recurrent_cell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,47 +28,38 @@ RecurrentCellTransformation::RecurrentCellTransformation(const Params& params) :
const auto R = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto B = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();

const auto fq_X = NetworkHelper::fakeQuantizeWraper(X);
const auto fq_H = NetworkHelper::fakeQuantizeWraper(H);
const auto fq_W = NetworkHelper::fakeQuantizeWraper(W);
const auto fq_R = NetworkHelper::fakeQuantizeWraper(R);
const auto fq_X = wrap_fake_quantize(X);
const auto fq_H = wrap_fake_quantize(H);
const auto fq_W = wrap_fake_quantize(W);
const auto fq_R = wrap_fake_quantize(R);

const auto dequantization_convert_X = ngraph::pattern::wrap_type<ngraph::opset1::Convert>({ngraph::pattern::any_input()});
const auto dequantization_convert_H = ngraph::pattern::wrap_type<ngraph::opset1::Convert>({ngraph::pattern::any_input()});
const auto subtract_constant = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto dequantization_subtract_X = ngraph::pattern::wrap_type<ngraph::opset1::Subtract>(
{dequantization_convert_X, subtract_constant});
const auto dequantization_subtract_H = ngraph::pattern::wrap_type<ngraph::opset1::Subtract>(
{dequantization_convert_H, subtract_constant});
const auto multiply_constant = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto dequantization_multiply_X = ngraph::pattern::wrap_type<ngraph::opset1::Multiply>(
{dequantization_subtract_X, multiply_constant});
const auto quantization_X = wrap_quantization(X);
const auto quantization_H = wrap_quantization(H);

const auto dequantization_X = wrap_dequantization(quantization_X, true);
const auto dequantization_H = wrap_dequantization(quantization_H, true);

const auto dequantization_multiply_without_subtract_X = ngraph::pattern::wrap_type<ngraph::opset1::Multiply>(
{dequantization_convert_X, multiply_constant});
const auto dequantization_multiply_H = ngraph::pattern::wrap_type<ngraph::opset1::Multiply>(
{dequantization_subtract_H, multiply_constant});
const auto dequantization_multiply_without_subtract_H = ngraph::pattern::wrap_type<ngraph::opset1::Multiply>(
{dequantization_convert_H, multiply_constant});
const auto dequantization_without_subtract_X = wrap_dequantization(quantization_X, false);
const auto dequantization_without_subtract_H = wrap_dequantization(quantization_H, false);

const auto lstm_cell = ngraph::pattern::wrap_type<ngraph::opset5::LSTMCell>(
{fq_X, fq_H, C, fq_W, fq_R, B});
const auto lstm_cell_with_dequantizations = ngraph::pattern::wrap_type<ngraph::opset5::LSTMCell>(
{dequantization_multiply_X, dequantization_multiply_H, C, fq_W, fq_R, B});
{dequantization_X, dequantization_H, C, fq_W, fq_R, B});
const auto lstm_cell_with_dequantizations_without_subtract = ngraph::pattern::wrap_type<ngraph::opset5::LSTMCell>(
{dequantization_multiply_without_subtract_X, dequantization_multiply_without_subtract_H, C, fq_W, fq_R, B});
{dequantization_without_subtract_X, dequantization_without_subtract_H, C, fq_W, fq_R, B});

const auto gru_cell = ngraph::pattern::wrap_type<ngraph::opset4::GRUCell>({fq_X, fq_H, fq_W, fq_R, B});
const auto gru_cell_with_dequantizations = ngraph::pattern::wrap_type<ngraph::opset4::GRUCell>(
{dequantization_multiply_X, dequantization_multiply_X, fq_W, fq_R, B});
{dequantization_X, dequantization_H, fq_W, fq_R, B});
const auto gru_cell_with_dequantizations_without_subtract = ngraph::pattern::wrap_type<ngraph::opset4::GRUCell>(
{dequantization_multiply_without_subtract_X, dequantization_multiply_without_subtract_H, fq_W, fq_R, B});
{dequantization_without_subtract_X, dequantization_without_subtract_H, fq_W, fq_R, B});

const auto rnn_cell = ngraph::pattern::wrap_type<ngraph::opset4::RNNCell>({fq_X, fq_H, fq_W, fq_R, B});
const auto rnn_cell_with_dequantizations = ngraph::pattern::wrap_type<ngraph::opset4::RNNCell>(
{dequantization_multiply_X, dequantization_multiply_X, fq_W, fq_R, B});
{dequantization_X, dequantization_H, fq_W, fq_R, B});
const auto rnn_cell_with_dequantizations_without_subtract = ngraph::pattern::wrap_type<ngraph::opset4::RNNCell>(
{dequantization_multiply_without_subtract_X, dequantization_multiply_without_subtract_H, fq_W, fq_R, B});
{dequantization_without_subtract_X, dequantization_without_subtract_H, fq_W, fq_R, B});

ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto op = m.get_match_root();
Expand All @@ -94,7 +85,7 @@ RecurrentCellTransformation::RecurrentCellTransformation(const Params& params) :
}

bool RecurrentCellTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher& m) {
const auto lstm = m.get_match_root();
const auto lstm = m.get_match_root();
if (!canBeTransformed(context, lstm)) {
return false;
}
Expand All @@ -118,13 +109,7 @@ bool RecurrentCellTransformation::transform(TransformationContext& context, ngra
updatePrecisions);
std::shared_ptr<ngraph::Node> new_fq = std::get<0>(QDQ);
std::shared_ptr<ngraph::Node> deq_multiply = std::get<1>(QDQ);
if (deq_multiply == nullptr || new_fq == nullptr) {
return false;
}
auto multiply_parent = deq_multiply->get_input_node_shared_ptr(0);
if (is_type<ngraph::opset1::Subtract>(multiply_parent)) {
return false;
}
ov::disable_constant_folding(multiply_parent);
propagateSkipCleanupAttribute(deq_multiply);
this->register_new_node(new_fq);
Expand All @@ -142,7 +127,39 @@ bool RecurrentCellTransformation::transform(TransformationContext& context, ngra
return true;
}

bool RecurrentCellTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
bool RecurrentCellTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> lstm) const {
std::shared_ptr<ov::Node> W, R;
if (is_type<opset5::LSTMCell>(lstm)) {
W = lstm->get_input_node_shared_ptr(3);
R = lstm->get_input_node_shared_ptr(4);
} else {
W = lstm->get_input_node_shared_ptr(2);
R = lstm->get_input_node_shared_ptr(3);
}
for (auto fq_on_weight : {W, R}) {
auto fq_node = as_type_ptr<ngraph::opset1::FakeQuantize>(fq_on_weight);
const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(fq_node);
const auto precisionsAttribute = getAttributeFromOutput<PrecisionsAttribute>(fq_on_weight);
const auto precisions = precisionsAttribute.empty()
? defaultPrecisions
: precisionsAttribute.as<PrecisionsAttribute>().value();
const DataPrecision dataPrecision = getDataPrecision(fq_on_weight, quantizationDetails, precisions);
auto QDQ = NetworkHelper::decomposeFakeQuantize(fq_node,
dataPrecision.precision,
dataPrecision.min,
dataPrecision.max,
dataPrecision.hasZeroPoint,
updatePrecisions);
std::shared_ptr<ngraph::Node> new_fq = std::get<0>(QDQ);
std::shared_ptr<ngraph::Node> deq_multiply = std::get<1>(QDQ);
if (deq_multiply == nullptr || new_fq == nullptr) {
return false;
}
auto multiply_parent = deq_multiply->get_input_node_shared_ptr(0);
if (is_type<ngraph::opset1::Subtract>(multiply_parent)) {
return false;
}
}
return true;
}

Expand All @@ -160,6 +177,42 @@ void RecurrentCellTransformation::propagateSkipCleanupAttribute(std::shared_ptr<
}
}

std::shared_ptr<ov::Node> RecurrentCellTransformation::wrap_fake_quantize(
const std::shared_ptr<ov::Node> parameter) {
const auto input_low = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto input_high = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto output_low = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto output_high = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
return ngraph::pattern::wrap_type<opset1::FakeQuantize>({
parameter,
input_low,
input_high,
output_low,
output_high});
}

std::shared_ptr<ov::Node> RecurrentCellTransformation::wrap_quantization(
const std::shared_ptr<ov::Node> parameter) {
const auto quantization_fake_quantize = wrap_fake_quantize(parameter);
const auto quantization_convert = ngraph::pattern::wrap_type<ngraph::opset1::Convert>(
{quantization_fake_quantize});
return quantization_convert;
}

std::shared_ptr<ov::Node> RecurrentCellTransformation::wrap_dequantization(
const std::shared_ptr<ov::Node> parameter,
const bool with_subtract) {
const auto dequantization_convert = ngraph::pattern::wrap_type<ngraph::opset1::Convert>({parameter});
const auto subtract_constant = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto dequantization_subtract = ngraph::pattern::wrap_type<ngraph::opset1::Subtract>(
{dequantization_convert, subtract_constant});
const auto multiply_constant = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto multiply_parent = with_subtract ? dequantization_subtract : dequantization_convert;
const auto dequantization_multiply = ngraph::pattern::wrap_type<ngraph::opset1::Multiply>(
{multiply_parent, multiply_constant});
return dequantization_multiply;
}

} // namespace low_precision
} // namespace pass
} // namespace ngraph
7 changes: 0 additions & 7 deletions src/plugins/intel_cpu/src/nodes/tensoriterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,6 @@ class TensorIterator : public Node {

void setExtManager(const ExtensionManager::Ptr& extMgr) { ext_mng = extMgr; }

Graph getSubGraph() const {
return sub_graph;
}
std::shared_ptr<ov::Node> getOriginalOp() const {
return ngraphOp;
}

protected:
// needShapeInfer() should return false
// because we cannot resolve the output dimensions before the inference is completed
Expand Down
13 changes: 13 additions & 0 deletions src/plugins/intel_cpu/src/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/op/util/op_types.hpp>
#include <ngraph/pass/manager.hpp>
Expand Down Expand Up @@ -461,6 +462,18 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
{0, {ngraph::element::u8, ngraph::element::i8}},
{1, {ngraph::element::i8}}
}),
PrecisionsRestriction::create<ngraph::opset5::LSTMCell>({
{0, {ngraph::element::u8}},
{1, {ngraph::element::i8}},
}),
PrecisionsRestriction::create<ngraph::opset4::GRUCell>({
{0, {ngraph::element::u8}},
{1, {ngraph::element::i8}},
}),
PrecisionsRestriction::create<ngraph::opset4::RNNCell>({
{0, {ngraph::element::u8}},
{1, {ngraph::element::i8}},
}),
});

auto quantizationRestrictions = std::vector<QuantizationGranularityRestriction>({
Expand Down
Loading

0 comments on commit 7634c44

Please sign in to comment.