Skip to content

Commit

Permalink
[LPT] Disable cleanup #2
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Sep 28, 2023
1 parent e3fd64a commit 2d916d5
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include <memory>

#include "low_precision/layer_transformation.hpp"
#include "low_precision/cleanup_transformation.hpp"

namespace ov {
namespace pass {
Expand All @@ -20,7 +20,7 @@ namespace low_precision {
* [EliminateFakeQuantizeTransformation](@ref openvino_docs_OV_UG_lpt_EliminateFakeQuantizeTransformation) page
* in the Inference Engine Developer Guide.
*/
class LP_TRANSFORMATIONS_API EliminateFakeQuantizeTransformation : public LayerTransformation {
class LP_TRANSFORMATIONS_API EliminateFakeQuantizeTransformation : public CleanupTransformation {
public:
OPENVINO_RTTI("EliminateFakeQuantizeTransformation", "0");
EliminateFakeQuantizeTransformation(const Params& params = Params());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include <memory>

#include "low_precision/layer_transformation.hpp"
#include "low_precision/cleanup_transformation.hpp"

namespace ov {
namespace pass {
Expand All @@ -20,7 +20,7 @@ namespace low_precision {
* [FoldConvertTransformation](@ref openvino_docs_OV_UG_lpt_FoldConvertTransformation) page
* in the Inference Engine Developer Guide.
*/
class LP_TRANSFORMATIONS_API FoldConvertTransformation : public LayerTransformation {
class LP_TRANSFORMATIONS_API FoldConvertTransformation : public CleanupTransformation {
public:
OPENVINO_RTTI("FoldConvertTransformation", "0");
FoldConvertTransformation(const Params& params = Params());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

#pragma once


#include "low_precision/layer_transformation.hpp"
#include "low_precision/eltwise_base_transformation.hpp"
#include "low_precision/cleanup_transformation.hpp"

namespace ov {
namespace pass {
Expand All @@ -20,7 +18,7 @@ namespace low_precision {
* [FuseConvertTransformation](@ref openvino_docs_OV_UG_lpt_FuseConvertTransformation) page
* in the Inference Engine Developer Guide.
*/
class LP_TRANSFORMATIONS_API FuseConvertTransformation : public LayerTransformation {
class LP_TRANSFORMATIONS_API FuseConvertTransformation : public CleanupTransformation {
public:
OPENVINO_RTTI("FuseConvertTransformation", "0");
FuseConvertTransformation(const Params& params = Params());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#pragma once

#include <memory>
#include "low_precision/layer_transformation.hpp"
#include "low_precision/cleanup_transformation.hpp"
#include "common/precisions_restriction.hpp"

namespace ov {
Expand All @@ -20,7 +20,7 @@ namespace low_precision {
* [MultiplyToGroupConvolutionTransformation](@ref openvino_docs_OV_UG_lpt_MultiplyToGroupConvolutionTransformation) page
* in the Inference Engine Developer Guide.
*/
class LP_TRANSFORMATIONS_API MultiplyToGroupConvolutionTransformation : public LayerTransformation {
class LP_TRANSFORMATIONS_API MultiplyToGroupConvolutionTransformation : public CleanupTransformation {
public:
OPENVINO_RTTI("MultiplyToGroupConvolutionTransformation", "0");
MultiplyToGroupConvolutionTransformation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace ov {
namespace pass {
namespace low_precision {

EliminateFakeQuantizeTransformation::EliminateFakeQuantizeTransformation(const Params& params) : LayerTransformation(params) {
EliminateFakeQuantizeTransformation::EliminateFakeQuantizeTransformation(const Params& params) : CleanupTransformation(params) {
MATCHER_SCOPE(FuseMultiplyToFakeQuantizeTransformation);
const auto matcher = pattern::wrap_type<ov::opset1::FakeQuantize>({
pattern::any_input(),
Expand Down Expand Up @@ -112,6 +112,10 @@ bool check_intervals(const std::shared_ptr<ov::opset1::FakeQuantize>& fakeQuanti
} // namespace

bool EliminateFakeQuantizeTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
if (!CleanupTransformation::canBeTransformed(context, operation)) {
return false;
}

const auto fakeQuantize = ov::as_type_ptr<ov::opset1::FakeQuantize>(operation);
OPENVINO_ASSERT(fakeQuantize != nullptr, "unexpected operation type");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace ov {
namespace pass {
namespace low_precision {

FoldConvertTransformation::FoldConvertTransformation(const Params& params) : LayerTransformation(params) {
FoldConvertTransformation::FoldConvertTransformation(const Params& params) : CleanupTransformation(params) {
MATCHER_SCOPE(FoldConvertTransformation);
auto subtract = pattern::wrap_type<ov::opset1::Subtract>();
auto matcher = std::make_shared<ov::pass::pattern::Matcher>(subtract, matcher_name);
Expand Down Expand Up @@ -57,6 +57,7 @@ bool FoldConvertTransformation::transform(TransformationContext& context, ov::pa

bool FoldConvertTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
return
(CleanupTransformation::canBeTransformed(context, operation)) &&
(ov::is_type<ov::opset1::Convert>(operation->get_input_node_ptr(1)) &&
ov::is_type<ov::opset1::Constant>(operation->get_input_node_ptr(1)->get_input_node_ptr(0))) ||
(ov::is_type<ov::opset1::Convert>(operation->get_input_node_ptr(0)) &&
Expand Down
4 changes: 2 additions & 2 deletions src/common/low_precision_transformations/src/fuse_convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace ov {
namespace pass {
namespace low_precision {

FuseConvertTransformation::FuseConvertTransformation(const Params& params) : LayerTransformation(params) {
FuseConvertTransformation::FuseConvertTransformation(const Params& params) : CleanupTransformation(params) {
MATCHER_SCOPE(FuseConvertTransformation);
auto multiply = pattern::wrap_type<ov::opset1::Multiply>({ pattern::wrap_type<ov::opset1::Convert>(), pattern::wrap_type<ov::opset1::Constant>() });
auto subtract = pattern::wrap_type<ov::opset1::Subtract>({ pattern::wrap_type<ov::opset1::Convert>(), pattern::wrap_type<ov::opset1::Constant>() });
Expand Down Expand Up @@ -115,7 +115,7 @@ bool FuseConvertTransformation::transform(TransformationContext& context, ov::pa
}

bool FuseConvertTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const {
if (!getAttribute<DisableCleanupAttribute>(op).empty()) {
if (!CleanupTransformation::canBeTransformed(context, op)) {
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace low_precision {

MultiplyToGroupConvolutionTransformation::MultiplyToGroupConvolutionTransformation(
const Params& params,
const PrecisionsRestriction::PrecisionsByPorts& restrictions) : LayerTransformation(params), restrictions(restrictions), groupSize(1ul) {
const PrecisionsRestriction::PrecisionsByPorts& restrictions) : CleanupTransformation(params), restrictions(restrictions), groupSize(1ul) {
MATCHER_SCOPE(MultiplyToGroupConvolutionTransformation);
auto matcher = pattern::wrap_type<ov::opset1::Multiply>();

Expand Down Expand Up @@ -143,6 +143,10 @@ bool MultiplyToGroupConvolutionTransformation::transform(TransformationContext&
}

bool MultiplyToGroupConvolutionTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
if (!CleanupTransformation::canBeTransformed(context, operation)) {
return false;
}

const PartialShape outPShape = operation->get_output_partial_shape(0);
const auto rank = outPShape.rank();
if (rank.is_dynamic()) {
Expand Down

0 comments on commit 2d916d5

Please sign in to comment.