Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF Quant Models] Prepares GroupConvolution weights #25641

Merged
merged 6 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,29 @@

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/fake_quantize.hpp"
#include "openvino/op/group_conv.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"

ov::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() {
MATCHER_SCOPE(FakeQuantizeReshapeFusion);
const auto fq_node_p = ov::pass::pattern::wrap_type<ov::op::v0::FakeQuantize>(
{ov::pass::pattern::wrap_type<ov::op::v0::Constant>(), // for weights only
pattern::any_input(),
pattern::any_input(),
pattern::any_input(),
pattern::any_input()},
pattern::consumers_count(1));
// for weights only
const auto data_p = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(pattern::has_static_shape());
const auto convert_p = ov::pass::pattern::optional<ov::op::v0::Convert>(data_p, pattern::consumers_count(1));
const auto fq_node_p =
ov::pass::pattern::wrap_type<ov::op::v0::FakeQuantize>({convert_p,
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape())},
pattern::consumers_count(1));
const auto reshape_node_p = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>(
{fq_node_p, pattern::any_input()},
{fq_node_p, ov::pass::pattern::wrap_type<ov::op::v0::Constant>()},
[](const Output<Node>& output) {
// WA: check that all Reshape node consumers are not GroupConvolution operations
const auto& target_inputs = output.get_target_inputs();
Expand All @@ -36,13 +42,11 @@ ov::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() {

ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
const auto fq_node = pattern_map.at(fq_node_p).get_node_shared_ptr();
if (fq_node->is_dynamic())
return false;
const auto& fq_node = pattern_map.at(fq_node_p).get_node_shared_ptr();
const auto& reshape_node = pattern_map.at(reshape_node_p).get_node_shared_ptr();
const auto& original_data_rank = fq_node->get_input_shape(0).size();
OutputVector renewed_inputs = {
reshape_node->clone_with_new_inputs({fq_node->input_value(0), reshape_node->input_value(1)})};

OutputVector renewed_inputs = {};
for (auto i = 1; i < 5; ++i) {
Output<Node> limit_input = fq_node->input_value(i);
auto limit_shape = limit_input.get_shape();
Expand All @@ -62,21 +66,41 @@ ov::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() {
});
const auto& new_limit_size = shape_size(new_limit_shape);
if (new_limit_size == limit_size) { // we tracked future channel placement
if (new_limit_shape == limit_input.get_shape())
if (new_limit_shape == limit_input.get_shape()) {
renewed_inputs.push_back(limit_input);
else
renewed_inputs.push_back(reshape_node->clone_with_new_inputs(
} else {
auto reshaped_input = reshape_node->clone_with_new_inputs(
{limit_input,
ov::op::v0::Constant::create(element::i64, {new_limit_shape.size()}, new_limit_shape)}));
ov::op::v0::Constant::create(element::i64, {new_limit_shape.size()}, new_limit_shape)});
if (auto constant = ov::util::get_constant_from_source(reshaped_input)) {
reshaped_input = constant;
}
renewed_inputs.push_back(reshaped_input);
Comment on lines +72 to +78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: can we use ov::op::util::clone_try_fold instead to make the code a bit shorter? The same at L87 and at L60 in pull_transpose_through_fq.cpp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that, sure. The only difference is that clone_try_fold only works for single node folding and get_constant_from_source checks all the necessary no_folding rt info and could fold a chain of nodes. For now, I won't change the code in the PR, however, in case more feedback arises -- I'll clean it up.

}
continue;
}
}
// resulting FQ will become or already is more than per-tensor / per-channel
return false;
}

auto reshaped_input =
reshape_node->clone_with_new_inputs({pattern_map.at(data_p), reshape_node->input_value(1)});
if (auto constant = ov::util::get_constant_from_source(reshaped_input)) {
reshaped_input = constant;
}
if (pattern_map.count(convert_p)) {
const auto& convert_node = pattern_map.at(convert_p).get_node_shared_ptr();
convert_node->input(0).replace_source_output(reshaped_input);
convert_node->validate_and_infer_types();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just wondering: why should we manually call node validation here? It is usually automatically done by Validate pass, isn't it?

Upd. It seems like it is needed to keep a valid graph state during one GraphRewrite pass. Am I right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you ARE right. Since we are in a single Graph Rewrite -- it is safer to keep the graph in validated state, so that all the matchers could assume rank and shape is accurate

reshaped_input = convert_node;
}
renewed_inputs.insert(renewed_inputs.begin(), reshaped_input);

