Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RPE operation and fusing transformation #20734

Merged
merged 14 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace op {
namespace internal {

///
/// \brief Rotary Positional Embeddings operation
/// Internal operation which may change in the future
/// \ingroup ov_ops_cpp_api
class TRANSFORMATIONS_API RPE : public ov::op::Op {
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
public:
OPENVINO_OP("RPE", "ie_internal_opset", op::Op);

RPE() = default;
RPE(const Output<Node>& data, const Output<Node>& sin, const Output<Node>& cos, const int64_t& axis);
jane-intel marked this conversation as resolved.
Show resolved Hide resolved

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

int64_t get_axis() const {
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
return m_axis;
};
void set_axis(const int64_t& axis) {
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
m_axis = axis;
};

private:
int64_t m_axis{};
};

} // namespace internal
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {
class TRANSFORMATIONS_API RPE_Fusion;
} // namespace pass
} // namespace ov

/**
* @ingroup ie_transformation_common_api
* @brief Fuses special sub-graph into an internal Rotary Positional Embedding operation
*/
class ov::pass::RPE_Fusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("RPE_Fusion", "0");
RPE_Fusion();
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "ov_ops/rotary_positional_embeddings.hpp"

#include "itt.hpp"

using namespace std;
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
using namespace ov::op::internal;
jane-intel marked this conversation as resolved.
Show resolved Hide resolved

RPE::RPE(const Output<Node>& data, const Output<Node>& sin, const Output<Node>& cos, const int64_t& axis)
: Op({data, sin, cos}),
m_axis{axis} {
constructor_validate_and_infer_types();
}

void RPE::validate_and_infer_types() {
INTERNAL_OP_SCOPE(internal_RoPE_validate_and_infer_types);
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}

bool RPE::visit_attributes(ov::AttributeVisitor& visitor) {
INTERNAL_OP_SCOPE(internal_RoPE_visit_attributes);
visitor.on_attribute("axis", m_axis);
return true;
}

shared_ptr<ov::Node> RPE::clone_with_new_inputs(const ov::OutputVector& new_args) const {
INTERNAL_OP_SCOPE(internal_RoPE_clone_with_new_inputs);
return make_shared<RPE>(new_args.at(0), new_args.at(1), new_args.at(2), m_axis);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp"

#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/op.hpp"
#include "openvino/op/variadic_split.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "ov_ops/rotary_positional_embeddings.hpp"
#include "transformations/utils/utils.hpp"

ov::pass::RPE_Fusion::RPE_Fusion() {
MATCHER_SCOPE(RPE_Fusion);

jane-intel marked this conversation as resolved.
Show resolved Hide resolved
auto sin = pattern::any_input();
auto cos = pattern::any_input();

// FIXME: should be a single node match
auto source_1 = pattern::any_input();
auto source = pattern::any_input();
// BEGIN: rotate_half

// Variadic Split into two equal parts
auto axis = pattern::any_input();
auto split_length = INT_CONSTANT_WITH_PREDICATE(value.size() == 2 && value[0] == value[1]);
auto vsplit = pattern::wrap_type<op::v1::VariadicSplit>({source, axis, split_length});
vsplit->set_output_size(2);

// Negate
auto minus_1 = FLOAT_CONSTANT_WITH_PREDICATE(value.size() == 1 && value[0] == -1);
auto neg = pattern::wrap_type<op::v1::Multiply>({vsplit->output(1), minus_1});

// Concat two splitted parts in the opposite order, first of them is negated
auto concat = pattern::wrap_type<op::v0::Concat>({neg, vsplit->output(0)}); // make sure axis eq to vsplit eq -1

// END: rotate half

auto mul_sin = pattern::wrap_type<op::v1::Multiply>({concat, sin});
auto mul_cos = pattern::wrap_type<op::v1::Multiply>({source_1, cos});
auto add = pattern::wrap_type<op::v1::Add>({mul_cos, mul_sin});

ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
auto value_map = m.get_pattern_value_map();

auto actual_source = value_map.at(vsplit).get_node_shared_ptr()->input_value(0);
auto potential_source = value_map.at(mul_cos).get_node_shared_ptr()->input_value(0);
auto cos_output = value_map.at(mul_cos).get_node_shared_ptr()->input_value(1);

if (actual_source != potential_source && actual_source != cos_output)
return false; // flawed match
if (actual_source == potential_source && actual_source == cos_output)
return false; // flawed match
if (actual_source != potential_source && actual_source == cos_output)
cos_output = potential_source;

auto input = value_map.at(source);
auto concat_node = ov::as_type_ptr<op::v0::Concat>(value_map.at(concat).get_node_shared_ptr());
if (!concat_node)
return false;
OPENVINO_SUPPRESS_DEPRECATED_START
auto split_axis_node = ov::get_constant_from_source(value_map.at(axis));
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
OPENVINO_SUPPRESS_DEPRECATED_END
if (!split_axis_node)
return false;
auto value = split_axis_node->cast_vector<int64_t>();
if (value.size() != 1)
return false;
auto concat_axis = concat_node->get_concatenation_axis();
auto split_axis = value[0];
if (concat_axis != split_axis) {
if (input.get_partial_shape().is_static()) {
auto rank = input.get_partial_shape().rank().get_length();
concat_axis = concat_axis < 0 ? concat_axis + rank : concat_axis;
split_axis = split_axis < 0 ? split_axis + rank : split_axis;
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
}
if (concat_axis != split_axis)
return false;
}
auto rpe =
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
std::make_shared<ov::op::internal::RPE>(input, value_map.at(sin), cos_output, concat_node->get_axis());
ov::replace_output_update_name(value_map.at(add), rpe->output(0));
return true;
};
auto m = std::make_shared<pattern::Matcher>(add, matcher_name);
register_matcher(m, matcher_pass_callback);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp"

#include <gtest/gtest.h>

#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/variadic_split.hpp"
#include "ov_ops/rotary_positional_embeddings.hpp"
#include "transformations/utils/utils.hpp"

using namespace std;
using namespace ov;
using namespace ov::op;

void name_node_and_output(const shared_ptr<Node>& op, const std::string& name) {
op->set_friendly_name(name);
op->output(0).set_names({name});
}

TEST_F(TransformationTestsF, FuseRPE) {
{
auto data = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
name_node_and_output(data, "source");
auto sin = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
name_node_and_output(sin, "sin");
auto cos = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
name_node_and_output(cos, "cos");
auto axis = v0::Constant::create(element::i64, {}, {-1});
auto split_lengths = v0::Constant::create(element::i64, {2}, {10, 10});
auto split = make_shared<v1::VariadicSplit>(data, axis, split_lengths);

auto minus_one = v0::Constant::create(element::f32, {}, {-1});
auto negate = make_shared<v1::Multiply>(split->output(1), minus_one);

auto concat = make_shared<v0::Concat>(OutputVector{negate, split->output(0)}, -1);

auto mul_sin = make_shared<op::v1::Multiply>(concat, sin);
auto mul_cos = make_shared<op::v1::Multiply>(data, cos);
auto add = make_shared<op::v1::Add>(mul_cos, mul_sin);
name_node_and_output(add, "rpe");

model = std::make_shared<Model>(NodeVector{add}, ParameterVector{data, sin, cos});

manager.register_pass<ov::pass::RPE_Fusion>();
}
{
auto data = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
name_node_and_output(data, "source");
auto sin = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
name_node_and_output(sin, "sin");
auto cos = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
name_node_and_output(cos, "cos");
auto rpe = make_shared<ov::op::internal::RPE>(data, sin, cos, -1);
name_node_and_output(rpe, "rpe");
model_ref = std::make_shared<Model>(NodeVector{rpe}, ParameterVector{data, sin, cos});
}
comparator.enable(FunctionsComparator::CmpValues::NAMES);
}

TEST_F(TransformationTestsF, FuseRPESorcesAreMultiOutputed) {
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
{
auto data_ = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
auto sin_ = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
auto cos_ = make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());

auto data = make_shared<v1::Split>(data_, v0::Constant::create(element::i64, {}, {-1}), 2);
auto sin = make_shared<v1::Split>(sin_, v0::Constant::create(element::i64, {}, {-1}), 2)->output(1);
auto cos = make_shared<v1::Split>(cos_, v0::Constant::create(element::i64, {}, {-1}), 2)->output(1);

auto axis = v0::Constant::create(element::i64, {}, {-1});
auto split_lengths = v0::Constant::create(element::i64, {2}, {10, 10});
auto split = make_shared<v1::VariadicSplit>(data->output(0), axis, split_lengths);

auto minus_one = v0::Constant::create(element::f32, {}, {-1});
auto negate = make_shared<v1::Multiply>(split->output(1), minus_one);

auto concat = make_shared<v0::Concat>(OutputVector{negate, split->output(0)}, -1);

auto mul_sin = make_shared<op::v1::Multiply>(concat, sin);
auto mul_cos = make_shared<op::v1::Multiply>(data->output(1), cos);
auto add = make_shared<op::v1::Add>(mul_cos, mul_sin);

model = std::make_shared<Model>(NodeVector{add}, ParameterVector{data_, sin_, cos_});

manager.register_pass<ov::pass::RPE_Fusion>();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "transformations/common_optimizations/common_optimizations.hpp"
#include "transformations/common_optimizations/wrap_interpolate_into_transposes.hpp"
#include "transformations/common_optimizations/matmul_const_transposes_extraction.hpp"
#include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp"
#include "transformations/control_flow/unroll_tensor_iterator.hpp"
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
#include "transformations/op_conversions/convert_batch_to_space.hpp"
Expand Down Expand Up @@ -294,6 +295,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis

CPU_REGISTER_PASS_COMMON(manager, ov::pass::AUGRUCellFusion);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::CommonOptimizations);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::RPE_Fusion);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::WrapInterpolateIntoTransposes);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::TransposeSinking);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::ConvertSequenceToTensorIterator);
Expand Down Expand Up @@ -434,6 +436,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertTopK11ToTopK3);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::HSwishDecomposition);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::MatMulConstTransposesExtraction);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::RPE_Fusion); // should be disabled until CPU supports this internal op
CPU_DISABLE_PASS_X64(manager, ov::pass::HSigmoidDecomposition);

CPU_DISABLE_PASS_X64(manager, ov::pass::ReduceL1Decomposition);
Expand Down
Loading