Skip to content

Commit

Permalink
[GPU] Added horizontal fusion for LoRA
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman committed Dec 4, 2024
1 parent c1b344b commit e42850f
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "lora_horizontal_fusion.hpp"

#include "openvino/core/rt_info.hpp"
#include "openvino/opsets/opset1.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"

#include "intel_gpu/op/fully_connected_compressed.hpp"

namespace ov {
namespace intel_gpu {

LoRAHorizontalFusion::LoRAHorizontalFusion() {
using namespace ov::pass::pattern;

auto is_target_pattern = [](const std::shared_ptr<Node>& split_node) {
auto is_lora_pattern = [](const std::shared_ptr<Node>& node) {
#define check(node) if (!node) return false;

const auto& add = std::dynamic_pointer_cast<ov::op::v1::Add>(node); check(add)
const auto& matmul2 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(add->get_input_node_shared_ptr(1)); check(matmul2)
const auto& multiply = std::dynamic_pointer_cast<ov::op::v1::Multiply>(matmul2->get_input_node_shared_ptr(0)); check(multiply)
const auto& variable_b = std::dynamic_pointer_cast<ov::op::util::ReadValueBase>(matmul2->get_input_node_shared_ptr(1)); check(variable_b)
const auto& matmul1 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(multiply->get_input_node_shared_ptr(0)); check(matmul1)
const auto& variable_alpha = std::dynamic_pointer_cast<ov::op::util::ReadValueBase>(multiply->get_input_node_shared_ptr(1)); check(variable_alpha)
const auto& variable_a = std::dynamic_pointer_cast<ov::op::util::ReadValueBase>(matmul1->get_input_node_shared_ptr(1)); check(variable_a)

return true;
};

for (const auto& user : split_node->get_users()) {
if (!is_lora_pattern(user)) {
return false;
}
}

return true;
};

auto lora_input = any_input();
auto main_flow_1 = wrap_type<op::FullyConnectedCompressed>({lora_input, any_input(), any_input(), any_input()});
auto main_flow_2 = wrap_type<op::FullyConnectedCompressed>({lora_input, any_input(), any_input(), any_input(), any_input()});
auto main_flow = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{main_flow_1, main_flow_2});

auto axis_const = wrap_type<ov::op::v0::Constant>();
auto split_const = wrap_type<ov::op::v0::Constant>();
auto split = wrap_type<ov::op::v1::VariadicSplit>({main_flow, axis_const, split_const}, ov::pass::pattern::op::as_value_predicate(is_target_pattern));

ov::matcher_pass_callback callback = [=](Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
const auto& split = m.get_match_root();

ov::NodeVector add_nodes;
ov::NodeVector multiply_nodes;
ov::NodeVector variable_a_nodes;
ov::NodeVector variable_b_nodes;
ov::NodeVector variable_alpha_nodes;
ov::NodeVector matmul1_nodes;
ov::NodeVector matmul2_nodes;

for (const auto& add : split->get_users()) {
add_nodes.emplace_back(add);
matmul2_nodes.emplace_back(add->get_input_node_shared_ptr(1));
}
for (const auto& matmul2 : matmul2_nodes) {
multiply_nodes.emplace_back(matmul2->get_input_node_shared_ptr(0));
variable_b_nodes.emplace_back(matmul2->get_input_node_shared_ptr(1));
}
for (const auto& multiply : multiply_nodes) {
matmul1_nodes.emplace_back(multiply->get_input_node_shared_ptr(0));
variable_alpha_nodes.emplace_back(multiply->get_input_node_shared_ptr(1));
}
for (const auto& matmul1 : matmul1_nodes) {
variable_a_nodes.emplace_back(matmul1->get_input_node_shared_ptr(1));
}

auto fused_variable_a = std::make_shared<ov::op::v0::Concat>(variable_a_nodes, 0);
fused_variable_a->set_friendly_name(variable_a_nodes[0]->get_friendly_name() +
"_fused_" + std::to_string(variable_a_nodes.size()) + "_ReadValues");
ov::copy_runtime_info(variable_a_nodes, fused_variable_a);

auto fused_variable_alpha = std::make_shared<ov::op::v0::Concat>(variable_alpha_nodes, 1);
fused_variable_alpha->set_friendly_name(variable_alpha_nodes[0]->get_friendly_name() +
"_fused_" + std::to_string(variable_alpha_nodes.size()) + "_ReadValues");
ov::copy_runtime_info(variable_alpha_nodes, fused_variable_alpha);

bool transpose_a1 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(matmul1_nodes[0])->get_transpose_a();
bool transpose_b1 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(matmul1_nodes[0])->get_transpose_b();
auto fused_matmul1 = std::make_shared<ov::op::v0::MatMul>(pattern_map.at(lora_input), fused_variable_a, transpose_a1, transpose_b1);
auto fused_matmul1_name = matmul1_nodes[0]->get_friendly_name() + "_fused_" + std::to_string(matmul1_nodes.size()) + "_MatMuls";
fused_matmul1->set_friendly_name(fused_matmul1_name);
ov::copy_runtime_info(matmul1_nodes, fused_matmul1);
for (const auto& old_matmul1 : matmul1_nodes) {
old_matmul1->clear_control_dependencies();
}

auto fused_multiply = std::make_shared<ov::op::v1::Multiply>(fused_matmul1, fused_variable_alpha);
auto multiply_name = multiply_nodes[0]->get_friendly_name() + "_fused_" + std::to_string(multiply_nodes.size()) + "_Multiply";
fused_multiply->set_friendly_name(multiply_name);
ov::copy_runtime_info(multiply_nodes, fused_multiply);
for (const auto& old_multiply : multiply_nodes) {
old_multiply->clear_control_dependencies();
}

bool fuse_second_matmul = true;
size_t not_concatenable_idx = 0;
const auto& base_dim = variable_b_nodes[0]->get_output_partial_shape(0)[not_concatenable_idx];
for (size_t i = 1; i < variable_b_nodes.size(); ++i) {
const auto& dim = variable_b_nodes[i]->get_output_partial_shape(0)[not_concatenable_idx];
if (dim.is_dynamic() || dim.get_length() != base_dim.get_length()) {
fuse_second_matmul = false;
}
}

std::shared_ptr<ov::Node> fused_matmul2 = nullptr;
if (fuse_second_matmul) {
auto fused_variable_b = std::make_shared<ov::op::v0::Concat>(variable_b_nodes, 1);
fused_variable_b->set_friendly_name(variable_b_nodes[0]->get_friendly_name() +
"_fused" + std::to_string(variable_b_nodes.size()) + "_ReadValues");
ov::copy_runtime_info(variable_b_nodes, fused_variable_b);

bool transpose_a2 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(matmul2_nodes[0])->get_transpose_a();
bool transpose_b2 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(matmul2_nodes[0])->get_transpose_b();
fused_matmul2 = std::make_shared<ov::op::v0::MatMul>(fused_multiply, fused_variable_b, transpose_a2, transpose_b2);
auto matmul2_name = matmul2_nodes[0]->get_friendly_name() + "_fused_" + std::to_string(matmul2_nodes.size()) + "_MatMuls";
fused_matmul2->set_friendly_name(matmul2_name);
ov::copy_runtime_info(matmul2_nodes, fused_matmul2);
for (const auto& old_matmul2 : matmul2_nodes) {
old_matmul2->clear_control_dependencies();
}
} else {
auto axis_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {fused_multiply->get_output_partial_shape(0).size() - 1});
auto output_split = std::make_shared<ov::op::v1::Split>(fused_multiply, axis_const, matmul2_nodes.size());
auto split_name = fused_multiply->get_friendly_name() + "_split";
copy_runtime_info(fused_multiply, output_split);
output_split->set_friendly_name(split_name);
for (size_t i = 0; i < matmul2_nodes.size(); ++i) {
matmul2_nodes[i]->input(0).replace_source_output(output_split->output(i));
}

fused_matmul2 = std::make_shared<ov::op::v0::Concat>(matmul2_nodes, matmul2_nodes[0]->get_output_partial_shape(0).size() - 1);
auto matmul2_name = matmul2_nodes[0]->get_friendly_name() + "_fused_" + std::to_string(matmul2_nodes.size()) + "_MatMuls_output";
fused_matmul2->set_friendly_name(matmul2_name);
}

auto fused_add = std::make_shared<ov::op::v1::Add>(split->get_input_node_shared_ptr(0), fused_matmul2);
auto fused_add_name = add_nodes[0]->get_friendly_name() + "_fused_" + std::to_string(add_nodes.size()) + "_Adds";
fused_add->set_friendly_name(fused_add_name);
ov::copy_runtime_info(add_nodes, fused_add);

for (size_t i = 0; i < add_nodes.size(); ++i) {
const auto& old_add = add_nodes[i];
for (auto u : old_add->get_users()) {
for (size_t idx = 0; idx < u->inputs().size(); ++idx) {
if (u->get_input_node_shared_ptr(idx) == old_add) {
u->input(idx).replace_source_output(split->output(i));
}
}
}
old_add->clear_control_dependencies();
}

split->input(0).replace_source_output(fused_add->output(0));
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(split, "LoRAHorizontalFusion");
this->register_matcher(m, callback);
}

} // namespace intel_gpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"

