Skip to content

Commit

Permalink
fixed the marking on gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Nov 21, 2024
1 parent 1c7a72e commit 416d610
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void set_rt_info(const PatternValueMap& pt_map,
}
};

void swap_nodes(const PatternValueMap& pt_map,
bool swap_nodes(const PatternValueMap& pt_map,
const std::shared_ptr<Node>& first,
const std::shared_ptr<Node>& second) {
if (pt_map.count(first) && pt_map.count(second)) {
Expand All @@ -59,7 +59,9 @@ void swap_nodes(const PatternValueMap& pt_map,
}
first_node->validate_and_infer_types();
second_node->validate_and_infer_types();
return true;
}
return false;
}

} // namespace
Expand Down Expand Up @@ -89,7 +91,7 @@ ov::pass::MarkDequantization::MarkDequantization(const element::TypeVector& prec
auto input = pt_map.at(input_pattern);
const auto multiply = m.get_match_root();

if (transformation_callback(multiply)) {
if (!check_precision(input.get_element_type(), precisions) || transformation_callback(multiply)) {
return false;
}

Expand Down Expand Up @@ -117,9 +119,9 @@ ov::pass::MarkDequantization::MarkDequantization(const element::TypeVector& prec
set_rt_info(pt_map, enable_constant_folding, converts_to_unmark, precisions);

// Move Reshape/Unsqueeze ops up to fold them in ConstantFolding.
swap_nodes(pt_map, zp_convert_pattern, zp_reshape_pattern);
swap_nodes(pt_map, scale_convert_pattern, scale_reshape_pattern);
return false;
auto changed = swap_nodes(pt_map, zp_convert_pattern, zp_reshape_pattern);
changed = changed || swap_nodes(pt_map, scale_convert_pattern, scale_reshape_pattern);
return changed;
};

auto m = std::make_shared<Matcher>(multiply_pattern, "MarkDequantization");
Expand Down
13 changes: 8 additions & 5 deletions src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,12 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {

auto is_model_quantized = ov::pass::low_precision::LowPrecision::isFunctionQuantized(func);
enableInt8 = config.get_property(ov::intel_gpu::enable_lp_transformations) && is_model_quantized;
if (enableInt8) {
manager.register_pass<ov::pass::MarkDequantization>(
std::vector<ov::element::Type>{ ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4 });
}

//if (enableInt8) { Why do we need this check? According to the line 378 we did this marking anyway
manager.register_pass<ov::pass::MarkDequantization>(
std::vector<ov::element::Type>{ ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4 },
!device_info.supports_immad);
//}

manager.register_pass<ov::pass::InitNodeInfo>();
manager.register_pass<EinsumDecomposition>();
Expand Down Expand Up @@ -373,7 +375,8 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
// it expects to have the same data type for weights and zero points (apply it only for u8 data type, since other compression
// types are not supported by oneDNN)
manager.register_pass<ov::pass::KeepConstsPrecision>(supported_woq_types, !device_info.supports_immad);
pass_config->set_callback<ov::pass::KeepConstsPrecision>([&](const std::shared_ptr<const ov::Node> node) {
pass_config->set_callback<ov::pass::MarkDequantization,
ov::pass::KeepConstsPrecision>([&](const std::shared_ptr<const ov::Node> node) {
return !is_decompression_multiply(node, device_info.supports_immad);
});

Expand Down

0 comments on commit 416d610

Please sign in to comment.