forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
RPE operation and fusing transformation (openvinotoolkit#20734)
* RPE operation and fusing transformation * Correct includes * Apply suggestions from code review Co-authored-by: Pawel Raasz <[email protected]> * Comments adressed * Misprints * Update src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp Co-authored-by: Pawel Raasz <[email protected]> * Ivan comments adressed * Update src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp Co-authored-by: Pawel Raasz <[email protected]> * Fix includes and adds comments --------- Co-authored-by: Pavel Durandin <[email protected]> Co-authored-by: Pawel Raasz <[email protected]>
- Loading branch information
Showing
6 changed files
with
307 additions
and
0 deletions.
There are no files selected for viewing
38 changes: 38 additions & 0 deletions
38
src/common/transformations/include/ov_ops/rotary_positional_embeddings.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// Copyright (C) 2018-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/op/op.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
/// | ||
/// \brief Rotary Positional Embeddings operation | ||
/// Internal operation which may change in the future | ||
/// \ingroup ov_ops_cpp_api | ||
class TRANSFORMATIONS_API RPE : public ov::op::Op { | ||
public: | ||
OPENVINO_OP("RPE", "ie_internal_opset", op::Op); | ||
|
||
RPE() = default; | ||
RPE(const Output<Node>& data, const Output<Node>& sin, const Output<Node>& cos, int64_t axis); | ||
|
||
void set_axis(int64_t axis); | ||
int64_t get_axis() const; | ||
|
||
bool visit_attributes(AttributeVisitor& visitor) override; | ||
void validate_and_infer_types() override; | ||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override; | ||
|
||
private: | ||
int64_t m_axis{}; | ||
}; | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
24 changes: 24 additions & 0 deletions
24
...ations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// Copyright (C) 2018-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/graph_rewrite.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
class TRANSFORMATIONS_API RPE_Fusion; | ||
} // namespace pass | ||
} // namespace ov | ||
|
||
/** | ||
* @ingroup ie_transformation_common_api | ||
* @brief Fuses special sub-graph into an internal Rotary Positional Embedding operation | ||
*/ | ||
class ov::pass::RPE_Fusion : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("RPE_Fusion", "0"); | ||
RPE_Fusion(); | ||
}; |
46 changes: 46 additions & 0 deletions
46
src/common/transformations/src/ov_ops/rotary_positional_embeddings.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
// Copyright (C) 2018-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "ov_ops/rotary_positional_embeddings.hpp" | ||
|
||
#include "itt.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace internal { | ||
|
||
RPE::RPE(const Output<Node>& data, const Output<Node>& sin, const Output<Node>& cos, int64_t axis) | ||
: Op({data, sin, cos}), | ||
m_axis{axis} { | ||
constructor_validate_and_infer_types(); | ||
} | ||
|
||
void RPE::set_axis(int64_t axis) { | ||
m_axis = axis; | ||
} | ||
|
||
int64_t RPE::get_axis() const { | ||
return m_axis; | ||
} | ||
|
||
void RPE::validate_and_infer_types() { | ||
INTERNAL_OP_SCOPE(internal_RoPE_validate_and_infer_types); | ||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); | ||
} | ||
|
||
bool RPE::visit_attributes(ov::AttributeVisitor& visitor) { | ||
INTERNAL_OP_SCOPE(internal_RoPE_visit_attributes); | ||
visitor.on_attribute("axis", m_axis); | ||
return true; | ||
} | ||
|
||
std::shared_ptr<ov::Node> RPE::clone_with_new_inputs(const ov::OutputVector& new_args) const { | ||
INTERNAL_OP_SCOPE(internal_RoPE_clone_with_new_inputs); | ||
check_new_args_count(this, new_args); | ||
return std::make_shared<RPE>(new_args.at(0), new_args.at(1), new_args.at(2), m_axis); | ||
} | ||
|
||
} // namespace internal | ||
} // namespace op | ||
} // namespace ov |
98 changes: 98 additions & 0 deletions
98
...formations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
// Copyright (C) 2018-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp" | ||
|
||
#include "itt.hpp" | ||
#include "openvino/core/validation_util.hpp" | ||
#include "openvino/op/add.hpp" | ||
#include "openvino/op/concat.hpp" | ||
#include "openvino/op/multiply.hpp" | ||
#include "openvino/op/op.hpp" | ||
#include "openvino/op/variadic_split.hpp" | ||
#include "openvino/pass/pattern/op/wrap_type.hpp" | ||
#include "ov_ops/rotary_positional_embeddings.hpp" | ||
#include "transformations/utils/utils.hpp" | ||
#include "validation_util.hpp" | ||
|
||
using ov::op::v0::Concat; | ||
using ov::op::v1::Add; | ||
using ov::op::v1::Multiply; | ||
using ov::op::v1::VariadicSplit; | ||
|
||
ov::pass::RPE_Fusion::RPE_Fusion() { | ||
MATCHER_SCOPE(RPE_Fusion); | ||
|
||
auto sin = pattern::any_input(); | ||
auto cos = pattern::any_input(); | ||
|
||
// FIXME: should be a single node match | ||
auto source_1 = pattern::any_input(); | ||
auto source = pattern::any_input(); | ||
// BEGIN: rotate_half | ||
|
||
// Variadic Split into two equal parts | ||
auto axis = pattern::any_input(); | ||
auto split_length = INT_CONSTANT_WITH_PREDICATE(value.size() == 2 && value[0] == value[1]); | ||
auto vsplit = pattern::wrap_type<VariadicSplit>({source, axis, split_length}); | ||
vsplit->set_output_size(2); | ||
|
||
// Negate | ||
auto minus_1 = FLOAT_CONSTANT_WITH_PREDICATE(value.size() == 1 && value[0] == -1); | ||
auto neg = pattern::wrap_type<Multiply>({vsplit->output(1), minus_1}); | ||
|
||
// Concat two splitted parts in the opposite order, first of them is negated | ||
auto concat = pattern::wrap_type<Concat>({neg, vsplit->output(0)}); // make sure axis eq to vsplit eq -1 | ||
|
||
// END: rotate half | ||
|
||
auto mul_sin = pattern::wrap_type<Multiply>({concat, sin}); | ||
auto mul_cos = pattern::wrap_type<Multiply>({source_1, cos}); | ||
auto add = pattern::wrap_type<Add>({mul_cos, mul_sin}); | ||
|
||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) { | ||
auto value_map = m.get_pattern_value_map(); | ||
|
||
auto actual_source = value_map.at(vsplit).get_node_shared_ptr()->input_value(0); | ||
auto potential_source = value_map.at(mul_cos).get_node_shared_ptr()->input_value(0); | ||
auto cos_output = value_map.at(mul_cos).get_node_shared_ptr()->input_value(1); | ||
|
||
if (actual_source != potential_source && actual_source != cos_output) | ||
return false; // flawed match | ||
if (actual_source == potential_source && actual_source == cos_output) | ||
return false; // flawed match | ||
if (actual_source != potential_source && actual_source == cos_output) | ||
cos_output = potential_source; | ||
|
||
auto input = value_map.at(source); | ||
auto concat_node = ov::as_type_ptr<Concat>(value_map.at(concat).get_node_shared_ptr()); | ||
if (!concat_node) | ||
return false; | ||
auto split_axis_node = ov::util::get_constant_from_source(value_map.at(axis)); | ||
if (!split_axis_node) | ||
return false; | ||
auto value = split_axis_node->cast_vector<int64_t>(); | ||
if (value.size() != 1) | ||
return false; | ||
auto concat_axis = concat_node->get_concatenation_axis(); | ||
auto split_axis = value[0]; | ||
if (concat_axis != split_axis) { | ||
if (input.get_partial_shape().rank().is_static()) { | ||
auto rank = input.get_partial_shape().rank().get_length(); | ||
concat_axis = ov::util::normalize(concat_axis, rank); | ||
split_axis = ov::util::normalize(split_axis, rank); | ||
} | ||
if (concat_axis != split_axis) | ||
return false; | ||
} | ||
auto rpe = | ||
std::make_shared<ov::op::internal::RPE>(input, value_map.at(sin), cos_output, concat_node->get_axis()); | ||
|
||
for (const auto& label : {vsplit, neg, concat, mul_sin, mul_cos, add}) | ||
ov::copy_runtime_info(value_map.at(label).get_node_shared_ptr(), rpe); | ||
return ov::replace_output_update_name(value_map.at(add), rpe->output(0)); | ||
}; | ||
auto m = std::make_shared<pattern::Matcher>(add, matcher_name); | ||
register_matcher(m, matcher_pass_callback); | ||
} |
97 changes: 97 additions & 0 deletions
97
src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
// Copyright (C) 2018-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp" | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include "common_test_utils/ov_test_utils.hpp" | ||
#include "openvino/op/parameter.hpp" | ||
#include "openvino/op/variadic_split.hpp" | ||
#include "ov_ops/rotary_positional_embeddings.hpp" | ||
#include "transformations/utils/utils.hpp" | ||
|
||
using namespace std; | ||
using namespace ov; | ||
using namespace ov::op; | ||
|
||
void name_node_and_output(const shared_ptr<Node>& op, const std::string& name) { | ||
op->set_friendly_name(name); | ||
op->output(0).set_names({name}); | ||
} | ||
|
||
TEST_F(TransformationTestsF, FuseRPE) { | ||
{ | ||
auto data = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic()); | ||
name_node_and_output(data, "source"); | ||
auto sin = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic()); | ||
name_node_and_output(sin, "sin"); | ||
auto cos = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic()); | ||
name_node_and_output(cos, "cos"); | ||
auto axis = v0::Constant::create(element::i64, {}, {-1}); | ||
auto split_lengths = v0::Constant::create(element::i64, {2}, {10, 10}); | ||
auto split = make_shared<v1::VariadicSplit>(data, axis, split_lengths); | ||
|
||
auto minus_one = v0::Constant::create(element::f32, {}, {-1}); | ||
auto negate = make_shared<v1::Multiply>(split->output(1), minus_one); | ||
|
||
auto concat = make_shared<v0::Concat>(OutputVector{negate, split->output(0)}, -1); | ||
|
||
auto mul_sin = make_shared<op::v1::Multiply>(concat, sin); | ||
auto mul_cos = make_shared<op::v1::Multiply>(data, cos); | ||
auto add = make_shared<op::v1::Add>(mul_cos, mul_sin); | ||
name_node_and_output(add, "rpe"); | ||
|
||
model = std::make_shared<Model>(NodeVector{add}, ParameterVector{data, sin, cos}); | ||
|
||
manager.register_pass<ov::pass::RPE_Fusion>(); | ||
} | ||
{ | ||
auto data = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic()); | ||
name_node_and_output(data, "source"); | ||
auto sin = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic()); | ||
name_node_and_output(sin, "sin"); | ||
auto cos = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic()); | ||
name_node_and_output(cos, "cos"); | ||
auto rpe = make_shared<ov::op::internal::RPE>(data, sin, cos, -1); | ||
name_node_and_output(rpe, "rpe"); | ||
model_ref = std::make_shared<Model>(NodeVector{rpe}, ParameterVector{data, sin, cos}); | ||
} | ||
comparator.enable(FunctionsComparator::CmpValues::NAMES); | ||
} | ||
|
||
TEST_F(TransformationTestsF, FuseRPESorcesAreMultiOutputed) { | ||
/* Transformation matcher searches for a single source as a beginning of the pattern: | ||
VariadicSplit ... | ||
source ____/ | ||
\ | ||
Multiply ... | ||
This test is designed to check that in case we feed VariadicSplit and Multiply from different outputs of the same | ||
node, the transformation won't happen since the source isn't the same | ||
*/ | ||
{ | ||
auto data_ = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic()); | ||
auto sin = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic()); | ||
auto cos = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic()); | ||
|
||
auto data = make_shared<v1::Split>(data_, v0::Constant::create(element::i64, {}, {-1}), 2); | ||
|
||
auto axis = v0::Constant::create(element::i64, {}, {-1}); | ||
auto split_lengths = v0::Constant::create(element::i64, {2}, {10, 10}); | ||
auto split = make_shared<v1::VariadicSplit>(data->output(0), axis, split_lengths); | ||
|
||
auto minus_one = v0::Constant::create(element::f32, {}, {-1}); | ||
auto negate = make_shared<v1::Multiply>(split->output(1), minus_one); | ||
|
||
auto concat = make_shared<v0::Concat>(OutputVector{negate, split->output(0)}, -1); | ||
|
||
auto mul_sin = make_shared<op::v1::Multiply>(concat, sin); | ||
auto mul_cos = make_shared<op::v1::Multiply>(data->output(1), cos); | ||
auto add = make_shared<op::v1::Add>(mul_cos, mul_sin); | ||
|
||
model = std::make_shared<Model>(NodeVector{add}, ParameterVector{data_, sin, cos}); | ||
|
||
manager.register_pass<ov::pass::RPE_Fusion>(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters