Skip to content

Commit

Permalink
[GPU] Added test and removed fusion of second matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman committed Dec 4, 2024
1 parent e42850f commit 6dc506b
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,46 +106,19 @@ LoRAHorizontalFusion::LoRAHorizontalFusion() {
old_multiply->clear_control_dependencies();
}

bool fuse_second_matmul = true;
size_t not_concatenable_idx = 0;
const auto& base_dim = variable_b_nodes[0]->get_output_partial_shape(0)[not_concatenable_idx];
for (size_t i = 1; i < variable_b_nodes.size(); ++i) {
const auto& dim = variable_b_nodes[i]->get_output_partial_shape(0)[not_concatenable_idx];
if (dim.is_dynamic() || dim.get_length() != base_dim.get_length()) {
fuse_second_matmul = false;
}
auto axis_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {fused_multiply->get_output_partial_shape(0).size() - 1});
auto output_split = std::make_shared<ov::op::v1::Split>(fused_multiply, axis_const, matmul2_nodes.size());
auto split_name = fused_multiply->get_friendly_name() + "_split";
copy_runtime_info(fused_multiply, output_split);
output_split->set_friendly_name(split_name);
for (size_t i = 0; i < matmul2_nodes.size(); ++i) {
matmul2_nodes[i]->input(0).replace_source_output(output_split->output(i));
}

std::shared_ptr<ov::Node> fused_matmul2 = nullptr;
if (fuse_second_matmul) {
auto fused_variable_b = std::make_shared<ov::op::v0::Concat>(variable_b_nodes, 1);
fused_variable_b->set_friendly_name(variable_b_nodes[0]->get_friendly_name() +
"_fused" + std::to_string(variable_b_nodes.size()) + "_ReadValues");
ov::copy_runtime_info(variable_b_nodes, fused_variable_b);

bool transpose_a2 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(matmul2_nodes[0])->get_transpose_a();
bool transpose_b2 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(matmul2_nodes[0])->get_transpose_b();
fused_matmul2 = std::make_shared<ov::op::v0::MatMul>(fused_multiply, fused_variable_b, transpose_a2, transpose_b2);
auto matmul2_name = matmul2_nodes[0]->get_friendly_name() + "_fused_" + std::to_string(matmul2_nodes.size()) + "_MatMuls";
fused_matmul2->set_friendly_name(matmul2_name);
ov::copy_runtime_info(matmul2_nodes, fused_matmul2);
for (const auto& old_matmul2 : matmul2_nodes) {
old_matmul2->clear_control_dependencies();
}
} else {
auto axis_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {fused_multiply->get_output_partial_shape(0).size() - 1});
auto output_split = std::make_shared<ov::op::v1::Split>(fused_multiply, axis_const, matmul2_nodes.size());
auto split_name = fused_multiply->get_friendly_name() + "_split";
copy_runtime_info(fused_multiply, output_split);
output_split->set_friendly_name(split_name);
for (size_t i = 0; i < matmul2_nodes.size(); ++i) {
matmul2_nodes[i]->input(0).replace_source_output(output_split->output(i));
}

fused_matmul2 = std::make_shared<ov::op::v0::Concat>(matmul2_nodes, matmul2_nodes[0]->get_output_partial_shape(0).size() - 1);
auto matmul2_name = matmul2_nodes[0]->get_friendly_name() + "_fused_" + std::to_string(matmul2_nodes.size()) + "_MatMuls_output";
fused_matmul2->set_friendly_name(matmul2_name);
}
auto fused_matmul2 = std::make_shared<ov::op::v0::Concat>(matmul2_nodes, matmul2_nodes[0]->get_output_partial_shape(0).size() - 1);
auto matmul2_name = matmul2_nodes[0]->get_friendly_name() + "_fused_" + std::to_string(matmul2_nodes.size()) + "_MatMuls_output";
fused_matmul2->set_friendly_name(matmul2_name);
ov::copy_runtime_info(matmul2_nodes, fused_matmul2);

auto fused_add = std::make_shared<ov::op::v1::Add>(split->get_input_node_shared_ptr(0), fused_matmul2);
auto fused_add_name = add_nodes[0]->get_friendly_name() + "_fused_" + std::to_string(add_nodes.size()) + "_Adds";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/opsets/opset6.hpp"

#include "plugin/transformations/lora_horizontal_fusion.hpp"
#include "intel_gpu/op/placeholder.hpp"
#include "intel_gpu/op/fully_connected_compressed.hpp"

using namespace testing;
using namespace ov::intel_gpu;