for (auto& new_input : renewed_inputs)
copy_runtime_info({reshape_node, fq_node}, new_input.get_node_shared_ptr());
const auto new_fq_node = fq_node->clone_with_new_inputs(renewed_inputs);
register_new_node(new_fq_node);
replace_node(reshape_node, new_fq_node);
new_fq_node->set_friendly_name(reshape_node->get_friendly_name());
copy_runtime_info({fq_node, reshape_node}, new_fq_node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
#include "openvino/op/fake_quantize.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

ov::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
MATCHER_SCOPE(PullTransposeThroughFQUp);
const auto weights = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
auto m_fq = pattern::wrap_type<ov::op::v0::FakeQuantize>({weights,
const auto convert_p = ov::pass::pattern::optional<ov::op::v0::Convert>(weights, pattern::consumers_count(1));
auto m_fq = pattern::wrap_type<ov::op::v0::FakeQuantize>({convert_p,
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
Expand All @@ -34,25 +36,15 @@ ov::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
auto transpose = pattern_map[m_transpose].get_node_shared_ptr();
auto fq = pattern_map[m_fq].get_node_shared_ptr();

auto are_inputs_scalars =
shape_size(fq->input_value(1).get_shape()) == 1 && shape_size(fq->input_value(2).get_shape()) == 1 &&
shape_size(fq->input_value(3).get_shape()) == 1 && shape_size(fq->input_value(4).get_shape()) == 1;
if (!are_inputs_scalars) {
auto perm =
std::dynamic_pointer_cast<ov::op::v0::Constant>(pattern_map[m_transpose_perm].get_node_shared_ptr());
if (!perm)
return false;
auto perm_val = perm->cast_vector<int64_t>();
if (!(perm_val[0] == 0 && perm_val[1] == 1))
return false;
}

auto input_rank = fq->input(0).get_partial_shape().rank().get_length();

ov::NodeVector new_ops;
ov::OutputVector fq_inputs;
for (size_t i = 0; i < fq->inputs().size(); ++i) {
auto fq_input = fq->input_value(i);
if (i == 0) {
fq_input = pattern_map[weights];
}
auto fq_input_rank = fq_input.get_partial_shape().rank().get_length();
std::vector<int64_t> unsqueeze_axes;
for (int64_t j = 0; j < input_rank - fq_input_rank; ++j) {
Expand All @@ -69,10 +61,17 @@ ov::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
fq_input = constant;
}
ov::copy_runtime_info(transpose, fq_input.get_node_shared_ptr());
if (i == 0 && pattern_map.count(convert_p)) {
const auto& convert_node = pattern_map.at(convert_p).get_node_shared_ptr();
convert_node->input(0).replace_source_output(fq_input);
convert_node->validate_and_infer_types();
fq_input = convert_node;
}
fq_inputs.push_back(fq_input);
}

auto new_fq = fq->clone_with_new_inputs(fq_inputs);
register_new_node(new_fq);
new_ops.push_back(new_fq);
new_fq->set_friendly_name(transpose->get_friendly_name());
ov::copy_runtime_info({fq, transpose}, new_ops);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "openvino/core/model.hpp"
#include "openvino/opsets/opset4.hpp"
#include "openvino/pass/manager.hpp"
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
#include "transformations/common_optimizations/fq_mul_fusion.hpp"
#include "transformations/common_optimizations/pull_transpose_through_fq.hpp"
#include "transformations/init_node_info.hpp"

using namespace ov;
Expand Down Expand Up @@ -66,13 +68,8 @@ class FQReshapeFusionTests : public ov::test::TestsCommon,
}

std::shared_ptr<ov::Model> get_reference_function(const FQReshapeFusionTestCase& test_case) {
const auto& data = std::make_shared<opset4::Constant>(element::f32, test_case.data_shape, 0);
const auto& reshaped_data = std::make_shared<opset4::Reshape>(
data,
std::make_shared<opset4::Constant>(element::i64,
Shape{test_case.reshape_pattern.size()},
test_case.reshape_pattern),
true);
auto shape = PartialShape(test_case.reshape_pattern).to_shape();
const auto& data = std::make_shared<opset4::Constant>(element::f32, shape, 0);

const auto& p_il = std::make_shared<opset4::Parameter>(element::f32, test_case.il_shape);
Output<Node> il = p_il;
Expand Down Expand Up @@ -104,7 +101,7 @@ class FQReshapeFusionTests : public ov::test::TestsCommon,
opset4::Constant::create(element::i64, {test_case.new_oh_shape.size()}, test_case.new_oh_shape),
true);

auto fq = std::make_shared<opset4::FakeQuantize>(reshaped_data, il, ih, ol, oh, 42);
auto fq = std::make_shared<opset4::FakeQuantize>(data, il, ih, ol, oh, 42);

auto result = std::make_shared<op::v0::Result>(fq);
ParameterVector params = {p_il, p_ih, p_ol, p_oh};
Expand Down Expand Up @@ -213,3 +210,77 @@ TEST_F(TransformationTestsF, FQReshapeGroupConvolution) {
manager.register_pass<ov::pass::InitNodeInfo>();
manager.register_pass<ov::pass::FakeQuantizeReshapeFusion>();
}

TEST_F(TransformationTestsF, FQOptimizations) {
{
const auto& data = std::make_shared<opset4::Constant>(element::u8, Shape{9, 32}, 0);
const auto& convert = std::make_shared<opset4::Convert>(data, element::f32);

const auto& il = op::v0::Constant::create(element::f32, Shape{1}, {0});
const auto& ih = op::v0::Constant::create(element::f32, Shape{1}, {254});
const auto& ol = op::v0::Constant::create(element::f32, Shape{32}, {-14.22});
const auto& oh = op::v0::Constant::create(element::f32, Shape{32}, {14.22});

const auto& fq = std::make_shared<opset4::FakeQuantize>(convert, il, ih, ol, oh, 255);

const auto& reshape =
std::make_shared<opset4::Reshape>(fq,
op::v0::Constant::create(element::i64, Shape{4}, {3, 3, 32, 1}),
true);

const auto& multiply =
std::make_shared<opset4::Multiply>(reshape,
op::v0::Constant::create(element::f32, Shape{1, 1, 32, 1}, {0.1140}));

const auto& transpose =
std::make_shared<opset4::Transpose>(multiply,
op::v0::Constant::create(element::i64, Shape{4}, {2, 3, 0, 1}));

const auto& reshape_to_weight =
std::make_shared<opset4::Reshape>(transpose,
op::v0::Constant::create(element::i64, Shape{5}, {32, 1, 1, 3, 3}),
true);

const auto& input = std::make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(4));
const auto& group_conv = std::make_shared<opset4::GroupConvolution>(input,
reshape_to_weight,
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});

model = std::make_shared<ov::Model>(OutputVector{group_conv}, ParameterVector{input});

auto fq_fusions = manager.register_pass<ov::pass::GraphRewrite>();
fq_fusions->add_matcher<ov::pass::FakeQuantizeMulFusion>();
fq_fusions->add_matcher<ov::pass::FakeQuantizeReshapeFusion>();
fq_fusions->add_matcher<ov::pass::PullTransposeThroughFQUp>();
fq_fusions->set_name("ov::pass::FakeQuantizeFusions");
}
{
const auto& data = std::make_shared<opset4::Constant>(element::u8, Shape{32, 1, 3, 3}, 0);
const auto& convert = std::make_shared<opset4::Convert>(data, element::f32);

const auto& il = op::v0::Constant::create(element::f32, Shape{1, 1, 1, 1}, {0});
const auto& ih = op::v0::Constant::create(element::f32, Shape{1, 1, 1, 1}, {254});
const auto& ol = op::v0::Constant::create(element::f32, Shape{32, 1, 1, 1}, {-14.22 * 0.1140});
const auto& oh = op::v0::Constant::create(element::f32, Shape{32, 1, 1, 1}, {14.22 * 0.1140});

const auto& fq = std::make_shared<opset4::FakeQuantize>(convert, il, ih, ol, oh, 255);

const auto& reshape_to_weight =
std::make_shared<opset4::Reshape>(fq,
op::v0::Constant::create(element::i64, Shape{5}, {32, 1, 1, 3, 3}),
true);

const auto& input = std::make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(4));
const auto& group_conv = std::make_shared<opset4::GroupConvolution>(input,
reshape_to_weight,
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});

model_ref = std::make_shared<ov::Model>(OutputVector{group_conv}, ParameterVector{input});
}
}
Loading