Skip to content

Commit

Permalink
[GPU] Apply is_non_decompression_multiply() callback only for compres…
Browse files Browse the repository at this point in the history
…sed models
  • Loading branch information
sshlyapn committed Dec 18, 2023
1 parent da66211 commit 51389ea
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,8 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
bool enableInt8;
bool unroll_loop = config.get_property(ov::intel_gpu::enable_loop_unrolling);
{
ov::pass::Manager manager;
auto pass_config = manager.get_pass_config();
manager.set_per_pass_validation(false);
ov::pass::Manager initial_transformations_manager;
initial_transformations_manager.set_per_pass_validation(false);

// Temporary solution, global rt info cleanup is needed
for (auto& node : func->get_ops()) {
Expand All @@ -198,13 +197,8 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
}

enableInt8 = config.get_property(ov::intel_gpu::enable_lp_transformations) && ov::pass::low_precision::LowPrecision::isFunctionQuantized(func);
if (enableInt8) {
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(
std::vector<ov::element::Type>{ ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4 });
}

manager.register_pass<ov::pass::InitNodeInfo>();
manager.register_pass<EinsumDecomposition>();
initial_transformations_manager.register_pass<ov::pass::InitNodeInfo>();
initial_transformations_manager.register_pass<EinsumDecomposition>();

precisions_map fp_convert_precision_map = {
{ov::element::f64, ov::element::f32}
Expand Down Expand Up @@ -253,19 +247,19 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
}

type_to_fuse_map empty_fuse_map = {};
manager.register_pass<ov::pass::Validate>();
initial_transformations_manager.register_pass<ov::pass::Validate>();

// fuse softmax, MVN patterns, so that they will not be marked as precision sensitive in ConvertPrecision
manager.register_pass<ov::pass::SoftmaxFusion>();
manager.register_pass<ov::pass::MVNFusion>();
initial_transformations_manager.register_pass<ov::pass::SoftmaxFusion>();
initial_transformations_manager.register_pass<ov::pass::MVNFusion>();
// decompose MVNs that sre not supported in GPU, so that they will be marked as precision sensitive in ConvertPrecision
manager.register_pass<ov::pass::MVN6Decomposition>();
initial_transformations_manager.register_pass<ov::pass::MVN6Decomposition>();
// Run these broadcast optimizations earlier to ensure that those are executed before NopElimination/ConstantFolding
manager.register_pass<ov::pass::BroadcastElementwiseFusion>();
manager.register_pass<ov::pass::BroadcastTransition>();
initial_transformations_manager.register_pass<ov::pass::BroadcastElementwiseFusion>();
initial_transformations_manager.register_pass<ov::pass::BroadcastTransition>();

manager.register_pass<ov::pass::KeepConstantsPrecisionAndAddConverts>();
pass_config->set_callback<ov::pass::KeepConstantsPrecisionAndAddConverts>(
initial_transformations_manager.register_pass<ov::pass::KeepConstantsPrecisionAndAddConverts>();
initial_transformations_manager.get_pass_config()->set_callback<ov::pass::KeepConstantsPrecisionAndAddConverts>(
[](const_node_ptr& node) -> bool {
auto next_node = node->get_output_target_inputs(0).begin()->get_node();
if (is_type<ov::op::v0::Convert>(next_node)) {
Expand All @@ -274,9 +268,18 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
return !is_type<ov::op::v0::MatMul>(next_node);
});

manager.register_pass<ov::pass::MarkDequantizationSubgraph>(ov::element::TypeVector{ov::element::u8, ov::element::u4, ov::element::i4}, true);
initial_transformations_manager.register_pass<ov::pass::MarkDequantizationSubgraph>(ov::element::TypeVector{ov::element::u8, ov::element::u4, ov::element::i4}, true);
// Ignore nodes that are not related to FullyConnected and allow ConstantFolding to be applied to them
pass_config->set_callback<ov::pass::MarkDequantizationSubgraph>(is_non_decompression_multiply);
initial_transformations_manager.get_pass_config()->set_callback<ov::pass::MarkDequantizationSubgraph>(is_non_decompression_multiply);

ov::pass::Manager manager;
auto pass_config = manager.get_pass_config();

// Need to check if transfomrations work correctly for mixed models with both compression and quantization at the same time.
if (enableInt8) {
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(
std::vector<ov::element::Type>{ ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4 });
}

const bool keep_precision_sensitive_in_fp32_1 = true;
const bool convert_input_output_precision = false;
Expand Down

0 comments on commit 51389ea

Please sign in to comment.