forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[GPU] Added horizontal fusion for LoRA
- Loading branch information
1 parent
90aacf2
commit ef78518
Showing
3 changed files
with
199 additions
and
1 deletion.
There are no files selected for viewing
176 changes: 176 additions & 0 deletions
176
src/plugins/intel_gpu/src/plugin/transformations/lora_horizontal_fusion.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "lora_horizontal_fusion.hpp" | ||
|
||
#include "openvino/core/rt_info.hpp" | ||
#include "openvino/opsets/opset1.hpp" | ||
#include "openvino/pass/pattern/op/or.hpp" | ||
#include "openvino/pass/pattern/op/wrap_type.hpp" | ||
|
||
#include "intel_gpu/op/fully_connected_compressed.hpp" | ||
|
||
namespace ov { | ||
namespace intel_gpu { | ||
|
||
LoRAHorizontalFusion::LoRAHorizontalFusion() { | ||
using namespace ov::pass::pattern; | ||
|
||
auto is_target_pattern = [](const std::shared_ptr<Node>& split_node) { | ||
auto is_lora_pattern = [](const std::shared_ptr<Node>& node) { | ||
#define check(node) if (!node) return false; | ||
|
||
const auto& add = std::dynamic_pointer_cast<ov::op::v1::Add>(node); check(add) | ||
const auto& matmul2 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(add->get_input_node_shared_ptr(1)); check(matmul2) | ||
const auto& multiply = std::dynamic_pointer_cast<ov::op::v1::Multiply>(matmul2->get_input_node_shared_ptr(0)); check(multiply) | ||
const auto& variable_b = std::dynamic_pointer_cast<ov::op::util::ReadValueBase>(matmul2->get_input_node_shared_ptr(1)); check(variable_b) | ||
const auto& matmul1 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(multiply->get_input_node_shared_ptr(0)); check(matmul1) | ||
const auto& variable_alpha = std::dynamic_pointer_cast<ov::op::util::ReadValueBase>(multiply->get_input_node_shared_ptr(1)); check(variable_alpha) | ||
const auto& variable_a = std::dynamic_pointer_cast<ov::op::util::ReadValueBase>(matmul1->get_input_node_shared_ptr(1)); check(variable_a) | ||
|
||
return true; | ||
}; | ||
|
||
for (const auto& user : split_node->get_users()) { | ||
if (!is_lora_pattern(user)) { | ||
return false; | ||
} | ||
} | ||
|
||
return true; | ||
}; | ||
|
||
auto lora_input = any_input(); | ||
auto main_flow_1 = wrap_type<op::FullyConnectedCompressed>({lora_input, any_input(), any_input(), any_input()}); | ||
auto main_flow_2 = wrap_type<op::FullyConnectedCompressed>({lora_input, any_input(), any_input(), any_input(), any_input()}); | ||
auto main_flow = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{main_flow_1, main_flow_2}); | ||
|
||
auto axis_const = wrap_type<ov::op::v0::Constant>(); | ||
auto split_const = wrap_type<ov::op::v0::Constant>(); | ||
auto split = wrap_type<ov::op::v1::VariadicSplit>({main_flow, axis_const, split_const}, ov::pass::pattern::op::as_value_predicate(is_target_pattern)); | ||
|
||
ov::matcher_pass_callback callback = [=](Matcher& m) { | ||
const auto& pattern_map = m.get_pattern_value_map(); | ||
const auto& split = m.get_match_root(); | ||
|
||
ov::NodeVector add_nodes; | ||
ov::NodeVector multiply_nodes; | ||
ov::NodeVector variable_a_nodes; | ||
ov::NodeVector variable_b_nodes; | ||
ov::NodeVector variable_alpha_nodes; | ||
ov::NodeVector matmul1_nodes; | ||
ov::NodeVector matmul2_nodes; | ||
|
||
for (const auto& add : split->get_users()) { | ||
add_nodes.emplace_back(add); | ||
matmul2_nodes.emplace_back(add->get_input_node_shared_ptr(1)); | ||
} | ||
for (const auto& matmul2 : matmul2_nodes) { | ||
multiply_nodes.emplace_back(matmul2->get_input_node_shared_ptr(0)); | ||
variable_b_nodes.emplace_back(matmul2->get_input_node_shared_ptr(1)); | ||
} | ||
for (const auto& multiply : multiply_nodes) { | ||
matmul1_nodes.emplace_back(multiply->get_input_node_shared_ptr(0)); | ||
variable_alpha_nodes.emplace_back(multiply->get_input_node_shared_ptr(1)); | ||
} | ||
for (const auto& matmul1 : matmul1_nodes) { | ||
variable_a_nodes.emplace_back(matmul1->get_input_node_shared_ptr(1)); | ||
} | ||
|
||
auto fused_variable_a = std::make_shared<ov::op::v0::Concat>(variable_a_nodes, 0); | ||
fused_variable_a->set_friendly_name(variable_a_nodes[0]->get_friendly_name() + | ||
"_fused_" + std::to_string(variable_a_nodes.size()) + "_ReadValues"); | ||
ov::copy_runtime_info(variable_a_nodes, fused_variable_a); | ||
|
||
auto fused_variable_alpha = std::make_shared<ov::op::v0::Concat>(variable_alpha_nodes, 1); | ||
fused_variable_alpha->set_friendly_name(variable_alpha_nodes[0]->get_friendly_name() + | ||
"_fused_" + std::to_string(variable_alpha_nodes.size()) + "_ReadValues"); | ||
ov::copy_runtime_info(variable_alpha_nodes, fused_variable_alpha); | ||
|
||
bool transpose_a1 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(matmul1_nodes[0])->get_transpose_a(); | ||
bool transpose_b1 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(matmul1_nodes[0])->get_transpose_b(); | ||
auto fused_matmul1 = std::make_shared<ov::op::v0::MatMul>(pattern_map.at(lora_input), fused_variable_a, transpose_a1, transpose_b1); | ||
auto fused_matmul1_name = matmul1_nodes[0]->get_friendly_name() + "_fused_" + std::to_string(matmul1_nodes.size()) + "_MatMuls"; | ||
fused_matmul1->set_friendly_name(fused_matmul1_name); | ||
ov::copy_runtime_info(matmul1_nodes, fused_matmul1); | ||
for (const auto& old_matmul1 : matmul1_nodes) { | ||
old_matmul1->clear_control_dependencies(); | ||
} | ||
|
||
auto fused_multiply = std::make_shared<ov::op::v1::Multiply>(fused_matmul1, fused_variable_alpha); | ||
auto multiply_name = multiply_nodes[0]->get_friendly_name() + "_fused_" + std::to_string(multiply_nodes.size()) + "_Multiply"; | ||
fused_multiply->set_friendly_name(multiply_name); | ||
ov::copy_runtime_info(multiply_nodes, fused_multiply); | ||
for (const auto& old_multiply : multiply_nodes) { | ||
old_multiply->clear_control_dependencies(); | ||
} | ||
|
||
bool fuse_second_matmul = true; | ||
size_t not_concatenable_idx = 0; | ||
const auto& base_dim = variable_b_nodes[0]->get_output_partial_shape(0)[not_concatenable_idx]; | ||
for (size_t i = 1; i < variable_b_nodes.size(); ++i) { | ||
const auto& dim = variable_b_nodes[i]->get_output_partial_shape(0)[not_concatenable_idx]; | ||
if (dim.is_dynamic() || dim.get_length() != base_dim.get_length()) { | ||
fuse_second_matmul = false; | ||
} | ||
} | ||
|
||
std::shared_ptr<ov::Node> fused_matmul2 = nullptr; | ||
if (fuse_second_matmul) { | ||
auto fused_variable_b = std::make_shared<ov::op::v0::Concat>(variable_b_nodes, 1); | ||
fused_variable_b->set_friendly_name(variable_b_nodes[0]->get_friendly_name() + | ||
"_fused" + std::to_string(variable_b_nodes.size()) + "_ReadValues"); | ||
ov::copy_runtime_info(variable_b_nodes, fused_variable_b); | ||
|
||
bool transpose_a2 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(matmul2_nodes[0])->get_transpose_a(); | ||
bool transpose_b2 = std::dynamic_pointer_cast<ov::op::v0::MatMul>(matmul2_nodes[0])->get_transpose_b(); | ||
fused_matmul2 = std::make_shared<ov::op::v0::MatMul>(fused_multiply, fused_variable_b, transpose_a2, transpose_b2); | ||
auto matmul2_name = matmul2_nodes[0]->get_friendly_name() + "_fused_" + std::to_string(matmul2_nodes.size()) + "_MatMuls"; | ||
fused_matmul2->set_friendly_name(matmul2_name); | ||
ov::copy_runtime_info(matmul2_nodes, fused_matmul2); | ||
for (const auto& old_matmul2 : matmul2_nodes) { | ||
old_matmul2->clear_control_dependencies(); | ||
} | ||
} else { | ||
auto axis_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {fused_multiply->get_output_partial_shape(0).size() - 1}); | ||
auto output_split = std::make_shared<ov::op::v1::Split>(fused_multiply, axis_const, matmul2_nodes.size()); | ||
auto split_name = fused_multiply->get_friendly_name() + "_split"; | ||
copy_runtime_info(fused_multiply, output_split); | ||
output_split->set_friendly_name(split_name); | ||
for (size_t i = 0; i < matmul2_nodes.size(); ++i) { | ||
matmul2_nodes[i]->input(0).replace_source_output(output_split->output(i)); | ||
} | ||
|
||
fused_matmul2 = std::make_shared<ov::op::v0::Concat>(matmul2_nodes, matmul2_nodes[0]->get_output_partial_shape(0).size() - 1); | ||
auto matmul2_name = matmul2_nodes[0]->get_friendly_name() + "_fused_" + std::to_string(matmul2_nodes.size()) + "_MatMuls_output"; | ||
fused_matmul2->set_friendly_name(matmul2_name); | ||
} | ||
|
||
auto fused_add = std::make_shared<ov::op::v1::Add>(split->get_input_node_shared_ptr(0), fused_matmul2); | ||
auto fused_add_name = add_nodes[0]->get_friendly_name() + "_fused_" + std::to_string(add_nodes.size()) + "_Adds"; | ||
fused_add->set_friendly_name(fused_add_name); | ||
ov::copy_runtime_info(add_nodes, fused_add); | ||
|
||
for (size_t i = 0; i < add_nodes.size(); ++i) { | ||
const auto& old_add = add_nodes[i]; | ||
for (auto u : old_add->get_users()) { | ||
for (size_t idx = 0; idx < u->inputs().size(); ++idx) { | ||
if (u->get_input_node_shared_ptr(idx) == old_add) { | ||
u->input(idx).replace_source_output(split->output(i)); | ||
} | ||
} | ||
} | ||
old_add->clear_control_dependencies(); | ||
} | ||
|
||
split->input(0).replace_source_output(fused_add->output(0)); | ||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ov::pass::pattern::Matcher>(split, "LoRAHorizontalFusion"); | ||
this->register_matcher(m, callback); | ||
} | ||
|
||
} // namespace intel_gpu | ||
} // namespace ov |
19 changes: 19 additions & 0 deletions
19
src/plugins/intel_gpu/src/plugin/transformations/lora_horizontal_fusion.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/graph_rewrite.hpp" | ||
|
||
namespace ov { | ||
namespace intel_gpu { | ||
|
||
class LoRAHorizontalFusion: public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("LoRAHorizontalFusion", "0"); | ||
LoRAHorizontalFusion(); | ||
}; | ||
|
||
} // namespace intel_gpu | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters