Skip to content

Commit

Permalink
Added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Oct 16, 2023
1 parent eb258cd commit 270839e
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@

ov::pass::FuseU4WeightsAndZeroPoint::FuseU4WeightsAndZeroPoint() {
MATCHER_SCOPE(FuseU4WeightsAndZeroPoint);
auto weights_m = pattern::wrap_type<ov::op::v0::Constant>(pattern::type_matches(ov::element::u4));
auto convert_m = pattern::wrap_type<ov::op::v0::Convert>({weights_m});
auto zero_point_m = pattern::wrap_type<ov::op::v0::Constant>();
auto subtract_m = pattern::wrap_type<ov::op::v1::Subtract>({convert_m, zero_point_m});
auto weights_match = [](ov::Output<ov::Node> output) -> bool {
return pattern::type_matches(ov::element::u4)(output) && pattern::consumers_count(1)(output);
};
auto weights_m = pattern::wrap_type<ov::op::v0::Constant>(weights_match);
auto convert_m = pattern::wrap_type<ov::op::v0::Convert>({weights_m}, pattern::consumers_count(1));
auto zero_point_m = pattern::wrap_type<ov::op::v0::Constant>(pattern::consumers_count(1));
auto subtract_m = pattern::wrap_type<ov::op::v1::Subtract>({convert_m, zero_point_m}, pattern::consumers_count(1));

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
auto& pattern_map = m.get_pattern_value_map();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/fuse_u4_weights_zero_point.hpp"

#include <gtest/gtest.h>

#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/core/model.hpp"
#include "openvino/pass/manager.hpp"

using namespace testing;
using namespace ov;

TEST_F(TransformationTestsF, FuseU4WeightsAndZeroPoint) {
auto decompression_precision = ov::element::f32;
ov::Shape weights_shape{32, 128, 64};
ov::Shape decompression_shape{32, 1, 64};
{
std::vector<int> weights_values(ov::shape_size(weights_shape));
for (size_t i = 0; i < weights_values.size(); ++i) {
weights_values[i] = i % 16;
}
auto weights = ov::op::v0::Constant::create(ov::element::u4, weights_shape, weights_values);
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {8});
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
model = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
manager.register_pass<ov::pass::FuseU4WeightsAndZeroPoint>();
}
{
std::vector<int> weights_values(ov::shape_size(weights_shape));
for (size_t i = 0; i < weights_values.size(); ++i) {
weights_values[i] = i % 16 - 8;
}
auto weights = ov::op::v0::Constant::create(ov::element::i4, weights_shape, weights_values);
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(convert, scale);
model_ref = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
}
}

TEST_F(TransformationTestsF, FuseU4WeightsAndZeroPointUnaccaptableZeroPoint) {
auto decompression_precision = ov::element::f32;
ov::Shape weights_shape{32, 128, 64};
ov::Shape decompression_shape{32, 1, 64};

std::vector<int> weights_values(ov::shape_size(weights_shape));
for (size_t i = 0; i < weights_values.size(); ++i) {
weights_values[i] = i % 8;
}
auto weights = ov::op::v0::Constant::create(ov::element::u4, weights_shape, weights_values);
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {4});
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
model = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
manager.register_pass<ov::pass::FuseU4WeightsAndZeroPoint>();
}

TEST_F(TransformationTestsF, FuseU4WeightsAndZeroPointOutOfI4Range) {
auto decompression_precision = ov::element::f32;
ov::Shape weights_shape{32, 128, 64};
ov::Shape decompression_shape{32, 1, 64};

std::vector<int> weights_values(ov::shape_size(weights_shape));
for (size_t i = 0; i < weights_values.size(); ++i) {
weights_values[i] = i % 16;
}
auto weights = ov::op::v0::Constant::create(ov::element::u4, weights_shape, weights_values);
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
std::vector<float> zero_point_values(ov::shape_size(decompression_shape));
zero_point_values.back() = 16;
auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, zero_point_values);
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
model = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
manager.register_pass<ov::pass::FuseU4WeightsAndZeroPoint>();
}

TEST_F(TransformationTestsF, FuseU4WeightsAndZeroPointWrongWeightsPrecision) {
auto decompression_precision = ov::element::f32;
ov::Shape weights_shape{32, 128, 64};

std::vector<int> weights_values(ov::shape_size(weights_shape));
for (size_t i = 0; i < weights_values.size(); ++i) {
weights_values[i] = i % 256;
}
auto weights = ov::op::v0::Constant::create(ov::element::u8, weights_shape, weights_values);
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
auto zero_point = ov::op::v0::Constant::create(decompression_precision, {32, 1, 64}, {8});
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point);
auto scale = ov::op::v0::Constant::create(decompression_precision, {32, 1, 64}, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
model = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
manager.register_pass<ov::pass::FuseU4WeightsAndZeroPoint>();
}

TEST_F(TransformationTestsF, FuseU4WeightsAndZeroPointSeveralWeightsConsumers) {
auto decompression_precision = ov::element::f32;
ov::Shape weights_shape{32, 128, 64};
ov::Shape decompression_shape{32, 1, 64};

auto weights = ov::op::v0::Constant::create(ov::element::u4, weights_shape, {4});
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
auto additional_consumer = std::make_shared<ov::op::v0::Convert>(weights, ov::element::i32);
auto additional_op = std::make_shared<ov::op::v1::Multiply>(additional_consumer, ov::op::v0::Constant::create(ov::element::i32, {}, {-1}));
auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {8});
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
model = std::make_shared<Model>(NodeVector{multiply, additional_op}, ParameterVector{});
manager.register_pass<ov::pass::FuseU4WeightsAndZeroPoint>();
}

0 comments on commit 270839e

Please sign in to comment.