Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Apply is_non_decompression_multiply() callback only for compressed models #21719

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,8 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
bool enableInt8;
bool unroll_loop = config.get_property(ov::intel_gpu::enable_loop_unrolling);
{
ov::pass::Manager manager;
auto pass_config = manager.get_pass_config();
manager.set_per_pass_validation(false);
ov::pass::Manager initial_transformations_manager;
initial_transformations_manager.set_per_pass_validation(false);

// Temporary solution, global rt info cleanup is needed
for (auto& node : func->get_ops()) {
Expand All @@ -198,13 +197,8 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
}

enableInt8 = config.get_property(ov::intel_gpu::enable_lp_transformations) && ov::pass::low_precision::LowPrecision::isFunctionQuantized(func);
if (enableInt8) {
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(
std::vector<ov::element::Type>{ ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4 });
}

manager.register_pass<ov::pass::InitNodeInfo>();
manager.register_pass<EinsumDecomposition>();
initial_transformations_manager.register_pass<ov::pass::InitNodeInfo>();
initial_transformations_manager.register_pass<EinsumDecomposition>();

precisions_map fp_convert_precision_map = {
{ov::element::f64, ov::element::f32}
Expand Down Expand Up @@ -253,19 +247,19 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
}

type_to_fuse_map empty_fuse_map = {};
manager.register_pass<ov::pass::Validate>();
initial_transformations_manager.register_pass<ov::pass::Validate>();

// fuse softmax, MVN patterns, so that they will not be marked as precision sensitive in ConvertPrecision
manager.register_pass<ov::pass::SoftmaxFusion>();
manager.register_pass<ov::pass::MVNFusion>();
initial_transformations_manager.register_pass<ov::pass::SoftmaxFusion>();
initial_transformations_manager.register_pass<ov::pass::MVNFusion>();
// decompose MVNs that sre not supported in GPU, so that they will be marked as precision sensitive in ConvertPrecision
manager.register_pass<ov::pass::MVN6Decomposition>();
initial_transformations_manager.register_pass<ov::pass::MVN6Decomposition>();
// Run these broadcast optimizations earlier to ensure that those are executed before NopElimination/ConstantFolding
manager.register_pass<ov::pass::BroadcastElementwiseFusion>();
manager.register_pass<ov::pass::BroadcastTransition>();
initial_transformations_manager.register_pass<ov::pass::BroadcastElementwiseFusion>();
initial_transformations_manager.register_pass<ov::pass::BroadcastTransition>();

manager.register_pass<ov::pass::KeepConstantsPrecisionAndAddConverts>();
pass_config->set_callback<ov::pass::KeepConstantsPrecisionAndAddConverts>(
initial_transformations_manager.register_pass<ov::pass::KeepConstantsPrecisionAndAddConverts>();
initial_transformations_manager.get_pass_config()->set_callback<ov::pass::KeepConstantsPrecisionAndAddConverts>(
[](const_node_ptr& node) -> bool {
auto next_node = node->get_output_target_inputs(0).begin()->get_node();
if (is_type<ov::op::v0::Convert>(next_node)) {
Expand All @@ -274,9 +268,19 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
return !is_type<ov::op::v0::MatMul>(next_node);
});

manager.register_pass<ov::pass::MarkDequantizationSubgraph>(ov::element::TypeVector{ov::element::u8, ov::element::u4, ov::element::i4}, true);
initial_transformations_manager.register_pass<ov::pass::MarkDequantizationSubgraph>(ov::element::TypeVector{ov::element::u8, ov::element::u4, ov::element::i4}, true);
// Ignore nodes that are not related to FullyConnected and allow ConstantFolding to be applied to them
pass_config->set_callback<ov::pass::MarkDequantizationSubgraph>(is_non_decompression_multiply);
initial_transformations_manager.get_pass_config()->set_callback<ov::pass::MarkDequantizationSubgraph>(is_non_decompression_multiply);
initial_transformations_manager.run_passes(func);

ov::pass::Manager manager;
auto pass_config = manager.get_pass_config();

// Need to check if transfomrations work correctly for mixed models with both compression and quantization at the same time.
if (enableInt8) {
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(
std::vector<ov::element::Type>{ ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4 });
}

const bool keep_precision_sensitive_in_fp32_1 = true;
const bool convert_input_output_precision = false;
Expand Down
Loading