Skip to content

Commit

Permalink
[LPT] Refactoring: PoC (openvinotoolkit#5226)
Browse files Browse the repository at this point in the history
* LPT fix for Windows

* LPT fix for Windows

* Remove inference_engine_transformations_EXPORTS

* [nGraph] Register new node in GraphRewrite

* [LPT] nGraph alignment

* [LPT] nGraph alignment: tests

Co-authored-by: Ilya Lavrenov <[email protected]>
  • Loading branch information
2 people authored and akuporos committed Sep 29, 2021
1 parent 258e50e commit 1ec6ff4
Show file tree
Hide file tree
Showing 393 changed files with 9,559 additions and 5,323 deletions.
103 changes: 78 additions & 25 deletions inference-engine/src/cldnn_engine/cldnn_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,12 @@
#include <transformations/low_precision/disable_convert_constant_folding_on_const_path.hpp>
#include <low_precision/pull_reshape_through_dequantization.hpp>
#include <low_precision/pull_transpose_through_dequantization.hpp>
#include <low_precision/transformer.hpp>
#include <low_precision/convolution.hpp>
#include <low_precision/convolution_backprop_data.hpp>
#include <low_precision/group_convolution.hpp>
#include <low_precision/low_precision.hpp>
#include <low_precision/mat_mul.hpp>
#include <low_precision/multiply_to_group_convolution.hpp>
#include <low_precision/strided_slice.hpp>
#include <low_precision/network_helper.hpp>

Expand Down Expand Up @@ -151,10 +154,12 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
OV_ITT_SCOPED_TASK(itt::domains::CLDNNPlugin, "clDNNEngine::TransformNetwork");
auto nGraphFunc = clonedNetwork.getFunction();

using const_node_ptr = const std::shared_ptr<const ngraph::Node>;

bool enableInt8;
{
ngraph::pass::Manager manager;
enableInt8 = config.enableInt8 && ngraph::pass::low_precision::LowPrecisionTransformer::isFunctionQuantized(nGraphFunc);
enableInt8 = config.enableInt8 && ngraph::pass::low_precision::LowPrecision::isFunctionQuantized(nGraphFunc);
if (enableInt8) {
manager.register_pass<ngraph::pass::DisableConvertConstantFoldingOnConstPath>(
std::vector<ngraph::element::Type>{ ngraph::element::i8, ngraph::element::u8, ngraph::element::i4, ngraph::element::u4 });
Expand Down Expand Up @@ -208,8 +213,6 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc

auto pass_config = manager.get_pass_config();

using const_node_ptr = const std::shared_ptr<const ngraph::Node>;

// SpaceToDepth/DepthToSpace node implementation supports only equal input/output tensors with rank <= 5
pass_config->set_callback<ngraph::pass::ConvertSpaceToDepth,
ngraph::pass::ConvertDepthToSpace>(
Expand Down Expand Up @@ -391,28 +394,78 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
if (!config.enable_fp16_for_quantized_models) {
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::f16, ngraph::element::f32 }});
}
auto lptPrerequisites = manager.register_pass<ngraph::pass::GraphRewrite>();
const std::vector<ngraph::element::Type> supportedTypes = { ngraph::element::i8, ngraph::element::u8 };
lptPrerequisites->add_matcher<PullReshapeThroughDequantization>(supportedTypes);
lptPrerequisites->add_matcher<PullTransposeThroughDequantization>(supportedTypes);
lptPrerequisites->add_matcher<ngraph::pass::LinOpSequenceFusion>();
manager.run_passes(nGraphFunc);

auto params = LayerTransformation::Params(true, // updatePrecisions
LayerTransformation::QuantizedTensorAlignment::UpdateLevel, // quantizedTensorAlignmentOnActivations
LayerTransformation::QuantizedTensorAlignment::None, // quantizedTensorAlignmentOnWeights
true); // supportAsymmetricQuantization
LowPrecisionTransformer transformer(LowPrecisionTransformer::getAllTransformations(params)
.add<MatMulTransformation, ngraph::opset1::MatMul>(LayerTransformation::Params(params)
.setSupportAsymmetricQuantization(false)
.setSupport3DTensorOnActivations(false))
.add<ConvolutionBackpropDataTransformation, ngraph::opset1::ConvolutionBackpropData>(LayerTransformation::Params(params)
.setSupportAsymmetricQuantization(false)
.setDeconvolutionSpecificChannelsRatio(true))
// INT8 StridedSlice not supported
.remove<StridedSliceTransformation, ngraph::opset1::StridedSlice>());

transformer.transform(nGraphFunc);
auto supportedPrecisions = std::vector<OperationPrecisionRestriction>({
OperationPrecisionRestriction::create<ngraph::opset1::Convolution>({
{0, {ngraph::element::u8, ngraph::element::i8}},
{1, {ngraph::element::i8}},
}),
OperationPrecisionRestriction::create<ngraph::opset1::ConvolutionBackpropData>({
{0, {ngraph::element::u8, ngraph::element::i8}},
{1, {ngraph::element::i8}}
}),
OperationPrecisionRestriction::create<ngraph::opset1::GroupConvolution>({
{0, {ngraph::element::u8, ngraph::element::i8}},
{1, {ngraph::element::i8}}
}),
OperationPrecisionRestriction::create<ngraph::opset1::StridedSlice>({})
});

auto perTensorQuantization = std::vector<OperationPerTensorQuantizationRestriction>({
OperationPerTensorQuantizationRestriction::create<ngraph::opset1::Convolution>({0}),
OperationPerTensorQuantizationRestriction::create<ngraph::opset1::ConvolutionBackpropData>({0}),
});

ngraph::pass::Manager lptManager;

auto lptPassConfig = lptManager.get_pass_config();
lptPassConfig->disable<ngraph::pass::low_precision::StridedSliceTransformation>();
lptPassConfig->set_callback<ngraph::pass::low_precision::MarkupPrecisions>([](const_node_ptr& node) -> bool {
if (const auto mulitply = std::dynamic_pointer_cast<const ngraph::opset1::Multiply>(node)) {
return !MultiplyToGroupConvolutionTransformation::canBeTransformedToGroupConvolution(mulitply);
}
return false;
});
lptPassConfig->set_callback<ConvolutionBackpropDataTransformation>([](const_node_ptr& node) -> bool {
auto fillStaticChannel = [](const ngraph::PartialShape& shape, size_t& channel) -> bool {
const auto rank = shape.rank();
if (rank.is_dynamic()) {
return false;
}
if (rank.get_length() < 2ul) {
return false;
}
const auto dimension = shape[1];
if (dimension.is_dynamic()) {
return false;
}
channel = dimension.get_length();
return true;
};

size_t inputChannels;
if (!fillStaticChannel(node->get_input_partial_shape(0), inputChannels)) {
return true;
}

size_t outputChannels;
if (!fillStaticChannel(node->get_output_partial_shape(0), outputChannels)) {
return true;
}


if ((inputChannels % 4 != 0) || (outputChannels % 16 != 0)) {
return true;
}

return LayerTransformation::isAsymmetricQuantization(node) || WeightableLayerTransformation::isAsymmetricOnWeights(node);
});
lptPassConfig->set_callback<MatMulTransformation>([](const_node_ptr& node) -> bool {
return MatMulTransformation::is3DTensorOnActivations(node);
});

lptManager.register_pass<LowPrecision>(supportedPrecisions, perTensorQuantization);
lptManager.run_passes(nGraphFunc);
}

{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ ie_faster_build(${TARGET_NAME}
ie_add_vs_version_file(NAME ${TARGET_NAME}
FILEDESCRIPTION "Inference Engine LP transformations library")

target_compile_definitions(${TARGET_NAME} PRIVATE inference_engine_transformations_EXPORTS)

target_link_libraries(${TARGET_NAME} PUBLIC inference_engine_transformations
PRIVATE openvino::itt)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@ namespace ngraph {
namespace pass {
namespace low_precision {

class TRANSFORMATIONS_API AddTransformation : public EltwiseBaseTransformation {
class LP_TRANSFORMATIONS_API AddTransformation : public EltwiseBaseTransformation {
public:
AddTransformation(const Params& params) : EltwiseBaseTransformation(params) {}
~AddTransformation() override {}
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
NGRAPH_RTTI_DECLARATION;
AddTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <ngraph/pass/pass.hpp>
#include "low_precision/lpt_visibility.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

class LP_TRANSFORMATIONS_API AlignQuantizationIntervals;

} // namespace low_precision
} // namespace pass
} // namespace ngraph

class ngraph::pass::low_precision::AlignQuantizationIntervals : public ngraph::pass::FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>

#include <ngraph/pass/pass.hpp>
#include "low_precision/lpt_visibility.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

class LP_TRANSFORMATIONS_API AlignQuantizationParameters;

} // namespace low_precision
} // namespace pass
} // namespace ngraph

class ngraph::pass::low_precision::AlignQuantizationParameters : public ngraph::pass::FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ namespace ngraph {
namespace pass {
namespace low_precision {

class TRANSFORMATIONS_API AvgPoolTransformation : public LayerTransformation {
class LP_TRANSFORMATIONS_API AvgPoolTransformation : public LayerTransformation {
public:
AvgPoolTransformation(const Params& params);
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
NGRAPH_RTTI_DECLARATION;
AvgPoolTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once
#include <ngraph/node.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include "rt_info/attribute_parameters.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

class LP_TRANSFORMATIONS_API BaseMatcherPass;

} // namespace low_precision
} // namespace pass
} // namespace ngraph

class LP_TRANSFORMATIONS_API ngraph::pass::low_precision::BaseMatcherPass : public ngraph::pass::MatcherPass {
public:
BaseMatcherPass(const AttributeParameters& params = AttributeParameters());
AttributeParameters params;
};
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ namespace ngraph {
namespace pass {
namespace low_precision {

class TRANSFORMATIONS_API ClampTransformation : public LayerTransformation {
class LP_TRANSFORMATIONS_API ClampTransformation : public LayerTransformation {
public:
ClampTransformation(const Params& params);
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) const override;
NGRAPH_RTTI_DECLARATION;
ClampTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
#include <ngraph/check.hpp>
#include <ngraph/opsets/opset1.hpp>

#include "transformations_visibility.hpp"
#include "low_precision/lpt_visibility.hpp"
#include "transformations/rt_info/dequantization_attribute.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

// template<typename BaseOp2>
// class TRANSFORMATIONS_API DequantizationOp : public BaseOp2 {
// class LP_TRANSFORMATIONS_API DequantizationOp : public BaseOp2 {
// public:
// template <typename ... Args>
// DequantizationOp(Args&&... args) : BaseOp2(std::forward<Args>(args)...) {
Expand Down Expand Up @@ -63,7 +63,7 @@ void copyRuntimeInfo(const ngraph::Node& from, ngraph::Node& to) {

} // namespace

class TRANSFORMATIONS_API DequantizationConvert : public ngraph::opset1::Convert {
class LP_TRANSFORMATIONS_API DequantizationConvert : public ngraph::opset1::Convert {
public:
DequantizationConvert(const ngraph::Output<Node>& arg, const ngraph::element::Type& destination_type) :
ngraph::opset1::Convert(arg, destination_type) {
Expand All @@ -77,7 +77,7 @@ class TRANSFORMATIONS_API DequantizationConvert : public ngraph::opset1::Convert
}
};

class TRANSFORMATIONS_API DequantizationSubtract : public ngraph::opset1::Subtract {
class LP_TRANSFORMATIONS_API DequantizationSubtract : public ngraph::opset1::Subtract {
public:
DequantizationSubtract(
const ngraph::Output<Node>& arg0,
Expand All @@ -94,7 +94,7 @@ class TRANSFORMATIONS_API DequantizationSubtract : public ngraph::opset1::Subtra
}
};

class TRANSFORMATIONS_API DequantizationMultiply : public ngraph::opset1::Multiply {
class LP_TRANSFORMATIONS_API DequantizationMultiply : public ngraph::opset1::Multiply {
public:
DequantizationMultiply(
const Output<Node>& arg0,
Expand All @@ -116,7 +116,7 @@ class TRANSFORMATIONS_API DequantizationMultiply : public ngraph::opset1::Multip
}
};

class TRANSFORMATIONS_API DequantizationAdd : public ngraph::opset1::Add {
class LP_TRANSFORMATIONS_API DequantizationAdd : public ngraph::opset1::Add {
public:
DequantizationAdd(
const ngraph::Output<Node>& arg0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
#include <tuple>
#include <ngraph/ngraph.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <low_precision/lpt_visibility.hpp>

namespace ngraph {
namespace pass {
namespace low_precision {

typedef std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> FakeQuantizeDequantizationValues;

class FakeQuantizeDequantization {
class LP_TRANSFORMATIONS_API FakeQuantizeDequantization {
public:
FakeQuantizeDequantization();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <exception>
#include <string>
#include <ngraph/node.hpp>
#include <transformations_visibility.hpp>
#include <low_precision/lpt_visibility.hpp>

/**
* @def THROW_TRANSFORMATION_EXCEPTION_LPT
Expand All @@ -19,7 +19,7 @@ namespace ngraph {
namespace pass {
namespace low_precision {

class TRANSFORMATIONS_API Exception : std::exception {
class LP_TRANSFORMATIONS_API Exception : std::exception {
std::shared_ptr<std::ostringstream> buffer;
mutable std::string buffer_str;
public:
Expand All @@ -42,7 +42,7 @@ class TRANSFORMATIONS_API Exception : std::exception {
#define THROW_TRANSFORMATION_EXCEPTION throw ::ngraph::pass::low_precision::Exception() << __FILE__ << ":" << __LINE__ << " "


class TRANSFORMATIONS_API InferenceEngineLptException : public Exception {
class LP_TRANSFORMATIONS_API InferenceEngineLptException : public Exception {
public:
InferenceEngineLptException(const std::string& filename, const size_t line, const Node& node) {
*this
Expand Down
Loading

0 comments on commit 1ec6ff4

Please sign in to comment.