Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ndemashov committed Apr 4, 2022
1 parent 26765ac commit 90e33cf
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,6 @@ class LP_TRANSFORMATIONS_API NetworkHelper {
float& updatedOutputLowValue,
float& updatedOutputHighValue);

static std::shared_ptr<ov::Node> fakeQuantizeWraper
(const std::shared_ptr<ov::Node> parameter);

private:
static std::shared_ptr<Node> foldFakeQuantize(
const std::shared_ptr<opset1::FakeQuantize>& fq,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class LP_TRANSFORMATIONS_API RecurrentCellTransformation : public LayerTransform
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
void propagateSkipCleanupAttribute(std::shared_ptr<Node> dequantization_multiply);
static std::shared_ptr<ov::Node> wrap_fake_quantize(const std::shared_ptr<ov::Node> parameter);
};

} // namespace low_precision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,8 @@

#pragma once

#include <memory>
#include <set>
#include <unordered_set>
#include <vector>

#include <ngraph/node.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include <ngraph/variant.hpp>

#include "low_precision/lpt_visibility.hpp"
#include "low_precision/rt_info/attribute_parameters.hpp"

namespace ngraph {
Expand Down
15 changes: 0 additions & 15 deletions src/common/low_precision_transformations/src/network_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1982,21 +1982,6 @@ void NetworkHelper::insertDequantizationAfter(
}
}
}

std::shared_ptr<ov::Node> NetworkHelper::fakeQuantizeWraper(
const std::shared_ptr<ov::Node> parameter) {
const auto input_low = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto input_high = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto output_low = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto output_high = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
return ngraph::pattern::wrap_type<opset1::FakeQuantize>({
parameter,
input_low,
input_high,
output_low,
output_high});
}

} // namespace low_precision
} // namespace pass
} // namespace ngraph
22 changes: 18 additions & 4 deletions src/common/low_precision_transformations/src/recurrent_cell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ RecurrentCellTransformation::RecurrentCellTransformation(const Params& params) :
const auto R = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto B = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();

const auto fq_X = NetworkHelper::fakeQuantizeWraper(X);
const auto fq_H = NetworkHelper::fakeQuantizeWraper(H);
const auto fq_W = NetworkHelper::fakeQuantizeWraper(W);
const auto fq_R = NetworkHelper::fakeQuantizeWraper(R);
const auto fq_X = wrap_fake_quantize(X);
const auto fq_H = wrap_fake_quantize(H);
const auto fq_W = wrap_fake_quantize(W);
const auto fq_R = wrap_fake_quantize(R);

const auto dequantization_convert_X = ngraph::pattern::wrap_type<ngraph::opset1::Convert>({ngraph::pattern::any_input()});
const auto dequantization_convert_H = ngraph::pattern::wrap_type<ngraph::opset1::Convert>({ngraph::pattern::any_input()});
Expand Down Expand Up @@ -160,6 +160,20 @@ void RecurrentCellTransformation::propagateSkipCleanupAttribute(std::shared_ptr<
}
}

