-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TF Quant Models] Prepares GroupConvolution weights #25641
Changes from 2 commits
9d3a3eb
8346adf
bdb2405
158a3ca
4434eda
d37b871
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,23 +9,29 @@ | |
|
||
#include "itt.hpp" | ||
#include "openvino/core/rt_info.hpp" | ||
#include "openvino/core/validation_util.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/convert.hpp" | ||
#include "openvino/op/fake_quantize.hpp" | ||
#include "openvino/op/group_conv.hpp" | ||
#include "openvino/op/reshape.hpp" | ||
#include "openvino/pass/pattern/op/optional.hpp" | ||
#include "openvino/pass/pattern/op/wrap_type.hpp" | ||
|
||
ov::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() { | ||
MATCHER_SCOPE(FakeQuantizeReshapeFusion); | ||
const auto fq_node_p = ov::pass::pattern::wrap_type<ov::op::v0::FakeQuantize>( | ||
{ov::pass::pattern::wrap_type<ov::op::v0::Constant>(), // for weights only | ||
pattern::any_input(), | ||
pattern::any_input(), | ||
pattern::any_input(), | ||
pattern::any_input()}, | ||
pattern::consumers_count(1)); | ||
// for weights only | ||
const auto data_p = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(pattern::has_static_shape()); | ||
const auto convert_p = ov::pass::pattern::optional<ov::op::v0::Convert>(data_p, pattern::consumers_count(1)); | ||
const auto fq_node_p = | ||
ov::pass::pattern::wrap_type<ov::op::v0::FakeQuantize>({convert_p, | ||
pattern::any_input(pattern::has_static_shape()), | ||
pattern::any_input(pattern::has_static_shape()), | ||
pattern::any_input(pattern::has_static_shape()), | ||
pattern::any_input(pattern::has_static_shape())}, | ||
pattern::consumers_count(1)); | ||
const auto reshape_node_p = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>( | ||
{fq_node_p, pattern::any_input()}, | ||
{fq_node_p, ov::pass::pattern::wrap_type<ov::op::v0::Constant>()}, | ||
[](const Output<Node>& output) { | ||
// WA: check that all Reshape node consumers are not GroupConvolution operations | ||
const auto& target_inputs = output.get_target_inputs(); | ||
|
@@ -36,13 +42,11 @@ ov::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() { | |
|
||
ov::matcher_pass_callback callback = [=](pattern::Matcher& m) { | ||
const auto& pattern_map = m.get_pattern_value_map(); | ||
const auto fq_node = pattern_map.at(fq_node_p).get_node_shared_ptr(); | ||
if (fq_node->is_dynamic()) | ||
return false; | ||
const auto& fq_node = pattern_map.at(fq_node_p).get_node_shared_ptr(); | ||
const auto& reshape_node = pattern_map.at(reshape_node_p).get_node_shared_ptr(); | ||
const auto& original_data_rank = fq_node->get_input_shape(0).size(); | ||
OutputVector renewed_inputs = { | ||
reshape_node->clone_with_new_inputs({fq_node->input_value(0), reshape_node->input_value(1)})}; | ||
|
||
OutputVector renewed_inputs = {}; | ||
for (auto i = 1; i < 5; ++i) { | ||
Output<Node> limit_input = fq_node->input_value(i); | ||
auto limit_shape = limit_input.get_shape(); | ||
|
@@ -62,21 +66,41 @@ ov::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() { | |
}); | ||
const auto& new_limit_size = shape_size(new_limit_shape); | ||
if (new_limit_size == limit_size) { // we tracked future channel placement | ||
if (new_limit_shape == limit_input.get_shape()) | ||
if (new_limit_shape == limit_input.get_shape()) { | ||
renewed_inputs.push_back(limit_input); | ||
else | ||
renewed_inputs.push_back(reshape_node->clone_with_new_inputs( | ||
} else { | ||
auto reshaped_input = reshape_node->clone_with_new_inputs( | ||
{limit_input, | ||
ov::op::v0::Constant::create(element::i64, {new_limit_shape.size()}, new_limit_shape)})); | ||
ov::op::v0::Constant::create(element::i64, {new_limit_shape.size()}, new_limit_shape)}); | ||
if (auto constant = ov::util::get_constant_from_source(reshaped_input)) { | ||
reshaped_input = constant; | ||
} | ||
renewed_inputs.push_back(reshaped_input); | ||
} | ||
continue; | ||
} | ||
} | ||
// resulting FQ will become or already is more than per-tensor / per-channel | ||
return false; | ||
} | ||
|
||
auto reshaped_input = | ||
reshape_node->clone_with_new_inputs({pattern_map.at(data_p), reshape_node->input_value(1)}); | ||
if (auto constant = ov::util::get_constant_from_source(reshaped_input)) { | ||
reshaped_input = constant; | ||
} | ||
if (pattern_map.count(convert_p)) { | ||
const auto& convert_node = pattern_map.at(convert_p).get_node_shared_ptr(); | ||
convert_node->input(0).replace_source_output(reshaped_input); | ||
convert_node->validate_and_infer_types(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am just wondering: why should we manually call node validation here? It is usually automatically done by Upd. It seems like it is needed to keep a valid graph state during one There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, you ARE right. Since we are in a single Graph Rewrite -- it is safer to keep the graph in validated state, so that all the matchers could assume rank and shape is accurate |
||
reshaped_input = convert_node; | ||
} | ||
renewed_inputs.insert(renewed_inputs.begin(), reshaped_input); | ||
|
||
for (auto& new_input : renewed_inputs) | ||
copy_runtime_info({reshape_node, fq_node}, new_input.get_node_shared_ptr()); | ||
const auto new_fq_node = fq_node->clone_with_new_inputs(renewed_inputs); | ||
register_new_node(new_fq_node); | ||
replace_node(reshape_node, new_fq_node); | ||
new_fq_node->set_friendly_name(reshape_node->get_friendly_name()); | ||
copy_runtime_info({fq_node, reshape_node}, new_fq_node); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: can we use
ov::op::util::clone_try_fold
instead to make the code a bit shorter? The same at L87 and at L60 in pull_transpose_through_fq.cppThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do that, sure. The only difference is that clone_try_fold only works for single node folding and get_constant_from_source checks all the necessary no_folding rt info and could fold a chain of nodes. For now, I won't change the code in the PR, however, in case more feedback arises -- I'll clean it up.