forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
218 additions
and
0 deletions.
There are no files selected for viewing
28 changes: 28 additions & 0 deletions
28
src/common/transformations/include/ov_ops/lora_subgraph.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/op/op.hpp" | ||
#include "openvino/op/util/sub_graph_base.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
class TRANSFORMATIONS_API LoraSubgraph : public ov::op::util::SubGraphOp { | ||
public: | ||
OPENVINO_OP("LoraSubgraph", "ie_internal_opset"); | ||
|
||
LoraSubgraph() = default; | ||
LoraSubgraph(const OutputVector& args, const std::shared_ptr<ov::Model>& body); | ||
|
||
void validate_and_infer_types() override; | ||
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; | ||
}; | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
29 changes: 29 additions & 0 deletions
29
...mon/transformations/include/transformations/common_optimizations/lora_subgraph_fusion.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include <memory> | ||
#include <vector> | ||
|
||
#include "openvino/pass/matcher_pass.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API LoraSubgraphFusion; | ||
|
||
} // namespace pass | ||
} // namespace ov | ||
|
||
/** | ||
* @ingroup ov_transformation_common_api | ||
* @brief | ||
*/ | ||
class ov::pass::LoraSubgraphFusion : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("LoraSubgraphFusion", "0"); | ||
LoraSubgraphFusion(); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "ov_ops/lora_subgraph.hpp" | ||
|
||
#include "openvino/core/partial_shape.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
LoraSubgraph::LoraSubgraph(const OutputVector& args, const std::shared_ptr<ov::Model>& body) : SubGraphOp(args) { | ||
SubGraphOp::set_function(body); | ||
for (size_t i = 0; i < body->get_parameters().size(); ++i) | ||
m_input_descriptions[0].push_back(std::make_shared<InvariantInputDescription>(i, i)); | ||
for (size_t i = 0; i < body->get_output_size(); ++i) | ||
m_output_descriptions[0].push_back(std::make_shared<BodyOutputDescription>(i, i)); | ||
constructor_validate_and_infer_types(); | ||
} | ||
|
||
std::shared_ptr<Node> LoraSubgraph::clone_with_new_inputs(const OutputVector& new_args) const { | ||
check_new_args_count(this, new_args); | ||
return std::make_shared<LoraSubgraph>(new_args, get_function()->clone()); | ||
} | ||
|
||
void LoraSubgraph::validate_and_infer_types() { | ||
std::unordered_map<size_t, PartialShape> shape_map; | ||
std::unordered_map<size_t, element::Type> type_map; | ||
auto body = get_function(); | ||
OPENVINO_ASSERT(body); | ||
validate_and_infer_type_body(body, m_input_descriptions[0]); | ||
|
||
for (const auto& item : get_mapping_outputs_on_body_description(m_output_descriptions[0])) { | ||
auto output_index = item.first; | ||
auto desc = item.second; | ||
auto node_result = body->get_results().at(desc->m_body_value_index)->input_value(0); | ||
auto pshape = PartialShape::dynamic(); | ||
if (shape_map.count(output_index)) { | ||
pshape = shape_map.at(output_index); | ||
} | ||
// TODO: should we support fully dynamic shape/element type? | ||
if (PartialShape::merge_into(pshape, node_result.get_partial_shape())) { | ||
shape_map[output_index] = std::move(pshape); | ||
} else { | ||
shape_map[output_index] = PartialShape::dynamic(); | ||
} | ||
auto type = element::dynamic; | ||
if (type_map.count(output_index)) { | ||
type = type_map.at(output_index); | ||
} | ||
if (element::Type::merge(type, type, node_result.get_element_type())) { | ||
type_map[output_index] = type; | ||
} else { | ||
type_map[output_index] = element::dynamic; | ||
} | ||
} | ||
for (const auto& item : shape_map) { | ||
auto output_index = item.first; | ||
NODE_VALIDATION_CHECK(this, | ||
type_map.count(output_index) != 0, | ||
"Type map must contain same outputs as shape map"); | ||
set_output_type(output_index, type_map.at(output_index), item.second); | ||
} | ||
} | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
87 changes: 87 additions & 0 deletions
87
src/common/transformations/src/transformations/common_optimizations/lora_subgraph_fusion.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/common_optimizations/lora_subgraph_fusion.hpp" | ||
|
||
#include <memory> | ||
#include <vector> | ||
|
||
#include "itt.hpp" | ||
#include "openvino/core/rt_info.hpp" | ||
#include "openvino/op/add.hpp" | ||
#include "openvino/op/convolution.hpp" | ||
#include "openvino/op/matmul.hpp" | ||
#include "openvino/op/multiply.hpp" | ||
#include "openvino/op/parameter.hpp" | ||
#include "openvino/op/transpose.hpp" | ||
#include "openvino/op/util/read_value_base.hpp" | ||
#include "openvino/pass/pattern/op/optional.hpp" | ||
#include "openvino/pass/pattern/op/wrap_type.hpp" | ||
#include "ov_ops/lora_subgraph.hpp" | ||
#include "transformations/utils/utils.hpp" | ||
|
||
ov::pass::LoraSubgraphFusion::LoraSubgraphFusion() { | ||
MATCHER_SCOPE(LoraSubgraphFusion); | ||
using namespace pass::pattern; | ||
auto input_m = any_input(); | ||
auto transpose_const1_m = wrap_type<ov::op::v0::Constant>(); | ||
auto transpose1_m = optional<ov::op::v1::Transpose>({input_m, transpose_const1_m}, consumers_count(1)); | ||
auto read_value1_m = wrap_type<ov::op::util::ReadValueBase>(); | ||
auto matmul1_m = wrap_type<ov::op::v0::MatMul>({transpose1_m, read_value1_m}, consumers_count(1)); | ||
auto read_value2_m = wrap_type<ov::op::util::ReadValueBase>(); | ||
auto multiply_m = wrap_type<ov::op::v1::Multiply>({matmul1_m, read_value2_m}, consumers_count(1)); | ||
auto read_value3_m = wrap_type<ov::op::util::ReadValueBase>(); | ||
auto matmul2_m = wrap_type<ov::op::v0::MatMul>({multiply_m, read_value3_m}, consumers_count(1)); | ||
auto transpose_const2_m = wrap_type<ov::op::v0::Constant>(consumers_count(1)); | ||
auto transpose2_m = optional<ov::op::v1::Transpose>({matmul2_m, transpose_const2_m}, consumers_count(1)); | ||
auto external_matmul_m = wrap_type<ov::op::v0::MatMul, ov::op::v1::Convolution>({input_m, any_input()}); | ||
auto add_m = wrap_type<ov::op::v1::Add>({transpose2_m, external_matmul_m}); | ||
|
||
ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](Matcher& m) { | ||
const auto& pattern_map = m.get_pattern_value_map(); | ||
const auto& input = pattern_map.at(input_m); | ||
const auto& read_value1 = pattern_map.at(read_value1_m); | ||
const auto& read_value2 = pattern_map.at(read_value2_m); | ||
const auto& read_value3 = pattern_map.at(read_value3_m); | ||
const auto& external_matmul = pattern_map.at(external_matmul_m); | ||
|
||
const auto& matmul1_node = pattern_map.at(matmul1_m).get_node_shared_ptr(); | ||
const auto& multiply_node = pattern_map.at(multiply_m).get_node_shared_ptr(); | ||
const auto& matmul2_node = pattern_map.at(matmul2_m).get_node_shared_ptr(); | ||
const auto& add_node = pattern_map.at(add_m).get_node_shared_ptr(); | ||
|
||
// Need to collect exactly inputs, not outputs | ||
// Inputs - for internal body | ||
// Outputs - for LoraSubgraph connection with the model | ||
std::vector<ov::Input<ov::Node>> inputs{ | ||
pattern_map.count(transpose1_m) ? pattern_map.at(transpose1_m).get_node()->input(0) : matmul1_node->input(0), | ||
matmul1_node->input(1), | ||
multiply_node->input(1), | ||
matmul2_node->input(1), | ||
add_node->input(0), | ||
}; | ||
ov::OutputVector lora_inputs{ | ||
input, | ||
read_value1, | ||
read_value2, | ||
read_value3, | ||
external_matmul, | ||
}; | ||
const auto& last_out_to_replace = add_node; | ||
|
||
ov::ParameterVector subgraph_parameters; | ||
subgraph_parameters.reserve(inputs.size()); | ||
for (auto& input : inputs) { | ||
const auto new_parameter = std::make_shared<ov::op::v0::Parameter>(input.get_element_type(), input.get_partial_shape()); | ||
subgraph_parameters.push_back(new_parameter); | ||
input.replace_source_output(new_parameter); | ||
} | ||
auto lora_subgraph = std::make_shared<ov::Model>(ov::OutputVector{last_out_to_replace}, subgraph_parameters)->clone(); | ||
const auto lora_node = std::make_shared<ov::op::internal::LoraSubgraph>(lora_inputs, lora_subgraph); | ||
return ov::replace_output_update_name(last_out_to_replace, lora_node->output(0)); | ||
}; | ||
|
||
auto m = std::make_shared<Matcher>(add_m, matcher_name); | ||
this->register_matcher(m, callback); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters