Skip to content

Commit

Permalink
RPE operation and fusing transformation (openvinotoolkit#20734)
Browse files Browse the repository at this point in the history
* RPE operation and fusing transformation

* Correct includes

* Apply suggestions from code review

Co-authored-by: Pawel Raasz <[email protected]>

* Comments adressed

* Misprints

* Update src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp

Co-authored-by: Pawel Raasz <[email protected]>

* Ivan comments adressed

* Update src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp

Co-authored-by: Pawel Raasz <[email protected]>

* Fix includes and adds comments

---------

Co-authored-by: Pavel Durandin <[email protected]>
Co-authored-by: Pawel Raasz <[email protected]>
  • Loading branch information
3 people authored and akuporos committed Dec 8, 2023
1 parent 5c5b393 commit b607fb7
Show file tree
Hide file tree
Showing 6 changed files with 307 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace op {
namespace internal {

///
/// \brief Rotary Positional Embeddings operation
/// Internal operation which may change in the future
/// \ingroup ov_ops_cpp_api
class TRANSFORMATIONS_API RPE : public ov::op::Op {
public:
OPENVINO_OP("RPE", "ie_internal_opset", op::Op);

RPE() = default;
RPE(const Output<Node>& data, const Output<Node>& sin, const Output<Node>& cos, int64_t axis);

void set_axis(int64_t axis);
int64_t get_axis() const;

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

private:
int64_t m_axis{};
};

} // namespace internal
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {
class TRANSFORMATIONS_API RPE_Fusion;
} // namespace pass
} // namespace ov

/**
* @ingroup ie_transformation_common_api
* @brief Fuses special sub-graph into an internal Rotary Positional Embedding operation
*/
class ov::pass::RPE_Fusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("RPE_Fusion", "0");
RPE_Fusion();
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "ov_ops/rotary_positional_embeddings.hpp"

#include "itt.hpp"

namespace ov {
namespace op {
namespace internal {

RPE::RPE(const Output<Node>& data, const Output<Node>& sin, const Output<Node>& cos, int64_t axis)
: Op({data, sin, cos}),
m_axis{axis} {
constructor_validate_and_infer_types();
}

void RPE::set_axis(int64_t axis) {
m_axis = axis;
}

int64_t RPE::get_axis() const {
return m_axis;
}

void RPE::validate_and_infer_types() {
INTERNAL_OP_SCOPE(internal_RoPE_validate_and_infer_types);
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}

bool RPE::visit_attributes(ov::AttributeVisitor& visitor) {
INTERNAL_OP_SCOPE(internal_RoPE_visit_attributes);
visitor.on_attribute("axis", m_axis);
return true;
}

std::shared_ptr<ov::Node> RPE::clone_with_new_inputs(const ov::OutputVector& new_args) const {
INTERNAL_OP_SCOPE(internal_RoPE_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<RPE>(new_args.at(0), new_args.at(1), new_args.at(2), m_axis);
}

} // namespace internal
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp"

#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/op.hpp"
#include "openvino/op/variadic_split.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "ov_ops/rotary_positional_embeddings.hpp"
#include "transformations/utils/utils.hpp"
#include "validation_util.hpp"

using ov::op::v0::Concat;
using ov::op::v1::Add;
using ov::op::v1::Multiply;
using ov::op::v1::VariadicSplit;

ov::pass::RPE_Fusion::RPE_Fusion() {
MATCHER_SCOPE(RPE_Fusion);

auto sin = pattern::any_input();
auto cos = pattern::any_input();

// FIXME: should be a single node match
auto source_1 = pattern::any_input();
auto source = pattern::any_input();
// BEGIN: rotate_half

// Variadic Split into two equal parts
auto axis = pattern::any_input();
auto split_length = INT_CONSTANT_WITH_PREDICATE(value.size() == 2 && value[0] == value[1]);
auto vsplit = pattern::wrap_type<VariadicSplit>({source, axis, split_length});
vsplit->set_output_size(2);

// Negate
auto minus_1 = FLOAT_CONSTANT_WITH_PREDICATE(value.size() == 1 && value[0] == -1);
auto neg = pattern::wrap_type<Multiply>({vsplit->output(1), minus_1});

// Concat two splitted parts in the opposite order, first of them is negated
auto concat = pattern::wrap_type<Concat>({neg, vsplit->output(0)}); // make sure axis eq to vsplit eq -1

// END: rotate half

auto mul_sin = pattern::wrap_type<Multiply>({concat, sin});
auto mul_cos = pattern::wrap_type<Multiply>({source_1, cos});
auto add = pattern::wrap_type<Add>({mul_cos, mul_sin});

ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
auto value_map = m.get_pattern_value_map();

auto actual_source = value_map.at(vsplit).get_node_shared_ptr()->input_value(0);
auto potential_source = value_map.at(mul_cos).get_node_shared_ptr()->input_value(0);
auto cos_output = value_map.at(mul_cos).get_node_shared_ptr()->input_value(1);

if (actual_source != potential_source && actual_source != cos_output)
return false; // flawed match
if (actual_source == potential_source && actual_source == cos_output)
return false; // flawed match
if (actual_source != potential_source && actual_source == cos_output)
cos_output = potential_source;

auto input = value_map.at(source);
auto concat_node = ov::as_type_ptr<Concat>(value_map.at(concat).get_node_shared_ptr());
if (!concat_node)
return false;
auto split_axis_node = ov::util::get_constant_from_source(value_map.at(axis));
if (!split_axis_node)
return false;
auto value = split_axis_node->cast_vector<int64_t>();
if (value.size() != 1)
return false;
auto concat_axis = concat_node->get_concatenation_axis();
auto split_axis = value[0];
if (concat_axis != split_axis) {
if (input.get_partial_shape().rank().is_static()) {
auto rank = input.get_partial_shape().rank().get_length();
concat_axis = ov::util::normalize(concat_axis, rank);
split_axis = ov::util::normalize(split_axis, rank);
}
if (concat_axis != split_axis)
return false;
}
auto rpe =
std::make_shared<ov::op::internal::RPE>(input, value_map.at(sin), cos_output, concat_node->get_axis());

for (const auto& label : {vsplit, neg, concat, mul_sin, mul_cos, add})
ov::copy_runtime_info(value_map.at(label).get_node_shared_ptr(), rpe);
return ov::replace_output_update_name(value_map.at(add), rpe->output(0));
};
auto m = std::make_shared<pattern::Matcher>(add, matcher_name);
register_matcher(m, matcher_pass_callback);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp"

#include <gtest/gtest.h>

#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/variadic_split.hpp"
#include "ov_ops/rotary_positional_embeddings.hpp"
#include "transformations/utils/utils.hpp"

using namespace std;
using namespace ov;
using namespace ov::op;

void name_node_and_output(const shared_ptr<Node>& op, const std::string& name) {
op->set_friendly_name(name);
op->output(0).set_names({name});
}

TEST_F(TransformationTestsF, FuseRPE) {
{
auto data = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
name_node_and_output(data, "source");
auto sin = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
name_node_and_output(sin, "sin");
auto cos = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
name_node_and_output(cos, "cos");
auto axis = v0::Constant::create(element::i64, {}, {-1});
auto split_lengths = v0::Constant::create(element::i64, {2}, {10, 10});
auto split = make_shared<v1::VariadicSplit>(data, axis, split_lengths);

auto minus_one = v0::Constant::create(element::f32, {}, {-1});
auto negate = make_shared<v1::Multiply>(split->output(1), minus_one);

auto concat = make_shared<v0::Concat>(OutputVector{negate, split->output(0)}, -1);

auto mul_sin = make_shared<op::v1::Multiply>(concat, sin);
auto mul_cos = make_shared<op::v1::Multiply>(data, cos);
auto add = make_shared<op::v1::Add>(mul_cos, mul_sin);
name_node_and_output(add, "rpe");

model = std::make_shared<Model>(NodeVector{add}, ParameterVector{data, sin, cos});

manager.register_pass<ov::pass::RPE_Fusion>();
}
{
auto data = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
name_node_and_output(data, "source");
auto sin = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
name_node_and_output(sin, "sin");
auto cos = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
name_node_and_output(cos, "cos");
auto rpe = make_shared<ov::op::internal::RPE>(data, sin, cos, -1);
name_node_and_output(rpe, "rpe");
model_ref = std::make_shared<Model>(NodeVector{rpe}, ParameterVector{data, sin, cos});
}
comparator.enable(FunctionsComparator::CmpValues::NAMES);
}

TEST_F(TransformationTestsF, FuseRPESorcesAreMultiOutputed) {
/* Transformation matcher searches for a single source as a beginning of the pattern:
VariadicSplit ...
source ____/
\
Multiply ...
This test is designed to check that in case we feed VariadicSplit and Multiply from different outputs of the same
node, the transformation won't happen since the source isn't the same
*/
{
auto data_ = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
auto sin = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
auto cos = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());

auto data = make_shared<v1::Split>(data_, v0::Constant::create(element::i64, {}, {-1}), 2);

auto axis = v0::Constant::create(element::i64, {}, {-1});
auto split_lengths = v0::Constant::create(element::i64, {2}, {10, 10});
auto split = make_shared<v1::VariadicSplit>(data->output(0), axis, split_lengths);

auto minus_one = v0::Constant::create(element::f32, {}, {-1});
auto negate = make_shared<v1::Multiply>(split->output(1), minus_one);

auto concat = make_shared<v0::Concat>(OutputVector{negate, split->output(0)}, -1);

auto mul_sin = make_shared<op::v1::Multiply>(concat, sin);
auto mul_cos = make_shared<op::v1::Multiply>(data->output(1), cos);
auto add = make_shared<op::v1::Add>(mul_cos, mul_sin);

model = std::make_shared<Model>(NodeVector{add}, ParameterVector{data_, sin, cos});

manager.register_pass<ov::pass::RPE_Fusion>();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "transformations/common_optimizations/common_optimizations.hpp"
#include "transformations/common_optimizations/wrap_interpolate_into_transposes.hpp"
#include "transformations/common_optimizations/matmul_const_transposes_extraction.hpp"
#include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp"
#include "transformations/control_flow/unroll_tensor_iterator.hpp"
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
#include "transformations/op_conversions/convert_batch_to_space.hpp"
Expand Down Expand Up @@ -326,6 +327,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis

CPU_REGISTER_PASS_COMMON(manager, ov::pass::AUGRUCellFusion);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::CommonOptimizations);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::RPE_Fusion);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::WrapInterpolateIntoTransposes);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::TransposeSinking);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::ConvertSequenceToTensorIterator);
Expand Down Expand Up @@ -473,6 +475,8 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertTopK11ToTopK3);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::HSwishDecomposition);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::MatMulConstTransposesExtraction);
// CVS-126827: should be disabled until CPU supports this internal op
CPU_DISABLE_PASS_COMMON(manager, ov::pass::RPE_Fusion);
CPU_DISABLE_PASS_X64(manager, ov::pass::HSigmoidDecomposition);

CPU_DISABLE_PASS_X64(manager, ov::pass::ReduceL1Decomposition);
Expand Down

0 comments on commit b607fb7

Please sign in to comment.