namespace ov {
namespace test {
namespace intel_gpu {

TEST_F(TransformationTestsF, LoRAHorizontalFusion) {
ov::element::Type model_dt = ov::element::f16;
{
auto lora_input = std::make_shared<ov::op::v0::Parameter>(model_dt, ov::PartialShape{-1, -1, 2048});
auto weights = std::make_shared<ov::op::v0::Constant>(ov::element::u8, ov::Shape{2560, 2048});
auto bias = std::make_shared<ov::intel_gpu::op::Placeholder>();
auto scale = std::make_shared<ov::op::v0::Constant>(model_dt, ov::Shape{2560, 1});
auto fc_fused = std::make_shared<ov::intel_gpu::op::FullyConnectedCompressed>(lora_input, weights, bias, scale);

auto axis_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {2});
auto split_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {2048, 256, 256});
auto split = std::make_shared<ov::op::v1::VariadicSplit>(fc_fused, axis_const, split_const);

auto variable_a_0 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({-1, 2048}), model_dt, "var_a_0"});
auto variable_alpha_0 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({1, -1}), model_dt, "var_alpha_0"});
auto variable_b_0 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({2048, -1}), model_dt, "var_b_0"});
auto read_value_a_0 = std::make_shared<ov::op::v6::ReadValue>(variable_a_0);
auto read_value_alpha_0 = std::make_shared<ov::op::v6::ReadValue>(variable_alpha_0);
auto read_value_b_0 = std::make_shared<ov::op::v6::ReadValue>(variable_b_0);
auto matmul1_0 = std::make_shared<ov::op::v0::MatMul>(lora_input, read_value_a_0, false, true);
auto multiply_0 = std::make_shared<ov::op::v1::Multiply>(matmul1_0, read_value_alpha_0);
auto matmul2_0 = std::make_shared<ov::op::v0::MatMul>(multiply_0, read_value_b_0, false, true);
auto add_0 = std::make_shared<ov::op::v1::Add>(split->output(0), matmul2_0);

auto variable_a_1 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({-1, 2048}), model_dt, "var_a_1"});
auto variable_alpha_1 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({1, -1}), model_dt, "var_alpha_1"});
auto variable_b_1 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({256, -1}), model_dt, "var_b_1"});
auto read_value_a_1 = std::make_shared<ov::op::v6::ReadValue>(variable_a_1);
auto read_value_alpha_1 = std::make_shared<ov::op::v6::ReadValue>(variable_alpha_1);
auto read_value_b_1 = std::make_shared<ov::op::v6::ReadValue>(variable_b_1);
auto matmul1_1 = std::make_shared<ov::op::v0::MatMul>(lora_input, read_value_a_1, false, true);
auto multiply_1 = std::make_shared<ov::op::v1::Multiply>(matmul1_1, read_value_alpha_1);
auto matmul2_1 = std::make_shared<ov::op::v0::MatMul>(multiply_1, read_value_b_1, false, true);
auto add_1 = std::make_shared<ov::op::v1::Add>(split->output(1), matmul2_1);

auto variable_a_2 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({-1, 2048}), model_dt, "var_a_2"});
auto variable_alpha_2 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({1, -1}), model_dt, "var_alpha_2"});
auto variable_b_2 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({256, -1}), model_dt, "var_b_2"});
auto read_value_a_2 = std::make_shared<ov::op::v6::ReadValue>(variable_a_2);
auto read_value_alpha_2 = std::make_shared<ov::op::v6::ReadValue>(variable_alpha_2);
auto read_value_b_2 = std::make_shared<ov::op::v6::ReadValue>(variable_b_2);
auto matmul1_2 = std::make_shared<ov::op::v0::MatMul>(lora_input, read_value_a_2, false, true);
auto multiply_2 = std::make_shared<ov::op::v1::Multiply>(matmul1_2, read_value_alpha_2);
auto matmul2_2 = std::make_shared<ov::op::v0::MatMul>(multiply_2, read_value_b_2, false, true);
auto add_2 = std::make_shared<ov::op::v1::Add>(split->output(2), matmul2_2);

auto reshape_pattern0 = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 0, 32, 64});
auto reshape_pattern1 = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 0, 4, 64});
auto reshape_pattern2 = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 0, 4, 64});
auto reshape0 = std::make_shared<ov::op::v1::Reshape>(add_0, reshape_pattern0, true);
auto reshape1 = std::make_shared<ov::op::v1::Reshape>(add_1, reshape_pattern1, true);
auto reshape2 = std::make_shared<ov::op::v1::Reshape>(add_2, reshape_pattern2, true);

auto result0 = std::make_shared<ov::op::v0::Result>(reshape0);
auto result1 = std::make_shared<ov::op::v0::Result>(reshape1);
auto result2 = std::make_shared<ov::op::v0::Result>(reshape2);

model = std::make_shared<ov::Model>(ov::NodeVector{result0, result1, result2}, ov::ParameterVector{lora_input});
manager.register_pass<LoRAHorizontalFusion>();
}