std::shared_ptr<ov::Node> RecurrentCellTransformation::wrap_fake_quantize(
const std::shared_ptr<ov::Node> parameter) {
const auto input_low = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto input_high = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto output_low = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto output_high = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
return ngraph::pattern::wrap_type<opset1::FakeQuantize>({
parameter,
input_low,
input_high,
output_low,
output_high});
}

} // namespace low_precision
} // namespace pass
} // namespace ngraph
7 changes: 0 additions & 7 deletions src/plugins/intel_cpu/src/nodes/tensoriterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,6 @@ class TensorIterator : public Node {

void setExtManager(const ExtensionManager::Ptr& extMgr) { ext_mng = extMgr; }

Graph getSubGraph() const {
return sub_graph;
}
std::shared_ptr<ov::Node> getOriginalOp() const {
return ngraphOp;
}

protected:
// needShapeInfer() should return false
// because we cannot resolve the output dimensions before the inference is completed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include <gtest/gtest.h>

#include <low_precision/common/operation_per_tensor_quantization_restriction.hpp>
#include <low_precision/common/operation_precision_restriction.hpp>
#include <low_precision/recurrent_cell.hpp>
#include <low_precision/fold_convert.hpp>
Expand Down Expand Up @@ -69,18 +68,12 @@ class RecurrentCellTransformationTestValues {
: params(params),
type(type),
actual(actual),
result(result),
addNotPrecisionPreservedOperation(addNotPrecisionPreservedOperation),
checkIntervalsAlignmentAttributes(checkIntervalsAlignmentAttributes) {}
result(result) {}

TestTransformationParams params;
RecurrentCellFunction::RNNType type;
RecurrentCellTransformationValues actual;
RecurrentCellTransformationValues result;
// add not precision preserved operation to set output precision for FakeQuantize
// don't set to 'true' by default to keep test cases with tested operation as output
bool addNotPrecisionPreservedOperation;
bool checkIntervalsAlignmentAttributes;
};

inline std::ostream& operator<<(std::ostream& out, const RecurrentCellTransformationTestValues& values) {
Expand Down Expand Up @@ -119,10 +112,7 @@ class RecurrentCellTransformation : public LayerTransformation, public testing::
testValues.actual.dequantization_H,
testValues.actual.dequantization_W,
testValues.actual.dequantization_R
},
{},
ngraph::element::undefined,
{});
});

const auto params = TestTransformationParams::toParams(testValues.params);

Expand All @@ -143,13 +133,6 @@ class RecurrentCellTransformation : public LayerTransformation, public testing::
testValues.result.dequantizationAfter.multiply.outPrecision = precision;
}

if (!testValues.params.updatePrecisions && (precision == ngraph::element::f32) &&
!testValues.result.dequantizationAfter.convert.empty()) {
testValues.result.dequantizationAfter.convert = {};
}

IntervalsAlignmentSharedValue::Interval interval{-1.28f, 2.55f};

referenceFunction =
ngraph::builder::subgraph::RecurrentCellFunction::get(precision,
activations_shapes,
Expand All @@ -172,10 +155,7 @@ class RecurrentCellTransformation : public LayerTransformation, public testing::
testValues.result.dequantization_H,
testValues.result.dequantization_W,
testValues.result.dequantization_R
},
{},
testValues.result.precisionAfterOperation,
testValues.result.dequantizationAfter);
});
}

static std::string getTestCaseName(testing::TestParamInfo<RecurrentCellTransformationParams> obj) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,7 @@ void RecurrentCellTransformation::SetUp() {
param.dequantization_H,
param.dequantization_W,
param.dequantization_R
},
{},
ngraph::element::undefined,
{});
});
}

void RecurrentCellTransformation::Run() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ class RecurrentCellFunction {
const RNNType type,
const std::vector<FakeQuantizeOnDataWithConstant>& fqOnDatas,
const std::vector<DequantizationOperations::Convert>& converts,
const std::vector<DequantizationOperations>& dequantizations,
const std::vector<ov::Any>& concatAttributes,
const ngraph::element::Type precisionAfterOperation,
const DequantizationOperations& dequantizationAfter);
const std::vector<DequantizationOperations>& dequantizations);
};

std::shared_ptr<Node> makeQuantizationAndDequantization(const std::shared_ptr<Node> input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ std::shared_ptr<ngraph::Function> RecurrentCellFunction::get(
const RNNType type,
const std::vector<FakeQuantizeOnDataWithConstant>& fqOnDatas,
const std::vector<DequantizationOperations::Convert>& converts,
const std::vector<DequantizationOperations>& dequantizations,
const std::vector<ov::Any>& concatAttributes,
const ngraph::element::Type precisionAfterOperation,
const DequantizationOperations& dequantizationAfter) {
const std::vector<DequantizationOperations>& dequantizations) {
auto X = std::make_shared<opset5::Parameter>(inputPrecision, inputActivationsShapes[0]);
X->set_friendly_name("X");
std::shared_ptr<Node> parent_X = makeQuantizationAndDequantization(X,
Expand Down

0 comments on commit 90e33cf

Please sign in to comment.