Skip to content

Commit

Permalink
[LPT] Tests infrastructure support
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Sep 28, 2023
1 parent e484452 commit 373e8c2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,29 @@
#include "low_precision/markup_quantization_granularity.hpp"
#include "low_precision/transformation_context.hpp"

// cleanup transformations
#include "low_precision/convert.hpp"
#include "low_precision/eliminate_fake_quantize.hpp"
#include "low_precision/fold_convert.hpp"
#include "low_precision/fold_fake_quantize.hpp"
#include "low_precision/fuse_convert.hpp"
#include "low_precision/fuse_multiply_to_fake_quantize.hpp"
#include "low_precision/fuse_subtract_to_fake_quantize.hpp"
#include "low_precision/multiply_to_group_convolution.hpp"

#include <string>

using namespace testing;
using namespace ov::pass;
using namespace ov::pass::low_precision;

OPENVINO_SUPPRESS_DEPRECATED_START

SimpleLowPrecisionTransformer::SimpleLowPrecisionTransformer(
const std::vector<ov::pass::low_precision::PrecisionsRestriction>& precisionRestrictions,
const std::vector<ov::pass::low_precision::QuantizationGranularityRestriction>& quantizationRestrictions,
const AttributeParameters& params) {
const AttributeParameters& params,
const bool addCleanup) {
auto passConfig = get_pass_config();

// TODO: use one pass manager
Expand All @@ -39,7 +51,20 @@ SimpleLowPrecisionTransformer::SimpleLowPrecisionTransformer(

common = std::make_shared<ov::pass::Manager>(passConfig);
commonGraphRewrite = common->register_pass<ov::pass::GraphRewrite>();

cleanup = common->register_pass<ov::pass::GraphRewrite>();
if (addCleanup) {
ov::pass::low_precision::LayerTransformation::Params params;
cleanup->add_matcher<EliminateFakeQuantizeTransformation>(params);
cleanup->add_matcher<FoldConvertTransformation>(params);
cleanup->add_matcher<FuseConvertTransformation>(params);
cleanup->add_matcher<FuseSubtractToFakeQuantizeTransformation>(params);
cleanup->add_matcher<FuseMultiplyToFakeQuantizeTransformation>(params);

cleanup->add_matcher<MultiplyToGroupConvolutionTransformation>(
params,
PrecisionsRestriction::getPrecisionsByOperationType<opset1::GroupConvolution>(precisionRestrictions));
}
}

void SimpleLowPrecisionTransformer::transform(std::shared_ptr<ov::Model>& model) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class SimpleLowPrecisionTransformer : public ngraph::pass::FunctionPass{
SimpleLowPrecisionTransformer(
const std::vector<ov::pass::low_precision::PrecisionsRestriction>& precisionRestrictions = {},
const std::vector<ov::pass::low_precision::QuantizationGranularityRestriction>& quantizationRestrictions = {},
const AttributeParameters& params = AttributeParameters());
const AttributeParameters& params = AttributeParameters(),
const bool addCleanup = false);

template <class T, class Operation>
void add(const TestTransformationParams& params) {
Expand Down

0 comments on commit 373e8c2

Please sign in to comment.