{
auto lora_input = std::make_shared<ov::op::v0::Parameter>(model_dt, ov::PartialShape{-1, -1, 2048});
auto weights = std::make_shared<ov::op::v0::Constant>(ov::element::u8, ov::Shape{2560, 2048});
auto bias = std::make_shared<ov::intel_gpu::op::Placeholder>();
auto scale = std::make_shared<ov::op::v0::Constant>(model_dt, ov::Shape{2560, 1});
auto fc_fused = std::make_shared<ov::intel_gpu::op::FullyConnectedCompressed>(lora_input, weights, bias, scale);

auto variable_a_0 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({-1, 2048}), model_dt, "var_a_0"});
auto variable_a_1 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({-1, 2048}), model_dt, "var_a_1"});
auto variable_a_2 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({-1, 2048}), model_dt, "var_a_2"});

auto read_value_a_0 = std::make_shared<ov::op::v6::ReadValue>(variable_a_0);
auto read_value_a_1 = std::make_shared<ov::op::v6::ReadValue>(variable_a_1);
auto read_value_a_2 = std::make_shared<ov::op::v6::ReadValue>(variable_a_2);
auto concat_variable_a = std::make_shared<ov::op::v0::Concat>(NodeVector{read_value_a_0, read_value_a_1, read_value_a_2}, 0);

auto fused_matmul1 = std::make_shared<ov::op::v0::MatMul>(lora_input, concat_variable_a, false, true);

auto variable_alpha_0 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({1, -1}), model_dt, "var_alpha_0"});
auto variable_alpha_1 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({1, -1}), model_dt, "var_alpha_1"});
auto variable_alpha_2 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({1, -1}), model_dt, "var_alpha_2"});

auto read_value_alpha_0 = std::make_shared<ov::op::v6::ReadValue>(variable_alpha_0);
auto read_value_alpha_1 = std::make_shared<ov::op::v6::ReadValue>(variable_alpha_1);
auto read_value_alpha_2 = std::make_shared<ov::op::v6::ReadValue>(variable_alpha_2);
auto concat_variable_alpha = std::make_shared<ov::op::v0::Concat>(NodeVector{read_value_alpha_0, read_value_alpha_1, read_value_alpha_2}, 1);

auto multiply = std::make_shared<ov::op::v1::Multiply>(fused_matmul1, concat_variable_alpha);

auto split_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2});
auto split = std::make_shared<ov::op::v1::Split>(multiply, split_axis, 3);

auto variable_b_0 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({2048, -1}), model_dt, "var_b_0"});
auto variable_b_1 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({256, -1}), model_dt, "var_b_1"});
auto variable_b_2 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape({256, -1}), model_dt, "var_b_2"});

auto read_value_b_0 = std::make_shared<ov::op::v6::ReadValue>(variable_b_0);
auto read_value_b_1 = std::make_shared<ov::op::v6::ReadValue>(variable_b_1);
auto read_value_b_2 = std::make_shared<ov::op::v6::ReadValue>(variable_b_2);

auto matmul2_0 = std::make_shared<ov::op::v0::MatMul>(split->output(0), read_value_b_0, false, true);
auto matmul2_1 = std::make_shared<ov::op::v0::MatMul>(split->output(1), read_value_b_1, false, true);
auto matmul2_2 = std::make_shared<ov::op::v0::MatMul>(split->output(2), read_value_b_2, false, true);

auto concat_matmul2 = std::make_shared<ov::op::v0::Concat>(NodeVector{matmul2_0, matmul2_1, matmul2_2}, 2);

auto add = std::make_shared<ov::op::v1::Add>(fc_fused, concat_matmul2);

auto axis_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {2});
auto split_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {2048, 256, 256});
auto var_split = std::make_shared<ov::op::v1::VariadicSplit>(add, axis_const, split_const);

auto reshape_pattern0 = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 0, 32, 64});
auto reshape_pattern1 = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 0, 4, 64});
auto reshape_pattern2 = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 0, 4, 64});
auto reshape0 = std::make_shared<ov::op::v1::Reshape>(var_split->output(0), reshape_pattern0, true);
auto reshape1 = std::make_shared<ov::op::v1::Reshape>(var_split->output(1), reshape_pattern1, true);
auto reshape2 = std::make_shared<ov::op::v1::Reshape>(var_split->output(2), reshape_pattern2, true);

auto result0 = std::make_shared<ov::op::v0::Result>(reshape0);
auto result1 = std::make_shared<ov::op::v0::Result>(reshape1);
auto result2 = std::make_shared<ov::op::v0::Result>(reshape2);

model_ref = std::make_shared<ov::Model>(ov::NodeVector{result0, result1, result2}, ov::ParameterVector{lora_input});
comparator.enable(FunctionsComparator::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
}

} // namespace intel_gpu
} // namespace test
} // namespace ov

0 comments on commit 6dc506b

Please sign in to comment.