namespace ov {
namespace intel_gpu {

class LoRAHorizontalFusion: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("LoRAHorizontalFusion", "0");
LoRAHorizontalFusion();
};

} // namespace intel_gpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
#include "plugin/transformations/group_norm_composition.hpp"
#include "plugin/transformations/dynamic_quantize_fully_connected.hpp"
#include "plugin/transformations/optimize_subsequent_reshapes.hpp"
#include "plugin/transformations/lora_horizontal_fusion.hpp"
#include "transformations/common_optimizations/nop_elimination.hpp"
#include "transformations/common_optimizations/rms_fusion.hpp"
#include "transformations/common_optimizations/broadcast_elementwise_fusion.hpp"
Expand Down Expand Up @@ -917,8 +918,10 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
GPU_DEBUG_IF(debug_config->disable_horizontal_fc_fusion == 1)
disable_horizontal_fc_fusion = true;

if (!disable_horizontal_fc_fusion)
if (!disable_horizontal_fc_fusion) {
manager.register_pass<ov::intel_gpu::FullyConnectedHorizontalFusion>();
manager.register_pass<ov::intel_gpu::LoRAHorizontalFusion>();
}

// ZP should not be folded for FC. But still, ZP should be folded for Gather.
// Therefore, run MarkDequantizationSubgraph again to fold ZP constant.
Expand Down

0 comments on commit e42850f

Please sign in to comment.