Skip to content

Commit

Permalink
Added decompression related callbacks for Cleanup LPT
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Nov 9, 2023
1 parent 160f38f commit c1b8244
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace low_precision {
/**
* @ingroup ie_transformation_common_api
* @brief FoldConvertTransformation evaluates Convert operation on Subtract constant subgraph.
* Important notice: this transformation ignores DisableConstantFolding runtime attribute.
*
* For more details about the transformation, refer to
* [FoldConvertTransformation](@ref openvino_docs_OV_UG_lpt_FoldConvertTransformation) page
Expand Down
120 changes: 70 additions & 50 deletions src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@
#include "low_precision/add.hpp"
#include "low_precision/convert_subtract_constant.hpp"
#include "low_precision/convolution_backprop_data.hpp"
#include "low_precision/fold_convert.hpp"
#include "low_precision/fuse_convert.hpp"
#include "low_precision/group_convolution.hpp"
#include "low_precision/multiply_to_group_convolution.hpp"
#include "low_precision/recurrent_cell.hpp"
#include "low_precision/network_helper.hpp"
#include "low_precision/recurrent_cell.hpp"
#include "low_precision/rt_info/bias_attribute.hpp"
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"

Expand Down Expand Up @@ -130,6 +132,35 @@ namespace intel_cpu {

using const_node_ptr = const std::shared_ptr<const ov::Node>;

bool Transformations::is_decompression_multiply(const_node_ptr& node) const {
auto get_single_consumer = [](const_node_ptr& node) -> std::shared_ptr<ov::Node> {
const auto consumers = node->get_output_target_inputs(0);
if (consumers.size() != 1)
return nullptr;
return consumers.begin()->get_node()->shared_from_this();
};

auto consumer = get_single_consumer(node);
if (!consumer)
return false;

if (ov::is_type<ov::opset1::MatMul>(consumer)) {
return true;
} else if (ov::is_type<ov::opset1::Reshape>(consumer)) {
consumer = get_single_consumer(consumer);
if (consumer != nullptr && ov::is_type<ov::opset1::MatMul>(consumer)) {
return true;
}
}
if (consumer != nullptr && ov::is_type<ov::opset1::Convert>(consumer)) {
consumer = get_single_consumer(consumer);
if (consumer != nullptr && ov::is_type<ov::opset1::MatMul>(consumer)) {
return true;
}
}
return false;
}

bool Transformations::fuse_type_to_convert(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions) {
auto convert = ov::as_type_ptr<ov::opset10::Convert>(node);
if (!convert)
Expand Down Expand Up @@ -224,35 +255,9 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
decompression_precisions.push_back(ov::element::nf4);
}
CPU_REGISTER_PASS_X64(decompression_handling_manager, ov::pass::MarkDequantizationSubgraph, decompression_precisions, true);
CPU_SET_CALLBACK_X64(decompression_handling_manager, [](const_node_ptr &node) -> bool {
auto get_single_consumer = [](const_node_ptr &node) -> std::shared_ptr<ov::Node> {
const auto consumers = node->get_output_target_inputs(0);
if (consumers.size() != 1)
return nullptr;
return consumers.begin()->get_node()->shared_from_this();
};

auto consumer = get_single_consumer(node);
if (!consumer)
return true;

if (ov::is_type<ov::opset1::MatMul>(consumer)) {
return false;
} else if (ov::is_type<ov::opset1::Reshape>(consumer)) {
consumer = get_single_consumer(consumer);
if (consumer != nullptr && ov::is_type<ov::opset1::MatMul>(consumer)) {
return false;
}
}
if (consumer != nullptr && ov::is_type<ov::opset1::Convert>(consumer)) {
consumer = get_single_consumer(consumer);
if (consumer != nullptr && ov::is_type<ov::opset1::MatMul>(consumer)) {
return false;
}
}
return true;
CPU_SET_CALLBACK_X64(decompression_handling_manager, [&](const_node_ptr &node) -> bool {
return !is_decompression_multiply(node);
}, ov::pass::MarkDequantizationSubgraph);
decompression_handling_manager.register_pass<ov::pass::VisualizeTree>("/home/vgolubev/models/after_decompression.svg");
decompression_handling_manager.run_passes(model);

ov::pass::Manager manager;
Expand Down Expand Up @@ -566,32 +571,47 @@ void Transformations::Lpt(const bool hasINT16orINT32Levels, const std::vector<ov
}

ov::pass::Manager lptManager;
CPU_REGISTER_PASS_COMMON(lptManager, ov::pass::low_precision::LowPrecision,
CPU_REGISTER_PASS_COMMON(lptManager, LowPrecision,
supportedPrecisions,
quantizationRestrictions,
LayerTransformation::Params(updatePrecision, ov::element::f32, defaultPrecisions));
CPU_SET_CALLBACK_COMMON(lptManager,
[](const_node_ptr& node) -> bool {
if (const auto mulitply = std::dynamic_pointer_cast<const ov::opset1::Multiply>(node)) {
return !MultiplyToGroupConvolutionTransformation::canBeTransformedToGroupConvolution(mulitply);
}
return false;
},
ov::pass::low_precision::MarkupPrecisions);
CPU_SET_CALLBACK_COMMON(lptManager,
[&defaultPrecisions](const_node_ptr& node) -> bool {
return LayerTransformation::isAsymmetricQuantization(node, defaultPrecisions) ||
WeightableLayerTransformation::isAsymmetricOnWeights(node, defaultPrecisions);
},
ov::pass::low_precision::ConvolutionBackpropDataTransformation);

lptManager.get_pass_config()->set_callback<ov::pass::low_precision::AddTransformation>(
[](const_node_ptr& node) -> bool {
return ov::marked_as_bias(node);
});
CPU_SET_CALLBACK_COMMON(lptManager, [](const_node_ptr& node) -> bool {
return ov::is_type<ov::opset1::Multiply>(node) &&
!MultiplyToGroupConvolutionTransformation::canBeTransformedToGroupConvolution(node);
}, MarkupPrecisions);
CPU_SET_CALLBACK_COMMON(lptManager, [&defaultPrecisions](const_node_ptr& node) -> bool {
return LayerTransformation::isAsymmetricQuantization(node, defaultPrecisions) ||
WeightableLayerTransformation::isAsymmetricOnWeights(node, defaultPrecisions);
}, ConvolutionBackpropDataTransformation);
CPU_SET_CALLBACK_COMMON(lptManager, [](const_node_ptr& node) -> bool {
return ov::marked_as_bias(node);
}, AddTransformation);

CPU_SET_CALLBACK_X64(lptManager, [&](const_node_ptr& node) -> bool {
const auto& consumers = node->get_output_target_inputs(0);
if (consumers.size() == 1) {
const auto consumer = consumers.begin()->get_node()->shared_from_this();
return ov::is_type<ov::opset1::Multiply>(consumer) && is_decompression_multiply(consumer);
}
return false;
}, FoldConvertTransformation);

CPU_SET_CALLBACK_X64(lptManager, [&](const_node_ptr& node) -> bool {
if (ov::is_type<ov::opset1::Multiply>(node)) {
return ov::is_type<ov::opset1::Multiply>(node) && is_decompression_multiply(node);
} else if (ov::is_type<ov::opset1::Subtract>(node)) {
const auto& consumers = node->get_output_target_inputs(0);
if (consumers.size() == 1) {
const auto consumer = consumers.begin()->get_node()->shared_from_this();
return ov::is_type<ov::opset1::Multiply>(consumer) && is_decompression_multiply(consumer);
}
}
return false;
}, FuseConvertTransformation);

CPU_DISABLE_PASS_ARM(lptManager, ov::pass::low_precision::RecurrentCellTransformation);
CPU_DISABLE_PASS_COMMON(lptManager, ov::pass::low_precision::MultiplyToGroupConvolutionTransformation);
CPU_DISABLE_PASS_ARM(lptManager, RecurrentCellTransformation);
CPU_DISABLE_PASS_COMMON(lptManager, MultiplyToGroupConvolutionTransformation);

lptManager.run_passes(model);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ class Transformations {

void PostSnippets(void);

static bool fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions);
bool is_decompression_multiply(const std::shared_ptr<const ov::Node>& node) const;

static bool fuse_type_to_convert(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions);
};

} // namespace intel_cpu
Expand Down

0 comments on commit c1b8244

Please sign in to comment.