Skip to content

Commit

Permalink
[CPU] LoraFusion
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Oct 15, 2024
1 parent 222ddc0 commit f01abbe
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/common/transformations/include/ov_ops/lora_subgraph.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"
#include "openvino/op/util/sub_graph_base.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace op {
namespace internal {

class TRANSFORMATIONS_API LoraSubgraph : public ov::op::util::SubGraphOp {
public:
OPENVINO_OP("LoraSubgraph", "ie_internal_opset");

LoraSubgraph() = default;
LoraSubgraph(const OutputVector& args, const std::shared_ptr<ov::Model>& body);

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
};

} // namespace internal
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <vector>

#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API LoraSubgraphFusion;

} // namespace pass
} // namespace ov

/**
* @ingroup ov_transformation_common_api
* @brief
*/
class ov::pass::LoraSubgraphFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("LoraSubgraphFusion", "0");
LoraSubgraphFusion();
};
69 changes: 69 additions & 0 deletions src/common/transformations/src/ov_ops/lora_subgraph.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "ov_ops/lora_subgraph.hpp"

#include "openvino/core/partial_shape.hpp"

namespace ov {
namespace op {
namespace internal {

LoraSubgraph::LoraSubgraph(const OutputVector& args, const std::shared_ptr<ov::Model>& body) : SubGraphOp(args) {
SubGraphOp::set_function(body);
for (size_t i = 0; i < body->get_parameters().size(); ++i)
m_input_descriptions[0].push_back(std::make_shared<InvariantInputDescription>(i, i));
for (size_t i = 0; i < body->get_output_size(); ++i)
m_output_descriptions[0].push_back(std::make_shared<BodyOutputDescription>(i, i));
constructor_validate_and_infer_types();
}

std::shared_ptr<Node> LoraSubgraph::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return std::make_shared<LoraSubgraph>(new_args, get_function()->clone());
}

void LoraSubgraph::validate_and_infer_types() {
std::unordered_map<size_t, PartialShape> shape_map;
std::unordered_map<size_t, element::Type> type_map;
auto body = get_function();
OPENVINO_ASSERT(body);
validate_and_infer_type_body(body, m_input_descriptions[0]);

for (const auto& item : get_mapping_outputs_on_body_description(m_output_descriptions[0])) {
auto output_index = item.first;
auto desc = item.second;
auto node_result = body->get_results().at(desc->m_body_value_index)->input_value(0);
auto pshape = PartialShape::dynamic();
if (shape_map.count(output_index)) {
pshape = shape_map.at(output_index);
}
// TODO: should we support fully dynamic shape/element type?
if (PartialShape::merge_into(pshape, node_result.get_partial_shape())) {
shape_map[output_index] = std::move(pshape);
} else {
shape_map[output_index] = PartialShape::dynamic();
}
auto type = element::dynamic;
if (type_map.count(output_index)) {
type = type_map.at(output_index);
}
if (element::Type::merge(type, type, node_result.get_element_type())) {
type_map[output_index] = type;
} else {
type_map[output_index] = element::dynamic;
}
}
for (const auto& item : shape_map) {
auto output_index = item.first;
NODE_VALIDATION_CHECK(this,
type_map.count(output_index) != 0,
"Type map must contain same outputs as shape map");
set_output_type(output_index, type_map.at(output_index), item.second);
}
}

} // namespace internal
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/lora_subgraph_fusion.hpp"

#include <memory>
#include <vector>

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/convolution.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/util/read_value_base.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "ov_ops/lora_subgraph.hpp"
#include "transformations/utils/utils.hpp"

ov::pass::LoraSubgraphFusion::LoraSubgraphFusion() {
MATCHER_SCOPE(LoraSubgraphFusion);
using namespace pass::pattern;
auto input_m = any_input();
auto transpose_const1_m = wrap_type<ov::op::v0::Constant>();
auto transpose1_m = optional<ov::op::v1::Transpose>({input_m, transpose_const1_m}, consumers_count(1));
auto read_value1_m = wrap_type<ov::op::util::ReadValueBase>();
auto matmul1_m = wrap_type<ov::op::v0::MatMul>({transpose1_m, read_value1_m}, consumers_count(1));
auto read_value2_m = wrap_type<ov::op::util::ReadValueBase>();
auto multiply_m = wrap_type<ov::op::v1::Multiply>({matmul1_m, read_value2_m}, consumers_count(1));
auto read_value3_m = wrap_type<ov::op::util::ReadValueBase>();
auto matmul2_m = wrap_type<ov::op::v0::MatMul>({multiply_m, read_value3_m}, consumers_count(1));
auto transpose_const2_m = wrap_type<ov::op::v0::Constant>(consumers_count(1));
auto transpose2_m = optional<ov::op::v1::Transpose>({matmul2_m, transpose_const2_m}, consumers_count(1));
auto external_matmul_m = wrap_type<ov::op::v0::MatMul, ov::op::v1::Convolution>({input_m, any_input()});
auto add_m = wrap_type<ov::op::v1::Add>({transpose2_m, external_matmul_m});

ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
const auto& input = pattern_map.at(input_m);
const auto& read_value1 = pattern_map.at(read_value1_m);
const auto& read_value2 = pattern_map.at(read_value2_m);
const auto& read_value3 = pattern_map.at(read_value3_m);
const auto& external_matmul = pattern_map.at(external_matmul_m);

const auto& matmul1_node = pattern_map.at(matmul1_m).get_node_shared_ptr();
const auto& multiply_node = pattern_map.at(multiply_m).get_node_shared_ptr();
const auto& matmul2_node = pattern_map.at(matmul2_m).get_node_shared_ptr();
const auto& add_node = pattern_map.at(add_m).get_node_shared_ptr();

// Need to collect exactly inputs, not outputs
// Inputs - for internal body
// Outputs - for LoraSubgraph connection with the model
std::vector<ov::Input<ov::Node>> inputs{
pattern_map.count(transpose1_m) ? pattern_map.at(transpose1_m).get_node()->input(0) : matmul1_node->input(0),
matmul1_node->input(1),
multiply_node->input(1),
matmul2_node->input(1),
add_node->input(0),
};
ov::OutputVector lora_inputs{
input,
read_value1,
read_value2,
read_value3,
external_matmul,
};
const auto& last_out_to_replace = add_node;

ov::ParameterVector subgraph_parameters;
subgraph_parameters.reserve(inputs.size());
for (auto& input : inputs) {
const auto new_parameter = std::make_shared<ov::op::v0::Parameter>(input.get_element_type(), input.get_partial_shape());
subgraph_parameters.push_back(new_parameter);
input.replace_source_output(new_parameter);
}
auto lora_subgraph = std::make_shared<ov::Model>(ov::OutputVector{last_out_to_replace}, subgraph_parameters)->clone();
const auto lora_node = std::make_shared<ov::op::internal::LoraSubgraph>(lora_inputs, lora_subgraph);
return ov::replace_output_update_name(last_out_to_replace, lora_node->output(0));
};

auto m = std::make_shared<Matcher>(add_m, matcher_name);
this->register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "transformations/common_optimizations/move_eltwise_up_data_movement.hpp"
#include "transformations/common_optimizations/mark_rope_input_to_keep_in_mixed_precision.hpp"
#include "transformations/common_optimizations/rms_fusion.hpp"
#include "transformations/common_optimizations/lora_subgraph_fusion.hpp"
#include "transformations/control_flow/unroll_tensor_iterator.hpp"
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
#include "transformations/fp16_compression/mark_floatpoint_range.hpp"
Expand Down Expand Up @@ -171,6 +172,8 @@
#endif
#include "openvino/core/validation_util.hpp"

#include "ov_ops/lora_subgraph.hpp"

namespace ov {
namespace intel_cpu {

Expand Down Expand Up @@ -679,6 +682,8 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
CPU_REGISTER_PASS_COMMON(manager, ov::pass::EnableDecompressionConvertConstantFolding);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompression);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::ConstantFolding);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::LoraSubgraphFusion);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::Validate);

manager.run_passes(model);
}
Expand Down

0 comments on commit f01abbe

Please sign in to comment.