Skip to content

Commit

Permalink
Revert "[GPU] Apply is_non_decompression_multiply() callback only for…
Browse files Browse the repository at this point in the history
… compressed models (#21719)"

This reverts commit 032ac89.
  • Loading branch information
sshlyapn authored and Lyamin-Roman committed Dec 19, 2023
1 parent 2779df5 commit 26a31b3
Showing 1 changed file with 21 additions and 27 deletions.
48 changes: 21 additions & 27 deletions src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,9 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
bool enableInt8;
bool unroll_loop = config.get_property(ov::intel_gpu::enable_loop_unrolling);
{
ov::pass::Manager initial_transformations_manager;
initial_transformations_manager.set_per_pass_validation(false);
ov::pass::Manager manager;
auto pass_config = manager.get_pass_config();
manager.set_per_pass_validation(false);

// Temporary solution, global rt info cleanup is needed
for (auto& node : func->get_ops()) {
Expand All @@ -201,8 +202,13 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
}

enableInt8 = config.get_property(ov::intel_gpu::enable_lp_transformations) && ov::pass::low_precision::LowPrecision::isFunctionQuantized(func);
initial_transformations_manager.register_pass<ov::pass::InitNodeInfo>();
initial_transformations_manager.register_pass<EinsumDecomposition>();
if (enableInt8) {
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(
std::vector<ov::element::Type>{ ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4 });
}

manager.register_pass<ov::pass::InitNodeInfo>();
manager.register_pass<EinsumDecomposition>();

precisions_map fp_convert_precision_map = {
{ov::element::f64, ov::element::f32}
Expand Down Expand Up @@ -251,19 +257,19 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
}

type_to_fuse_map empty_fuse_map = {};
initial_transformations_manager.register_pass<ov::pass::Validate>();
manager.register_pass<ov::pass::Validate>();

// fuse softmax, MVN patterns, so that they will not be marked as precision sensitive in ConvertPrecision
initial_transformations_manager.register_pass<ov::pass::SoftmaxFusion>();
initial_transformations_manager.register_pass<ov::pass::MVNFusion>();
manager.register_pass<ov::pass::SoftmaxFusion>();
manager.register_pass<ov::pass::MVNFusion>();
// decompose MVNs that sre not supported in GPU, so that they will be marked as precision sensitive in ConvertPrecision
initial_transformations_manager.register_pass<ov::pass::MVN6Decomposition>();
manager.register_pass<ov::pass::MVN6Decomposition>();
// Run these broadcast optimizations earlier to ensure that those are executed before NopElimination/ConstantFolding
initial_transformations_manager.register_pass<ov::pass::BroadcastElementwiseFusion>();
initial_transformations_manager.register_pass<ov::pass::BroadcastTransition>();
manager.register_pass<ov::pass::BroadcastElementwiseFusion>();
manager.register_pass<ov::pass::BroadcastTransition>();

initial_transformations_manager.register_pass<ov::pass::KeepConstantsPrecisionAndAddConverts>();
initial_transformations_manager.get_pass_config()->set_callback<ov::pass::KeepConstantsPrecisionAndAddConverts>(
manager.register_pass<ov::pass::KeepConstantsPrecisionAndAddConverts>();
pass_config->set_callback<ov::pass::KeepConstantsPrecisionAndAddConverts>(
[](const_node_ptr& node) -> bool {
auto next_node = node->get_output_target_inputs(0).begin()->get_node();
if (is_type<ov::op::v0::Convert>(next_node)) {
Expand All @@ -272,22 +278,10 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
return !is_type<ov::op::v0::MatMul>(next_node);
});

initial_transformations_manager.register_pass<ov::pass::MarkDequantizationSubgraph>(ov::element::TypeVector{ov::element::u8,
ov::element::u4,
ov::element::i4}, true);

// Ignore nodes that are not related to FullyConnected and allow ConstantFolding to be applied to them
initial_transformations_manager.get_pass_config()->set_callback<ov::pass::MarkDequantizationSubgraph>(is_non_supported_decompression_op);
initial_transformations_manager.run_passes(func);

ov::pass::Manager manager;
auto pass_config = manager.get_pass_config();

manager.register_pass<ov::pass::MarkDequantizationSubgraph>(ov::element::TypeVector{ov::element::u8, ov::element::u4, ov::element::i4}, true);
// Need to check if transfomrations work correctly for mixed models with both compression and quantization at the same time.
if (enableInt8) {
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(
std::vector<ov::element::Type>{ ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4 });
}
// Ignore nodes that are not related to FullyConnected and allow ConstantFolding to be applied to them
pass_config->set_callback<ov::pass::MarkDequantizationSubgraph>(is_non_supported_decompression_op);

manager.register_pass<ov::intel_gpu::MoveConvertAfterGather>();

Expand Down

0 comments on commit 26a31b3

Please sign in to comment.