Skip to content

Commit

Permalink
[TF Quant Models] Prepares GroupConvolution weights (#25641)
Browse files Browse the repository at this point in the history
### Details:
- *using register_new_node in the matches to allow for seamless matcher
application for the FQ node*
- *added optional Convert node, so that pattern would work on different
quantized weights forms*


![image](https://github.com/user-attachments/assets/2d485f5b-7609-40ac-b798-16124a9a290a)

### Tickets:
 - *CVS-39818*
  • Loading branch information
jane-intel authored Nov 12, 2024
1 parent b5a4ad1 commit ff4d1e5
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,29 @@

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/fake_quantize.hpp"
#include "openvino/op/group_conv.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"

ov::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() {
MATCHER_SCOPE(FakeQuantizeReshapeFusion);
const auto fq_node_p = ov::pass::pattern::wrap_type<ov::op::v0::FakeQuantize>(
{ov::pass::pattern::wrap_type<ov::op::v0::Constant>(), // for weights only
pattern::any_input(),
pattern::any_input(),
pattern::any_input(),
pattern::any_input()},
pattern::consumers_count(1));
// for weights only
const auto data_p = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(pattern::has_static_shape());
const auto convert_p = ov::pass::pattern::optional<ov::op::v0::Convert>(data_p, pattern::consumers_count(1));
const auto fq_node_p =
ov::pass::pattern::wrap_type<ov::op::v0::FakeQuantize>({convert_p,
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape())},
pattern::consumers_count(1));
const auto reshape_node_p = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>(
{fq_node_p, pattern::any_input()},
{fq_node_p, ov::pass::pattern::wrap_type<ov::op::v0::Constant>()},
[](const Output<Node>& output) {
// WA: check that all Reshape node consumers are not GroupConvolution operations
const auto& target_inputs = output.get_target_inputs();
Expand All @@ -36,13 +42,11 @@ ov::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() {

ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
const auto fq_node = pattern_map.at(fq_node_p).get_node_shared_ptr();
if (fq_node->is_dynamic())
return false;
const auto& fq_node = pattern_map.at(fq_node_p).get_node_shared_ptr();
const auto& reshape_node = pattern_map.at(reshape_node_p).get_node_shared_ptr();
const auto& original_data_rank = fq_node->get_input_shape(0).size();
OutputVector renewed_inputs = {
reshape_node->clone_with_new_inputs({fq_node->input_value(0), reshape_node->input_value(1)})};

OutputVector renewed_inputs = {};
for (auto i = 1; i < 5; ++i) {
Output<Node> limit_input = fq_node->input_value(i);
auto limit_shape = limit_input.get_shape();
Expand All @@ -62,21 +66,41 @@ ov::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() {
});
const auto& new_limit_size = shape_size(new_limit_shape);
if (new_limit_size == limit_size) { // we tracked future channel placement
if (new_limit_shape == limit_input.get_shape())
if (new_limit_shape == limit_input.get_shape()) {
renewed_inputs.push_back(limit_input);
else
renewed_inputs.push_back(reshape_node->clone_with_new_inputs(
} else {
auto reshaped_input = reshape_node->clone_with_new_inputs(
{limit_input,
ov::op::v0::Constant::create(element::i64, {new_limit_shape.size()}, new_limit_shape)}));
ov::op::v0::Constant::create(element::i64, {new_limit_shape.size()}, new_limit_shape)});
if (auto constant = ov::util::get_constant_from_source(reshaped_input)) {
reshaped_input = constant;
}
renewed_inputs.push_back(reshaped_input);
}
continue;
}
}
// resulting FQ will become or already is more than per-tensor / per-channel
return false;
}

auto reshaped_input =
reshape_node->clone_with_new_inputs({pattern_map.at(data_p), reshape_node->input_value(1)});
if (auto constant = ov::util::get_constant_from_source(reshaped_input)) {
reshaped_input = constant;
}
if (pattern_map.count(convert_p)) {
const auto& convert_node = pattern_map.at(convert_p).get_node_shared_ptr();
convert_node->input(0).replace_source_output(reshaped_input);
convert_node->validate_and_infer_types();
reshaped_input = convert_node;
}
renewed_inputs.insert(renewed_inputs.begin(), reshaped_input);

for (auto& new_input : renewed_inputs)
copy_runtime_info({reshape_node, fq_node}, new_input.get_node_shared_ptr());
const auto new_fq_node = fq_node->clone_with_new_inputs(renewed_inputs);
register_new_node(new_fq_node);
replace_node(reshape_node, new_fq_node);
new_fq_node->set_friendly_name(reshape_node->get_friendly_name());
copy_runtime_info({fq_node, reshape_node}, new_fq_node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
#include "openvino/op/fake_quantize.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

ov::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
MATCHER_SCOPE(PullTransposeThroughFQUp);
const auto weights = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
auto m_fq = pattern::wrap_type<ov::op::v0::FakeQuantize>({weights,
const auto convert_p = ov::pass::pattern::optional<ov::op::v0::Convert>(weights, pattern::consumers_count(1));
auto m_fq = pattern::wrap_type<ov::op::v0::FakeQuantize>({convert_p,
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
Expand All @@ -33,25 +35,15 @@ ov::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
auto& pattern_map = m.get_pattern_value_map();
auto transpose = pattern_map[m_transpose].get_node_shared_ptr();
auto fq = pattern_map[m_fq].get_node_shared_ptr();

auto are_inputs_scalars =
shape_size(fq->input_value(1).get_shape()) == 1 && shape_size(fq->input_value(2).get_shape()) == 1 &&
shape_size(fq->input_value(3).get_shape()) == 1 && shape_size(fq->input_value(4).get_shape()) == 1;
if (!are_inputs_scalars) {
auto perm = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map[m_transpose_perm].get_node_shared_ptr());
if (!perm)
return false;
auto perm_val = perm->cast_vector<int64_t>();
if (!(perm_val[0] == 0 && perm_val[1] == 1))
return false;
}

auto input_rank = fq->input(0).get_partial_shape().rank().get_length();

ov::NodeVector new_ops;
ov::OutputVector fq_inputs;
for (size_t i = 0; i < fq->inputs().size(); ++i) {
auto fq_input = fq->input_value(i);
if (i == 0) {
fq_input = pattern_map[weights];
}
auto fq_input_rank = fq_input.get_partial_shape().rank().get_length();
std::vector<int64_t> unsqueeze_axes;
for (int64_t j = 0; j < input_rank - fq_input_rank; ++j) {
Expand All @@ -68,10 +60,17 @@ ov::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
fq_input = constant;
}
ov::copy_runtime_info(transpose, fq_input.get_node_shared_ptr());
if (i == 0 && pattern_map.count(convert_p)) {
const auto& convert_node = pattern_map.at(convert_p).get_node_shared_ptr();
convert_node->input(0).replace_source_output(fq_input);
convert_node->validate_and_infer_types();
fq_input = convert_node;
}
fq_inputs.push_back(fq_input);
}

auto new_fq = fq->clone_with_new_inputs(fq_inputs);
register_new_node(new_fq);
new_ops.push_back(new_fq);
new_fq->set_friendly_name(transpose->get_friendly_name());
ov::copy_runtime_info({fq, transpose}, new_ops);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/core/model.hpp"
#include "openvino/opsets/opset4.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/common_optimizations/fq_mul_fusion.hpp"
#include "transformations/common_optimizations/pull_transpose_through_fq.hpp"
#include "transformations/init_node_info.hpp"

using namespace ov;
Expand Down Expand Up @@ -66,13 +69,8 @@ class FQReshapeFusionTests : public ov::test::TestsCommon,
}

std::shared_ptr<ov::Model> get_reference_function(const FQReshapeFusionTestCase& test_case) {
const auto& data = std::make_shared<opset4::Constant>(element::f32, test_case.data_shape, 0);
const auto& reshaped_data = std::make_shared<opset4::Reshape>(
data,
std::make_shared<opset4::Constant>(element::i64,
Shape{test_case.reshape_pattern.size()},
test_case.reshape_pattern),
true);
auto shape = PartialShape(test_case.reshape_pattern).to_shape();
const auto& data = std::make_shared<opset4::Constant>(element::f32, shape, 0);

const auto& p_il = std::make_shared<opset4::Parameter>(element::f32, test_case.il_shape);
Output<Node> il = p_il;
Expand Down Expand Up @@ -104,7 +102,7 @@ class FQReshapeFusionTests : public ov::test::TestsCommon,
opset4::Constant::create(element::i64, {test_case.new_oh_shape.size()}, test_case.new_oh_shape),
true);

auto fq = std::make_shared<opset4::FakeQuantize>(reshaped_data, il, ih, ol, oh, 42);
auto fq = std::make_shared<opset4::FakeQuantize>(data, il, ih, ol, oh, 42);

auto result = std::make_shared<op::v0::Result>(fq);
ParameterVector params = {p_il, p_ih, p_ol, p_oh};
Expand Down Expand Up @@ -213,3 +211,77 @@ TEST_F(TransformationTestsF, FQReshapeGroupConvolution) {
manager.register_pass<ov::pass::InitNodeInfo>();
manager.register_pass<ov::pass::FakeQuantizeReshapeFusion>();
}

TEST_F(TransformationTestsF, FQOptimizations) {
{
const auto& data = std::make_shared<opset4::Constant>(element::u8, Shape{9, 32}, 0);
const auto& convert = std::make_shared<opset4::Convert>(data, element::f32);

const auto& il = op::v0::Constant::create(element::f32, Shape{1}, {0});
const auto& ih = op::v0::Constant::create(element::f32, Shape{1}, {254});
const auto& ol = op::v0::Constant::create(element::f32, Shape{32}, {-14.22});
const auto& oh = op::v0::Constant::create(element::f32, Shape{32}, {14.22});

const auto& fq = std::make_shared<opset4::FakeQuantize>(convert, il, ih, ol, oh, 255);

const auto& reshape =
std::make_shared<opset4::Reshape>(fq,
op::v0::Constant::create(element::i64, Shape{4}, {3, 3, 32, 1}),
true);

const auto& multiply =
std::make_shared<opset4::Multiply>(reshape,
op::v0::Constant::create(element::f32, Shape{1, 1, 32, 1}, {0.1140}));

const auto& transpose =
std::make_shared<opset4::Transpose>(multiply,
op::v0::Constant::create(element::i64, Shape{4}, {2, 3, 0, 1}));

const auto& reshape_to_weight =
std::make_shared<opset4::Reshape>(transpose,
op::v0::Constant::create(element::i64, Shape{5}, {32, 1, 1, 3, 3}),
true);

const auto& input = std::make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(4));
const auto& group_conv = std::make_shared<opset4::GroupConvolution>(input,
reshape_to_weight,
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});

model = std::make_shared<ov::Model>(OutputVector{group_conv}, ParameterVector{input});

auto fq_fusions = manager.register_pass<ov::pass::GraphRewrite>();
fq_fusions->add_matcher<ov::pass::FakeQuantizeMulFusion>();
fq_fusions->add_matcher<ov::pass::FakeQuantizeReshapeFusion>();
fq_fusions->add_matcher<ov::pass::PullTransposeThroughFQUp>();
fq_fusions->set_name("ov::pass::FakeQuantizeFusions");
}
{
const auto& data = std::make_shared<opset4::Constant>(element::u8, Shape{32, 1, 3, 3}, 0);
const auto& convert = std::make_shared<opset4::Convert>(data, element::f32);

const auto& il = op::v0::Constant::create(element::f32, Shape{1, 1, 1, 1}, {0});
const auto& ih = op::v0::Constant::create(element::f32, Shape{1, 1, 1, 1}, {254});
const auto& ol = op::v0::Constant::create(element::f32, Shape{32, 1, 1, 1}, {-14.22 * 0.1140});
const auto& oh = op::v0::Constant::create(element::f32, Shape{32, 1, 1, 1}, {14.22 * 0.1140});

const auto& fq = std::make_shared<opset4::FakeQuantize>(convert, il, ih, ol, oh, 255);

const auto& reshape_to_weight =
std::make_shared<opset4::Reshape>(fq,
op::v0::Constant::create(element::i64, Shape{5}, {32, 1, 1, 3, 3}),
true);

const auto& input = std::make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(4));
const auto& group_conv = std::make_shared<opset4::GroupConvolution>(input,
reshape_to_weight,
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});

model_ref = std::make_shared<ov::Model>(OutputVector{group_conv}, ParameterVector{input});
}
}

0 comments on commit ff4d1e5

Please sign in to comment.