From 3c418e197c6f2245d3933b095e1f119ab81cc326 Mon Sep 17 00:00:00 2001 From: Lyamin-Roman Date: Wed, 24 Jan 2024 19:12:27 +0900 Subject: [PATCH] [GPU] Initial RoPE version for ChatGLM --- .../intel_gpu/include/intel_gpu/op/rope.hpp | 104 ++ .../intel_gpu/plugin/primitives_list.hpp | 1 + .../include/intel_gpu/primitives/rope.hpp | 66 + .../src/graph/impls/cpu/register.cpp | 1 + .../src/graph/impls/cpu/register.hpp | 2 + .../intel_gpu/src/graph/impls/cpu/rope.cpp | 415 ++++++ .../intel_gpu/src/graph/include/rope_inst.h | 39 + src/plugins/intel_gpu/src/graph/rope.cpp | 77 + src/plugins/intel_gpu/src/plugin/graph.cpp | 1 + src/plugins/intel_gpu/src/plugin/ops/rope.cpp | 60 + .../src/plugin/transformations/op/rope.cpp | 78 + .../plugin/transformations/rope_fusion.cpp | 730 +++++++++ .../plugin/transformations/rope_fusion.hpp | 83 ++ .../src/plugin/transformations/utils.hpp | 1298 ++++++++++++++++ .../src/plugin/transformations_pipeline.cpp | 9 +- .../intel_gpu/tests/common/gen_pattern.hpp | 1315 +++++++++++++++++ .../subgraph_tests/rotary_pos_emb.cpp | 488 ++++++ .../unit/transformations/convert_to_rope.cpp | 564 +++++++ 18 files changed, 5330 insertions(+), 1 deletion(-) create mode 100644 src/plugins/intel_gpu/include/intel_gpu/op/rope.hpp create mode 100644 src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp create mode 100644 src/plugins/intel_gpu/src/graph/impls/cpu/rope.cpp create mode 100644 src/plugins/intel_gpu/src/graph/include/rope_inst.h create mode 100644 src/plugins/intel_gpu/src/graph/rope.cpp create mode 100644 src/plugins/intel_gpu/src/plugin/ops/rope.cpp create mode 100644 src/plugins/intel_gpu/src/plugin/transformations/op/rope.cpp create mode 100644 src/plugins/intel_gpu/src/plugin/transformations/rope_fusion.cpp create mode 100644 src/plugins/intel_gpu/src/plugin/transformations/rope_fusion.hpp create mode 100644 src/plugins/intel_gpu/src/plugin/transformations/utils.hpp create mode 100644 src/plugins/intel_gpu/tests/common/gen_pattern.hpp create mode 100644 src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp create mode 100644 src/plugins/intel_gpu/tests/unit/transformations/convert_to_rope.cpp diff --git a/src/plugins/intel_gpu/include/intel_gpu/op/rope.hpp b/src/plugins/intel_gpu/include/intel_gpu/op/rope.hpp new file mode 100644 index 00000000000000..bc655e951d240f --- /dev/null +++ b/src/plugins/intel_gpu/include/intel_gpu/op/rope.hpp @@ -0,0 +1,104 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/node.hpp" +#include "openvino/op/op.hpp" + +namespace ov { +namespace intel_gpu { +namespace op { + +/** + * The operation performs rotary positional embedding operation described in: + * ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING by Jianlin Su + * + * the core computation is application of 2x2 rotation matrix on basis + * of pair of input states x[i0] & x[i1] to get the rotary embedded pair of output + * states y[i0] and y[i1]: + * + * suppose dimension of hidden states (of each attention head) is N and d of which + * are to be embedded (d <= N), non-embedded parts are copied into output. + * + * for i in 0...(d/2) + * if (is_interleaved) { + * // interleaving style of indexing + * i0 = i*2 + * i1 = i*2 + 1 + * } else { + * // rotate-half style of indexing + * i0 = i + * i1 = i + (d/2) + * } + * y[i0] = x[i0]*cos(m * xita[i]) - x[i1]*sin(m * xita[i]) + * y[i1] = x[i1]*cos(m * xita[i]) + x[i0]*sin(m * xita[i]) + * Note: m is token position of current input + * + * based on configuration, additional preprocessing steps maybe performed as well: + * - slicing last dimension of input tensor + * (when q/k/v are merged and only q or k part is to be extracted & embedded) + * - transpose input tensor + * (when q/k comes from fullyconnect has layout [batch, seq_len, head_cnt, head_dim] + * but output of RoPE is required to be of layout [batch, head_cnt, seq_length, head_dims]) + * - gather sin/cos from input tensor 2&3 using position index tensor passed through input 4 + * + * Inputs: + * 1. Input hidden states tensor of type T1 - shape: + * [batch, seq_length, head_cnt, head_dims] when input_trans0213 == false OR + * [batch, head_cnt, seq_length, head_dims] when input_trans0213 == true + * 2. pre-calculated cos(m*xita[n]) tensor of type T2 - shape [1, 1, max_position_embeddings, d]. + * 3. pre-calculated sin(m*xita[n]) tensor of type T2 - shape [1, 1, max_position_embeddings, d]. + * input 3 is combined with 2 when is_interleaved is true. + * 4. postion index tensor of type T3 - shape [batch, 1, seq_length, 1 or d] OR [batch, seq_length] optional + * Outputs: + * 1. New embedding tensor of type T1 and of shape [batch, head_cnt, seq_length, head_dims] + * Types: + * T1 - FP32 or BF16 + * T2 - FP32 + * T3 - I32 + */ +class RoPE : public ov::op::Op { +public: + OPENVINO_OP("RoPE", "gpu_opset"); + + RoPE() = default; + + struct Config { + size_t slice_start = 0; // slice inner-most dimensions of input + size_t slice_stop = 0; + bool input_trans0213 = false; // transpose input dim 1&2 + bool is_interleaved = false; // interleaved mode, implies trans0213 happens after RoPE + size_t rotary_ndims = 0; // dimensions to be embedded (d in the description) + bool is_chatglm = false; // chatglm is special which overrides other setting + bool is_qwen = false; // Qwen is special which overrides other setting + size_t head_cnt = 0; + size_t head_size = 0; + int gather_position_arg_id = 0; // arg id of position tensor, + // == 3 when gather from sin/cos inputs according to position is required + }; + + RoPE(const OutputVector& args, const Config& cfg); + + bool visit_attributes(ov::AttributeVisitor& visitor) override; + + void validate_and_infer_types() override; + + std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; + + const Config& get_config() const { + return m_config; + } + + Config& get_config() { + return m_config; + } + +private: + Config m_config; +}; + +} // namespace op +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp index 479f300eecd67b..6293ce59c5a43a 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp @@ -275,3 +275,4 @@ REGISTER_FACTORY(internal, RMS); REGISTER_FACTORY(internal, GatherCompressed); REGISTER_FACTORY(internal, KVCache); REGISTER_FACTORY(internal, ReadValue); +REGISTER_FACTORY(internal, RoPE); diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp new file mode 100644 index 00000000000000..11a243687e7628 --- /dev/null +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp @@ -0,0 +1,66 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include "primitive.hpp" +#include "intel_gpu/op/rope.hpp" + +namespace cldnn { +using RoPE = ov::intel_gpu::op::RoPE; + +/// @brief Rotary Position Embedding primitive +struct rope : public primitive_base { + CLDNN_DECLARE_PRIMITIVE(rope); + + rope() : primitive_base("", {}) {} + + /// @brief Constructs rope primitive + /// @param id This primitive id + /// @param inputs Inputs primitive id + /// @param config + rope(const primitive_id& id, + const std::vector& inputs, + const RoPE::Config& config, + const padding& output_padding = padding()) + : primitive_base(id, inputs, {output_padding}), + config(config) {} + + /// @brief + RoPE::Config config; + + size_t hash() const override { + size_t seed = primitive::hash(); + seed = hash_combine(seed, config.gather_position_arg_id); + seed = hash_combine(seed, config.head_cnt); + seed = hash_combine(seed, config.head_size); + seed = hash_combine(seed, config.input_trans0213); + seed = hash_combine(seed, config.is_chatglm); + seed = hash_combine(seed, config.is_interleaved); + seed = hash_combine(seed, config.is_qwen); + seed = hash_combine(seed, config.rotary_ndims); + seed = hash_combine(seed, config.slice_start); + seed = hash_combine(seed, config.slice_stop); + return seed; + } + + bool operator==(const primitive& rhs) const override { + if (!compare_common_params(rhs)) + return false; + + auto rhs_casted = downcast(rhs); + + return config.gather_position_arg_id == rhs_casted.config.gather_position_arg_id; //TODO + } + + void save(BinaryOutputBuffer& ob) const override { + primitive_base::save(ob); + ob << config.gather_position_arg_id; //TODO + } + + void load(BinaryInputBuffer& ib) override { + primitive_base::load(ib); + ib >> config.gather_position_arg_id; //TODO + } +}; +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/impls/cpu/register.cpp b/src/plugins/intel_gpu/src/graph/impls/cpu/register.cpp index 2e66e836faa608..bb27c964e9206f 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cpu/register.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cpu/register.cpp @@ -29,6 +29,7 @@ void register_implementations() { REGISTER_CPU(broadcast); REGISTER_CPU(tile); REGISTER_CPU(select); + REGISTER_CPU(rope); } } // namespace cpu diff --git a/src/plugins/intel_gpu/src/graph/impls/cpu/register.hpp b/src/plugins/intel_gpu/src/graph/impls/cpu/register.hpp index fb5748e6c73b6b..5b3fc7bd5b94a6 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cpu/register.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cpu/register.hpp @@ -22,6 +22,7 @@ #include "intel_gpu/primitives/broadcast.hpp" #include "intel_gpu/primitives/tile.hpp" #include "intel_gpu/primitives/select.hpp" +#include "intel_gpu/primitives/rope.hpp" namespace cldnn { namespace cpu { @@ -53,6 +54,7 @@ REGISTER_CPU(reorder); REGISTER_CPU(broadcast); REGISTER_CPU(tile); REGISTER_CPU(select); +REGISTER_CPU(rope); #undef REGISTER_CPU diff --git a/src/plugins/intel_gpu/src/graph/impls/cpu/rope.cpp b/src/plugins/intel_gpu/src/graph/impls/cpu/rope.cpp new file mode 100644 index 00000000000000..3118f55c90e54a --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/cpu/rope.cpp @@ -0,0 +1,415 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "register.hpp" +#include "rope_inst.h" +#include "implementation_map.hpp" +#include "intel_gpu/runtime/error_handler.hpp" +#include "intel_gpu/op/rope.hpp" +#include "openvino/core/parallel.hpp" + +namespace cldnn { +namespace cpu { + +using RoPE = ov::intel_gpu::op::RoPE; + +class RoPEExecutor { +public: + void execute(const RoPE::Config& config, const ov::TensorVector& inputs, const ov::TensorVector& outputs); + void selectExecutor(const RoPE::Config& config, ov::element::Type dt); + +private: + struct Executor { + virtual void execute(const RoPE::Config& config, + const ov::TensorVector& inputs, + const ov::TensorVector& outputs) = 0; + }; + + template + struct RoPEExecutorRotateHalf; + template + struct RoPEExecutorInterleaved; + template + struct RoPEExecutorChatGLM; + template + struct RoPEExecutorQwen; + + std::shared_ptr m_executor; +}; + +static ov::Tensor slice(ov::Tensor& tensor, int axis, int start, int end, int step = 1) { + ov::Shape shape = tensor.get_shape(); + ov::Shape new_shape; + + if (end > start) { + new_shape = shape; + new_shape[axis] = (end - start - 1) / step + 1; + } else { + // squeeze if end == start + for (size_t i = 0; i < shape.size(); ++i) { + if (i != static_cast(axis)) { + new_shape.emplace_back(shape[i]); + } + } + } + + auto off = start * tensor.get_strides()[axis]; // strides calc in bytes + auto* data = reinterpret_cast(tensor.data()) + off; + + ov::Tensor new_tensor(tensor.get_element_type(), new_shape, reinterpret_cast(data)); + + return new_tensor; +} + +// static ov::Tensor permute(ov::Tensor& tensor, const std::vector& order) { +// auto& orig_shape = tensor.get_shape(); +// size_t rank = orig_shape.size(); +// assert(order.size() == rank); + +// ov::Shape new_shape; +// for (size_t i = 0; i < rank; i++) { +// new_shape.emplace_back(orig_shape[order[i]]); +// } +// tensor.set_shape(new_shape); +// return tensor; +// // return ov::Tensor(tensor.get_element_type(), new_shape, tensor.data()); +// } + +template +DT& get_data(const ov::Tensor& tensor, const std::initializer_list& index, bool allow_broadcast = false, ov::Strides old_strides = {}) { + const auto& shape = tensor.get_shape(); + // const auto& strides = tensor.get_strides(); + if (old_strides.empty()) { + old_strides = tensor.get_strides(); + } + size_t off = 0; + auto it = index.begin(); + for (size_t i = 0; i < shape.size(); ++i) { + size_t coordinate = (it != index.end()) ? (*it++) : 0; + if (allow_broadcast && shape[i] == 1) { + // allow_broadcast only works when the dimension is really 1 + coordinate = 0; + } else { + assert(coordinate < shape[i]); + } + off += old_strides[i] * coordinate; + } + return (reinterpret_cast(reinterpret_cast(tensor.data()) + off))[0]; +} + +template +struct RoPEExecutor::RoPEExecutorRotateHalf : public RoPEExecutor::Executor { + void execute(const RoPE::Config& config, + const ov::TensorVector& inputs, + const ov::TensorVector& outputs) override { + auto t_src = inputs[0]; + auto& t_cos = inputs[1]; + auto& t_sin = inputs[2]; + auto& t_dst = outputs[0]; + const ov::Tensor* gather = nullptr; + + if (config.slice_stop - config.slice_start > 0) { + t_src = slice(t_src, 3, config.slice_start, config.slice_stop); + } + // if (config.input_trans0213) { + // t_src = permute(t_src, {0, 2, 1, 3}); + // } + + if (config.gather_position_arg_id > 0) { + gather = &inputs[config.gather_position_arg_id]; + } + + auto batch_size = t_src.get_shape()[0]; + auto head_cnt = t_src.get_shape()[1]; + auto seq_len = t_src.get_shape()[2]; + auto feature_size = t_src.get_shape()[3]; + + auto rotary_dims = config.rotary_ndims; + auto half_rotary_dims = rotary_dims / 2; + + ov::parallel_for3d(batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) { + auto cos_pos = p; + if (gather != nullptr) { + if (gather->get_shape().size() == 4) + cos_pos = get_data(*gather, {b, h, p, 0}, true); + else + cos_pos = get_data(*gather, {b, p}, true); + } + T* src = &get_data(t_src, {b, h, p, 0}); + float* cos = &get_data(t_cos, {b, h, cos_pos, 0}, true); + float* sin = &get_data(t_sin, {b, h, cos_pos, 0}, true); + T* dst = &get_data(t_dst, {b, h, p, 0}); + + size_t i = 0; + for (; i < half_rotary_dims; ++i) { + dst[i] = cos[i] * src[i] + sin[i] * (-src[i + half_rotary_dims]); + } + for (; i < rotary_dims; ++i) { + dst[i] = cos[i] * src[i] + sin[i] * (src[i - half_rotary_dims]); + } + for (; i < feature_size; ++i) { + dst[i] = src[i]; + } + }); + } +}; + +template +struct RoPEExecutor::RoPEExecutorInterleaved : public RoPEExecutor::Executor { + void execute(const RoPE::Config& config, + const ov::TensorVector& inputs, + const ov::TensorVector& outputs) override { + auto t_src(inputs[0]); + auto t_sin_cos(inputs[1]); + auto t_dst(outputs[0]); + + auto batch_size = t_src.get_shape()[0]; + auto seq_len = t_src.get_shape()[1]; + auto head_cnt = t_src.get_shape()[2]; + auto head_dims = t_src.get_shape()[3]; + + auto rotary_dims = config.rotary_ndims; + auto half_rotary_dims = rotary_dims / 2; + ov::parallel_for3d(batch_size, seq_len, head_cnt, [&](size_t b, size_t p, size_t h) { + T* x = &get_data(t_src, {b, p, h, 0}); + float* sin = &get_data(t_sin_cos, {b, p, 0}, true); + float* cos = &get_data(t_sin_cos, {b, p, half_rotary_dims}, true); + T* dst = &get_data(t_dst, {b, h, p, 0}); + + size_t i = 0; + for (size_t j = 0; i < rotary_dims; i += 2, j++) { + dst[i] = cos[j] * x[i] - sin[j] * x[i + 1]; + dst[i + 1] = cos[j] * x[i + 1] + sin[j] * x[i]; + } + for (; i < head_dims; i++) { + dst[i] = x[i]; + } + }); + } +}; + +template +struct RoPEExecutor::RoPEExecutorChatGLM : public RoPEExecutor::Executor { + void execute(const RoPE::Config& config, + const ov::TensorVector& inputs, + const ov::TensorVector& outputs) override { + auto t_src(inputs[0]); + auto t_cos_sin(inputs[1]); + auto t_dst(outputs[0]); + + auto old_strides = t_src.get_strides(); + + // [seq_len, batch_size, (hidden_states_q + hidden_states_k + hidden_states_v)] + if (config.slice_stop - config.slice_start > 0) { + t_src = slice(t_src, 2, config.slice_start, config.slice_stop); + } + + auto seq_len = t_src.get_shape()[0]; + auto batch_size = t_src.get_shape()[1]; + + auto head_cnt = config.head_cnt; + auto head_size = config.head_size; + + auto rotary_dims = config.rotary_ndims; + + ov::parallel_for3d(seq_len, batch_size, head_cnt, [&](size_t p, size_t b, size_t h) { + T* src = &get_data(t_src, {p, b, h * head_size}, false, old_strides); + // [length, batch_size, ndims//2, 2] + T* cos_sin = &get_data(t_cos_sin, {p, b, 0, 0}, true); + T* dst = &get_data(t_dst, {p, b, h, 0}); + + size_t i = 0; + for (; i < rotary_dims; i += 2) { + auto cosv = cos_sin[i]; + auto sinv = cos_sin[i + 1]; + dst[i] = cosv * src[i] - sinv * src[i + 1]; + dst[i + 1] = sinv * src[i] + cosv * src[i + 1]; + } + for (; i < head_size; i++) { + dst[i] = src[i]; + } + }); + } +}; + +template +struct RoPEExecutor::RoPEExecutorQwen : public RoPEExecutor::Executor { + void execute(const RoPE::Config& config, + const ov::TensorVector& inputs, + const ov::TensorVector& outputs) override { + auto t_src(inputs[0]); // [batch, length, head_cnt*head_size * 3] + auto t_cos(inputs[1]); // [1, present-kv-length, 1, rotary_dims] + auto t_sin(inputs[2]); // [1, present-kv-length, 1, rotary_dims] + auto t_dst(outputs[0]); // [batch, length, head_cnt, head_size]> + + if (config.slice_stop - config.slice_start > 0) { + t_src = slice(t_src, 2, config.slice_start, config.slice_stop); + } + + auto batch_size = t_src.get_shape()[0]; + auto seq_len = t_src.get_shape()[1]; + auto head_cnt = config.head_cnt; + auto head_size = config.head_size; + auto present_kv_len = t_cos.get_shape()[1]; + + auto rotary_dims = t_cos.get_shape()[3]; + auto half_rotary_dims = rotary_dims / 2; + + ov::parallel_for3d(batch_size, seq_len, head_cnt, [&](size_t b, size_t p, size_t h) { + T* src = &get_data(t_src, {b, p, h * head_size}); + float* cos = &get_data(t_cos, {b, present_kv_len - seq_len + p, h, 0}, true); + float* sin = &get_data(t_sin, {b, present_kv_len - seq_len + p, h, 0}, true); + T* dst = &get_data(t_dst, {b, p, h, 0}); + + size_t i = 0; + for (; i < half_rotary_dims; i++) { + dst[i] = cos[i] * src[i] + sin[i] * (-src[i + half_rotary_dims]); + } + for (; i < rotary_dims; i++) { + dst[i] = cos[i] * src[i] + sin[i] * (src[i - half_rotary_dims]); + } + for (; i < head_size; i++) { + dst[i] = src[i]; + } + }); + } +}; + +void RoPEExecutor::selectExecutor(const RoPE::Config& config, ov::element::Type data_type) { + if (config.is_qwen) { + if (data_type == ov::element::f16) { + m_executor = std::make_shared>(); + } else { + m_executor = std::make_shared>(); + } + } else if (config.is_chatglm) { + if (data_type == ov::element::f16) { + m_executor = std::make_shared>(); + } else { + m_executor = std::make_shared>(); + } + } else if (config.is_interleaved) { + OPENVINO_ASSERT(config.input_trans0213 == false); + OPENVINO_ASSERT(config.slice_start == 0); + OPENVINO_ASSERT(config.slice_stop == 0); + OPENVINO_ASSERT(config.gather_position_arg_id == 0); + if (data_type == ov::element::f16) { + m_executor = std::make_shared>(); + } else { + m_executor = std::make_shared>(); + } + } else { + if (data_type == ov::element::f16) { + m_executor = std::make_shared>(); + } else { + m_executor = std::make_shared>(); + } + } +} + +void RoPEExecutor::execute(const RoPE::Config& config, + const ov::TensorVector& inputs, + const ov::TensorVector& outputs) { + OPENVINO_ASSERT(m_executor != nullptr); + m_executor->execute(config, inputs, outputs); +} + +struct rope_impl : public typed_primitive_impl { + using parent = typed_primitive_impl; + using parent::parent; + + DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::cpu::rope_impl) + + std::unique_ptr clone() const override { + return make_unique(*this); + } + + rope_impl() : parent("rope_cpu_impl") {} + + // void save(BinaryOutputBuffer& ob) const override { + // parent::save(ob); + // // ob << make_data(); + // } + + // void load(BinaryInputBuffer& ib) override { + // parent::load(ib); + // // ib >> make_data(&, sizeof(ov::op::)); + // } + + event::ptr execute_impl(const std::vector& events, rope_inst& instance) override { + OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, "rope::execute_impl"); + + for (auto e : events) { + e->wait(); + } + + auto& stream = instance.get_network().get_stream(); + auto ev = stream.create_user_event(false); + + auto params = instance.get_impl_params(); + const auto& primitive = params->typed_desc(); + const auto& config = primitive->config; + + ov::TensorVector input_host_tensors; + ov::TensorVector output_host_tensors; + + std::vector input_mem_ptrs; + for (size_t i = 0; i < instance.dependencies().size(); ++i) { + input_mem_ptrs.push_back(instance.dep_memory_ptr(i)); + } + + for (size_t i = 0; i < input_mem_ptrs.size(); ++i) { + void* mem_ptr = input_mem_ptrs[i]->lock(stream, mem_lock_type::read); + input_host_tensors.push_back(make_tensor(params->input_layouts[i], mem_ptr)); + } + + auto output_mem_ptr = instance.output_memory_ptr(); + + cldnn::mem_lock output_lock(output_mem_ptr, stream); + output_host_tensors.push_back(make_tensor(params->output_layouts[0], output_lock.data())); + + RoPEExecutor rope; + rope.selectExecutor(config, params->get_input_layout().data_type); + rope.execute(config, input_host_tensors, output_host_tensors); + + for (size_t i = 0; i < input_mem_ptrs.size(); ++i) { + input_mem_ptrs[i]->unlock(stream); + } + + ev->set(); + + return ev; + } + + void init_kernels(const kernels_cache& , const kernel_impl_params&) override {} + + void update_dispatch_data(const kernel_impl_params& impl_param) override {} + +public: + static std::unique_ptr create(const rope_node& arg, const kernel_impl_params& impl_param) { + return make_unique(); + } +}; + +namespace detail { + +attach_rope_impl::attach_rope_impl() { + auto formats = { + format::bfyx, + }; + + auto types = { + data_types::f32, + data_types::f16, + }; + + implementation_map::add(impl_types::cpu, shape_types::static_shape, rope_impl::create, types, formats); + implementation_map::add(impl_types::cpu, shape_types::dynamic_shape, rope_impl::create, types, formats); +} + +} // namespace detail +} // namespace cpu +} // namespace cldnn + +BIND_BINARY_BUFFER_WITH_TYPE(cldnn::cpu::rope_impl) diff --git a/src/plugins/intel_gpu/src/graph/include/rope_inst.h b/src/plugins/intel_gpu/src/graph/include/rope_inst.h new file mode 100644 index 00000000000000..669032ac6965f9 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/include/rope_inst.h @@ -0,0 +1,39 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "intel_gpu/primitives/rope.hpp" +#include "primitive_inst.h" + +#include + +namespace cldnn { +template <> +struct typed_program_node : public typed_program_node_base { + using parent = typed_program_node_base; + +public: + using parent::parent; + + program_node& input(size_t idx = 0) const { return get_dependency(idx); } + std::vector get_shape_infer_dependencies() const override { return {}; } +}; + +using rope_node = typed_program_node; + +template <> +class typed_primitive_inst : public typed_primitive_inst_base { + using parent = typed_primitive_inst_base; + using parent::parent; + +public: + template + static std::vector calc_output_layouts(const rope_node& /*node*/, const kernel_impl_params& impl_param); + static layout calc_output_layout(rope_node const& node, kernel_impl_params const& impl_param); + static std::string to_string(rope_node const& node); +}; + +using rope_inst = typed_primitive_inst; +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/rope.cpp b/src/plugins/intel_gpu/src/graph/rope.cpp new file mode 100644 index 00000000000000..d9aefa90d48293 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/rope.cpp @@ -0,0 +1,77 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "rope_inst.h" + +#include "primitive_type_base.h" +#include "json_object.h" +#include + +namespace cldnn { +GPU_DEFINE_PRIMITIVE_TYPE_ID(rope) + +layout rope_inst::calc_output_layout(rope_node const& node, kernel_impl_params const& impl_param) { + return calc_output_layouts(node, impl_param)[0]; +} + +template +std::vector rope_inst::calc_output_layouts(rope_node const& node, kernel_impl_params const& impl_param) { + auto desc = impl_param.typed_desc(); + + auto input0_layout = impl_param.get_input_layout(0); + auto input_pshape = input0_layout.get(); + auto output_format = input0_layout.format; + + auto output_type = desc->output_data_types[0].value_or(input0_layout.data_type); + if (impl_param.has_fused_primitives()) { + output_type = impl_param.get_fused_output_layout().data_type; + } + + ShapeType output_shape = input_pshape; + + if (desc->config.is_qwen) { + // Qwen specific RoPE + // input [batch_size, cur_length, (hidden_states_q + hidden_states_k + hidden_states_v)] + // output [batch_size, cur_length, head_cnt, head_size] + output_shape = {input_pshape[0], input_pshape[1], ov::Dimension(desc->config.head_cnt), ov::Dimension(desc->config.head_size)}; + } else if (desc->config.is_chatglm) { + // chatGLM specific RoPE + // input [length, batch_size, (hidden_states_q + hidden_states_k + hidden_states_v)] + // output [length, batch_size, head_cnt, hidden_states_k] + output_shape = {input_pshape[0], input_pshape[1], ov::Dimension(desc->config.head_cnt), ov::Dimension(desc->config.head_size)}; + // mb last dim another <--------------------------------------------------------------------------------------------------------------- + } else { + auto input_slice_size = desc->config.slice_stop - desc->config.slice_start; + if (input_slice_size > 0) { + output_shape[3] = input_slice_size; + } + if (desc->config.input_trans0213) { + // transpose 0213 ([B,L,H,S]=>[B,H,L,S]) happens before RoPE + std::swap(output_shape[2], output_shape[1]); + } else if (desc->config.is_interleaved) { + // transpose 0213 ([B,L,H,S]=>[B,H,L,S]) happens after RoPE + std::swap(output_shape[2], output_shape[1]); + } + } + return { layout{output_shape, output_type, output_format} }; +} + +template std::vector rope_inst::calc_output_layouts(rope_node const& node, const kernel_impl_params& impl_param); + +std::string rope_inst::to_string(rope_node const& node) { + auto desc = node.get_primitive(); + auto node_info = node.desc_to_json(); + + std::stringstream primitive_description; + + json_composite rope_info; + //rope_info.add("", ); + + node_info->add("rope info", rope_info); + node_info->dump(primitive_description); + + return primitive_description.str(); +} + +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/plugin/graph.cpp b/src/plugins/intel_gpu/src/plugin/graph.cpp index 0b1748b36ab76d..c01aeef3651eac 100644 --- a/src/plugins/intel_gpu/src/plugin/graph.cpp +++ b/src/plugins/intel_gpu/src/plugin/graph.cpp @@ -195,6 +195,7 @@ std::shared_ptr Graph::get_runtime_model(std::vector& op) { + validate_inputs_count(op, {3, 4}); + auto inputs = p.GetInputInfo(op); + const auto& config = op->get_config(); + + if (config.input_trans0213) { // Calculate in implementation instead + std::cout << "input_trans0213" << std::endl; + auto& input_pshape = op->get_input_partial_shape(0); + std::vector transposeOrder(input_pshape.size()); + std::iota(transposeOrder.begin(), transposeOrder.end(), 0); + std::swap(*(transposeOrder.begin() + 1), *(transposeOrder.begin() + 2)); + + auto permuteName = op->get_friendly_name() + "_trans0213"; + auto permutePrim = cldnn::permute(permuteName, + cldnn::input_info(inputs[0].pid), + transposeOrder); + p.add_primitive(*op, permutePrim); + inputs[0] = cldnn::input_info(permuteName); + } + + // if (config.is_interleaved) { + // add transpose afer RoPE + // } + + auto rope = cldnn::rope(layer_type_name_ID(op), + inputs, + config); + + p.add_primitive(*op, rope); +} + +REGISTER_FACTORY_IMPL(internal, RoPE); + +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations/op/rope.cpp b/src/plugins/intel_gpu/src/plugin/transformations/op/rope.cpp new file mode 100644 index 00000000000000..82f6cf667c07d9 --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/transformations/op/rope.cpp @@ -0,0 +1,78 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "intel_gpu/op/rope.hpp" + +namespace ov { +namespace intel_gpu { +namespace op { + +RoPE::RoPE(const OutputVector& args, const Config& cfg) : Op(args), m_config(cfg) { + constructor_validate_and_infer_types(); +} + +std::shared_ptr RoPE::clone_with_new_inputs(const ov::OutputVector& new_args) const { + check_new_args_count(this, new_args); + return std::make_shared(new_args, m_config); +} + +void RoPE::validate_and_infer_types() { + auto input_pshape = get_input_partial_shape(0); + auto input_slice_size = m_config.slice_stop - m_config.slice_start; + + if (m_config.is_qwen) { + // Qwen specific RoPE + // input [batch_size, cur_length, (hidden_states_q + hidden_states_k + hidden_states_v)] + // output [batch_size, cur_length, head_cnt, head_size] + set_output_type( + 0, + get_input_element_type(0), + {input_pshape[0], input_pshape[1], ov::Dimension(m_config.head_cnt), ov::Dimension(m_config.head_size)}); + return; + } + + if (m_config.is_chatglm) { + // chatGLM specific RoPE + // input [length, batch_size, (hidden_states_q + hidden_states_k + hidden_states_v)] + // output [length, batch_size, head_cnt, hidden_states_k] + set_output_type( + 0, + get_input_element_type(0), + {input_pshape[0], input_pshape[1], ov::Dimension(m_config.head_cnt), ov::Dimension(m_config.head_size)}); + return; + } + + if (input_slice_size > 0) { + input_pshape[3] = input_slice_size; + } + if (m_config.input_trans0213) { + // transpose 0213 ([B,L,H,S]=>[B,H,L,S]) happens before RoPE + std::swap(input_pshape[2], input_pshape[1]); + } else if (m_config.is_interleaved) { + // transpose 0213 ([B,L,H,S]=>[B,H,L,S]) happens after RoPE + std::swap(input_pshape[2], input_pshape[1]); + } + + set_output_type(0, get_input_element_type(0), input_pshape); +} + +bool RoPE::visit_attributes(ov::AttributeVisitor& visitor) { + visitor.start_structure("config"); + visitor.on_attribute("slice_start", m_config.slice_start); + visitor.on_attribute("slice_stop", m_config.slice_stop); + visitor.on_attribute("input_trans0213", m_config.input_trans0213); + visitor.on_attribute("is_interleaved", m_config.is_interleaved); + visitor.on_attribute("rotary_ndims", m_config.rotary_ndims); + visitor.on_attribute("is_chatglm", m_config.is_chatglm); + visitor.on_attribute("is_qwen", m_config.is_qwen); + visitor.on_attribute("head_cnt", m_config.head_cnt); + visitor.on_attribute("head_size", m_config.head_size); + visitor.on_attribute("gather_position_arg_id", m_config.gather_position_arg_id); + visitor.finish_structure(); + return true; +} + +} // namespace op +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations/rope_fusion.cpp b/src/plugins/intel_gpu/src/plugin/transformations/rope_fusion.cpp new file mode 100644 index 00000000000000..d69be4a031a8d6 --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/transformations/rope_fusion.cpp @@ -0,0 +1,730 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "rope_fusion.hpp" + +#include +#include +#include +#include "openvino/opsets/opset1.hpp" +#include +#include +#include +#include +#include + +#include "ov_ops/type_relaxed.hpp" + +namespace ov { +namespace intel_gpu { + +using namespace transformation_utils; +using namespace ov::pass::pattern; + +RoPEFusionGPTNEOX::RoPEFusionGPTNEOX() { + // rope pattern matching triggers a little design flaw: + // y1 = mul(x, cos) + // y2 = mul(x, sin) + // y = add(y1, y2) + // there is a chance that in 'y1' branch, pattern x is mapped to actual value of cos (mul is commutable) + // this leads to the matching failure of 'y2' branch, because cos didn't appear in that + // branch. + // so here we use a WA, only match the path of rotate_hal(x)*sin and check the x*cos path + // in the callback + + auto x = transformation_utils::makePattern(ov::Rank(4)); + auto x_or_cos1 = transformation_utils::makePattern(ov::Rank(4)); + auto x_or_cos2 = transformation_utils::makePattern(ov::Rank(4)); + auto t_sin = transformation_utils::makePattern(ov::Rank(4)); + + x->set_friendly_name("x"); + + auto half_ndims = transformation_utils::Symbol("half_ndims"); + auto int32_max = std::numeric_limits::max(); + + // rotate half : [-x2, x1] + auto x2 = transformation_utils::GenSlice(x, half_ndims, int32_max, 1, 3); + auto x2neg = transformation_utils::makePattern({x2, -1.0f}, {{"auto_broadcast", "numpy"}}); + auto x1 = transformation_utils::GenSlice(x, 0, half_ndims, 1, 3); + auto x_rotate_half = transformation_utils::makePattern({x2neg, x1}, {{"axis", -1}}); + + auto mul_cos = transformation_utils::makePattern({x_or_cos1, x_or_cos2}, {{"auto_broadcast", "numpy"}}); + auto mul_sin = transformation_utils::makePattern({x_rotate_half, t_sin}, {{"auto_broadcast", "numpy"}}); + + // [x1, x2]*cos + [-x2, x1]*sin + auto result = transformation_utils::makePattern({mul_cos, mul_sin}, {{"auto_broadcast", "numpy"}}); + + matcher_pass_callback callback = [=](Matcher& m) { + // std::cout << "CALLBACK RoPEFusionGPTNEOX ENTER" << std::endl; + transformation_utils::PatternValidator validator(m); + if (!validator) { + return false; + } + // std::cout << "CALLBACK RoPEFusionGPTNEOX PASSED" << std::endl; + + const auto& pattern_map = m.get_pattern_value_map(); + auto root = m.get_match_root(); + + // check mul(x, cos) exists + Output v_cos; + if (pattern_map.at(x_or_cos1) == pattern_map.at(x)) { + v_cos = pattern_map.at(x_or_cos2); + } else if (pattern_map.at(x_or_cos2) == pattern_map.at(x)) { + v_cos = pattern_map.at(x_or_cos1); + } else { + // not a RoPE + return false; + } + + op::RoPE::Config config; + OutputVector new_args; + config.rotary_ndims = 2 * validator["half_ndims"]; + + new_args.push_back(pattern_map.at(x)); + new_args.push_back(v_cos); + new_args.push_back(pattern_map.at(t_sin)); + + auto old_node = root; + auto new_node = std::make_shared(new_args, config); + new_node->set_friendly_name(old_node->get_friendly_name()); + ov::replace_node(old_node, new_node); + + // this new node may match following additional matchers + register_new_node(new_node); + + return true; + }; + + auto m = std::make_shared(result, "RoPEFusionGPTNEOX"); + this->register_matcher(m, callback); +} + +RoPEFusionCosSinPreprocess::RoPEFusionCosSinPreprocess() { + auto cos_const = transformation_utils::makePattern({}); // "f32[1,1,2048,24]" + auto sin_const = transformation_utils::makePattern({}); // "f32[1,1,2048,24]" + + auto node_batch_size = transformation_utils::makePattern("i32[1]"); + auto tile_batch = transformation_utils::makePattern("i32[1]"); + auto gather_positions = transformation_utils::makePattern("i32[?,?,?,?]"); + + auto prepare_cos_sin_gptneox = [&](std::shared_ptr const_tab) { + auto slice1 = transformation_utils::makePattern({const_tab, {0}, node_batch_size, {1}}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + return transformation_utils::makePattern({slice1, gather_positions}, {{"axis", 2}}); + }; + + auto seq_len = transformation_utils::makePattern("i32[1]"); + auto gather_positions_2d = transformation_utils::makePattern("i32[?,?]"); + + auto head_dims = transformation_utils::Symbol("head_dims"); + auto prepare_cos_sin_llama = [&](std::shared_ptr const_tab) { + auto ScatterUpdate = transformation_utils::makePattern({{0, 0, 0}, 2, seq_len, 0}); + auto slice_Slice = transformation_utils::makePattern({const_tab, {0, 0, 0}, ScatterUpdate, {1, 1, 1}}, + {{"begin_mask", {1, 1, 0}}, + {"end_mask", {1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto squeeze = transformation_utils::makePattern({slice_Slice, {-1, head_dims}}); + auto index_Gather = transformation_utils::makePattern({squeeze, gather_positions_2d, 0}, {{"batch_dims", 0}}); + + // another simplified pattern for gathering at position_ids + auto slice_Slice2 = transformation_utils::makePattern({const_tab, {0}, seq_len, {1}}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto index_Gather2 = transformation_utils::makePattern({slice_Slice2, gather_positions_2d, 0}, {{"batch_dims", 0}}); + + auto unsqueeze = transformation_utils::makePattern({index_Gather | index_Gather2, {1, 1, -1, head_dims}}); + auto unsqueeze2 = transformation_utils::makePattern({index_Gather2, 1}); + + return unsqueeze2 | unsqueeze; + }; + + auto cos_tab = prepare_cos_sin_gptneox(cos_const) | prepare_cos_sin_llama(cos_const); + auto sin_tab = prepare_cos_sin_gptneox(sin_const) | prepare_cos_sin_llama(sin_const); + + auto x = transformation_utils::makePattern(ov::Rank(4)); + auto rope = transformation_utils::makePattern({x, cos_tab, sin_tab}); + + matcher_pass_callback callback = [=](Matcher& m) { + std::cout << "CALLBACK ENTER LLAMA" << std::endl; + transformation_utils::PatternValidator validator(m); + if (!validator) { + return false; + } + std::cout << "CALLBACK PASS LLAMA" << std::endl; + const auto& pattern_map = m.get_pattern_value_map(); + auto root = m.get_match_root(); + auto rope_node = as_type_ptr(pattern_map.at(rope).get_node_shared_ptr()); + if (!rope_node) + return false; + + if (pattern_map.count(cos_const)) { + rope_node->set_argument(1, pattern_map.at(cos_const)); + } + if (pattern_map.count(sin_const)) { + rope_node->set_argument(2, pattern_map.at(sin_const)); + } + + auto& config = rope_node->get_config(); + if (pattern_map.count(gather_positions)) { + auto arg_id = rope_node->get_input_size(); + rope_node->set_argument(arg_id, pattern_map.at(gather_positions)); + config.gather_position_arg_id = arg_id; + } else if (pattern_map.count(gather_positions_2d)) { + auto arg_id = rope_node->get_input_size(); + rope_node->set_argument(arg_id, pattern_map.at(gather_positions_2d)); + config.gather_position_arg_id = arg_id; + } + rope_node->validate_and_infer_types(); + register_new_node(rope_node); + return true; + }; + auto m = std::make_shared(rope, "RoPEFusionCosSinPreprocess"); + this->register_matcher(m, callback); +} + +// only a fraction of head_size is rotary-embedded +RoPEFusionIOSlicing::RoPEFusionIOSlicing() { + auto int32_max = std::numeric_limits::max(); + auto data = transformation_utils::makePattern(ov::Rank(4)); + + auto ndims = transformation_utils::Symbol("ndims"); + auto x = transformation_utils::GenSlice(data, 0, ndims, 1, 3); + auto y = transformation_utils::GenSlice(data, ndims, int32_max, 1, 3); + auto x_emb = transformation_utils::makePattern({x, {}, {}}) | transformation_utils::makePattern({x, {}, {}, {}}); + auto result = transformation_utils::makePattern({x_emb, y}, {{"axis", -1}}); + + matcher_pass_callback callback = [=](Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + auto root = m.get_match_root(); + + auto rope_node = as_type_ptr(root->input_value(0).get_node_shared_ptr()); + if (!rope_node) + return false; + + transformation_utils::PatternValidator validator(m); + if (!validator) { + return false; + } + auto ndims = validator["ndims"]; + + auto& config = rope_node->get_config(); + if (config.rotary_ndims != ndims) + return false; + + // remove slice & concat + rope_node->set_argument(0, pattern_map.at(data)); + rope_node->set_friendly_name(root->get_friendly_name()); + ov::replace_node(root, rope_node); + + rope_node->validate_and_infer_types(); + register_new_node(rope_node); + return true; + }; + auto m = std::make_shared(result, "RoPEFusionIOSlicing"); + this->register_matcher(m, callback); +} + +RoPEFusionPreprocess::RoPEFusionPreprocess() { + // gptneox-preprocess of input data + auto input_to_slice = transformation_utils::makePattern(ov::Rank(4)); + auto input_to_trans = transformation_utils::makePattern(ov::Rank(4)); // no need to slice from 3S + + // in some model qkv prejection is combined and + // needs to be sliced before RoPE + auto slice_start = transformation_utils::Symbol("slice_start"); + auto slice_stop = transformation_utils::Symbol("slice_stop"); + auto input_slice = transformation_utils::GenSlice(input_to_slice, slice_start, slice_stop, 1, 3); + + // some model will transpose from [B,L,H,S] to [B,H,L,S] before RoPE + auto x = transformation_utils::makePattern({input_slice | input_to_trans, {0, 2, 1, 3}}); + auto result = transformation_utils::makePattern({x, {}, {}}) | transformation_utils::makePattern({x, {}, {}, {}}); + + matcher_pass_callback callback = [=](Matcher& m) { + transformation_utils::PatternValidator validator(m); + if (!validator) { + return false; + } + + const auto& pattern_map = m.get_pattern_value_map(); + auto root = m.get_match_root(); + auto rope_node = as_type_ptr(root); + if (!rope_node) + return false; + + auto& config = rope_node->get_config(); + + if (pattern_map.count(input_to_slice)) { + config.slice_start = validator["slice_start"]; + config.slice_stop = validator["slice_stop"]; + config.input_trans0213 = true; + rope_node->set_argument(0, pattern_map.at(input_to_slice)); + } else if (pattern_map.count(input_to_trans)) { + config.input_trans0213 = true; + rope_node->set_argument(0, pattern_map.at(input_to_trans)); + } else { + return false; + } + rope_node->validate_and_infer_types(); + register_new_node(rope_node); + return true; + }; + auto m = std::make_shared(result, "RoPEFusionPreprocess"); + this->register_matcher(m, callback); +} + +// remove stridedslice from 0 to int32_max with stride 1 +EliminateStridedSlice::EliminateStridedSlice() { + auto data = ov::pass::pattern::any_input(has_static_rank()); + auto begin = ov::pass::pattern::wrap_type(type_matches(ov::element::i32)); + auto end = ov::pass::pattern::wrap_type(type_matches(ov::element::i32)); + auto stride = ov::pass::pattern::wrap_type(type_matches(ov::element::i32)); + + auto strided_slice = + ov::pass::pattern::wrap_type({data, begin, end, stride}, [](const Output& value) { + auto s1 = as_type_ptr(value.get_node_shared_ptr()); + if (!s1->get_new_axis_mask().empty() || !s1->get_shrink_axis_mask().empty() || + !s1->get_ellipsis_mask().empty()) { + return false; + } + + auto inputs = s1->input_values(); + + auto begin = as_type_ptr(inputs[1].get_node_shared_ptr()); + auto end = as_type_ptr(inputs[2].get_node_shared_ptr()); + auto stride = as_type_ptr(inputs[3].get_node_shared_ptr()); + + if (!begin) + return false; + if (!end) + return false; + if (!stride) + return false; + + // stride is all 1 + auto v_stride = stride->cast_vector(); + for (auto& v : v_stride) { + if (v != 1) + return false; + } + + auto v_begin = begin->cast_vector(); + auto v_end = end->cast_vector(); + if (v_begin.size() != v_end.size()) { + return false; + } + + auto& begin_mask = s1->get_begin_mask(); + auto& end_mask = s1->get_end_mask(); + auto mask_size = begin_mask.size(); + if (begin_mask.size() != end_mask.size()) { + return false; + } + + auto int32_max = std::numeric_limits::max(); + size_t i = 0; + for (; i < mask_size; i++) { + if (begin_mask[i] != end_mask[i]) + return false; + // all valid [begin, end] are [0, int32_max] + if (begin_mask[i] == 0 && end_mask[i] == 0) { + if (v_begin[i] != 0 || v_end[i] != int32_max) + return false; + } + } + // the non-masked part + for (; i < v_begin.size(); i++) { + if (v_begin[i] != 0 || v_end[i] != int32_max) + return false; + } + return true; + }); + + matcher_pass_callback callback = [=](Matcher& m) { + auto root = m.get_match_root(); + return replace_output_update_name(root->output(0), root->input_value(0)); + }; + + auto m = std::make_shared(strided_slice, "EliminateStridedSlice"); + this->register_matcher(m, callback); +} + +RoPEFusionGPTJ::RoPEFusionGPTJ() { + auto int32_max = std::numeric_limits::max(); + auto ndims = transformation_utils::Symbol("ndims"); + + auto view_Reshape = transformation_utils::makePattern(ov::Rank(4)); + + // view_Reshape : B,L,H,S + auto slice_Slice_965 = transformation_utils::GenSlice(view_Reshape, 0, ndims, 1, 3); + + auto gather_sin_cos = transformation_utils::makePattern("f32"); + + auto varsplit = transformation_utils::makePattern({gather_sin_cos, -1, {ndims / 2, -1}}); + varsplit->set_output_size(2); + auto unsqueeze_sin = transformation_utils::makePattern({varsplit->output(0), {1, -1, 1, 32}}); + auto unsqueeze_cos = transformation_utils::makePattern({varsplit->output(1), {1, -1, 1, 32}}); + // repeate cos/sin table + auto const_idx = makeConst(ov::element::i32, ov::PartialShape::dynamic(), [](const ov::op::v0::Constant& node) { + const auto& vec = node.get_vector(); + int32_t v = 0; + for (size_t i = 0; i < vec.size(); i += 2, v++) { + if (vec[i] != v || vec[i + 1] != v) + return false; + } + return true; + }); + auto repeat_interleave_sin = transformation_utils::makePattern({unsqueeze_sin, const_idx, 3}, {{"batch_dims", 0}}); + auto repeat_interleave_cos = transformation_utils::makePattern({unsqueeze_cos, const_idx, 3}, {{"batch_dims", 0}}); + + auto t_cos = transformation_utils::makePattern(ov::Rank(4)); + auto t_sin = transformation_utils::makePattern(ov::Rank(4)); + + // x interleave (-x[:,:,:, 1::2], x[:,:,:, 0::2]) + auto slice_Slice_1174 = transformation_utils::GenSlice(slice_Slice_965, 1, int32_max, 2, 3); + + auto neg_Multiply_1177 = transformation_utils::makePattern({slice_Slice_1174, -1.0f}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_65524 = transformation_utils::makePattern({neg_Multiply_1177, -1}); + + auto slice_Slice_1168 = transformation_utils::GenSlice(slice_Slice_965, 0, int32_max, 2, 3); + auto Unsqueeze_65525 = transformation_utils::makePattern({slice_Slice_1168, -1}); + auto stack_1182 = transformation_utils::makePattern({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}}); + + auto ShapeOf_169068 = transformation_utils::makePattern({stack_1182}); + auto flatten_Slice_1194 = transformation_utils::GenSlice(ShapeOf_169068, 0, 3, 1, 0); + auto flatten_Concat_1197 = transformation_utils::makePattern({flatten_Slice_1194, {-1}}, {{"axis", 0}}); + auto flatten_Reshape_1198 = transformation_utils::makePattern({stack_1182, flatten_Concat_1197}); + + // x*cos [B,L,H,ndims] + auto mul_cos = + transformation_utils::makePattern({slice_Slice_965, repeat_interleave_cos}, {{"auto_broadcast", "numpy"}}); + auto mul_sin = + transformation_utils::makePattern({flatten_Reshape_1198, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}}); + + // *cos + *sin + auto rotary_emb = transformation_utils::makePattern({mul_cos, mul_sin}, {{"auto_broadcast", "numpy"}}); + + auto slice_Slice_971 = transformation_utils::GenSlice(view_Reshape, ndims, int32_max, 1, 3); + auto cat_Concat_1211 = transformation_utils::makePattern({rotary_emb, slice_Slice_971}, {{"axis", -1}}); + auto permute_Transpose_1213 = transformation_utils::makePattern({cat_Concat_1211, {0, 2, 1, 3}}); + + auto result = permute_Transpose_1213; + + matcher_pass_callback callback = [=](Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + auto root = m.get_match_root(); + transformation_utils::PatternValidator validator(m); + if (!validator) { + return false; + } + + op::RoPE::Config config; + OutputVector new_args; + config.rotary_ndims = validator["ndims"]; + + config.is_interleaved = true; + + // input is [B,L,H,S] + new_args.push_back(pattern_map.at(view_Reshape)); + // sin_cos table (gathered with positions) [1, L, 64] + new_args.push_back(pattern_map.at(gather_sin_cos)); + new_args.push_back(pattern_map.at(gather_sin_cos)); + + auto old_node = root; + + auto new_node = std::make_shared(new_args, config); + new_node->set_friendly_name(old_node->get_friendly_name()); + ov::replace_node(old_node, new_node); + return true; + }; + + auto m = std::make_shared(result, "RoPEFusionGPTJ"); + this->register_matcher(m, callback); +} + +RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id) { + // std::cout << "RoPEFusionChatGLM ENTER" << std::endl; + auto qkv_linear = transformation_utils::makePattern("[?,?,?]"); // f32[seq_length, batch_size, 4608] + auto seq_length = transformation_utils::makePattern("i32[1]"); + auto cos_sin_cache = transformation_utils::makePattern("[?,?,?,?]"); // [max_pos_embeddings, batch_size, 32, 2] + + auto ndims = transformation_utils::Symbol("ndims"); //64 + auto head_cnt = transformation_utils::Symbol("head_cnt"); //32 + auto head_size = transformation_utils::Symbol("head_size"); //128 + auto total_size_q = transformation_utils::Symbol("total_size_q"); // 4096 + auto total_size_k = transformation_utils::Symbol("total_size_k"); // 256 + auto total_size_v = transformation_utils::Symbol("total_size_v"); // 256 + + auto qkv_proj = transformation_utils::makePattern({ /*qkv_linear*/qkv_linear, -1, {total_size_q, total_size_k, total_size_v}}); + qkv_proj->set_output_size(3); + + // get key [L, B, Hkv, S] + auto cur_key = transformation_utils::makePattern({qkv_proj->output(split_output_id), {0, 0, head_cnt, head_size}}, + {{"special_zero", true}}); + + auto slice_Slice_437 = transformation_utils::makePattern({cur_key, {0, 0, 0, 0}, {0, 0, 0, ndims}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + + // rotate half + auto ListConstruct_452_Concat = + transformation_utils::makePattern({seq_length, {-1}, {head_cnt}, {ndims / 2}, {2}}, {{"axis", 0}}); + auto ListConstruct_379_Concat = + transformation_utils::makePattern({seq_length, {-1}, {1}, {ndims / 2}, {2}}, {{"axis", 0}}); + + auto reshape_Reshape_453 = + transformation_utils::makePattern({slice_Slice_437, ListConstruct_452_Concat}, {{"special_zero", false}}); + auto x_even = transformation_utils::makePattern({reshape_Reshape_453, 0, -1}, {{"batch_dims", 0}}); + auto slice_Slice_449 = transformation_utils::makePattern({/*cos_sin_cache*/cos_sin_cache, {0}, seq_length, {1}}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto view_Reshape_460 = + transformation_utils::makePattern({slice_Slice_449, ListConstruct_379_Concat}, {{"special_zero", false}}); + auto cos_tab = transformation_utils::makePattern({view_Reshape_460, 0, -1}, {{"batch_dims", 0}}); + auto x_even_cos = transformation_utils::makePattern({x_even, cos_tab}, {{"auto_broadcast", "numpy"}}); + auto x_odd = transformation_utils::makePattern({reshape_Reshape_453, 1, -1}, {{"batch_dims", 0}}); + auto sin_tab = transformation_utils::makePattern({view_Reshape_460, 1, -1}, {{"batch_dims", 0}}); + auto x_odd_sin = transformation_utils::makePattern({x_odd, sin_tab}, {{"auto_broadcast", "numpy"}}); + auto neg_x_odd_sin = transformation_utils::makePattern({x_odd_sin, -1.000000f}, {{"auto_broadcast", "numpy"}}); + auto sub_Subtract_469 = transformation_utils::makePattern({x_even_cos, neg_x_odd_sin}, {{"auto_broadcast", "numpy"}}); + + auto y_even = transformation_utils::makePattern({sub_Subtract_469, -1}); + auto x_odd_cos = transformation_utils::makePattern({x_odd, cos_tab}, {{"auto_broadcast", "numpy"}}); + auto x_even_sin = transformation_utils::makePattern({x_even, sin_tab}, {{"auto_broadcast", "numpy"}}); + auto add_Add_476 = transformation_utils::makePattern({x_odd_cos, x_even_sin}, {{"auto_broadcast", "numpy"}}); + auto y_odd = transformation_utils::makePattern({add_Add_476, -1}); + + auto stack_481 = transformation_utils::makePattern({y_even, y_odd}, {{"axis", -1}}); + + auto ShapeOf_135133 = transformation_utils::makePattern({stack_481}); + auto flatten_Slice_497 = transformation_utils::makePattern({ShapeOf_135133, {0}, {3}, {1}}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto flatten_Concat_500 = transformation_utils::makePattern({flatten_Slice_497, {-1}}, {{"axis", 0}}); + auto const_target_shape = makeConst({0, 0, head_cnt, ndims}); + // [length, batch, head_cnt, half_rotary_dims, 2] + auto flatten_Reshape_501 = + transformation_utils::makePattern({stack_481, flatten_Concat_500 | const_target_shape}, {{"special_zero", true}}); + auto slice_Slice_443 = + transformation_utils::makePattern({cur_key, {0, 0, 0, ndims}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto cat_Concat_505 = transformation_utils::makePattern({flatten_Reshape_501, slice_Slice_443}, {{"axis", -1}}); + + auto result = cat_Concat_505; + + matcher_pass_callback callback = [=](Matcher& m) { + // std::cout << "ChatGLM callback" << std::endl; + const auto& pattern_map = m.get_pattern_value_map(); + auto root = m.get_match_root(); + transformation_utils::PatternValidator validator(m); + // if (!validator) { + // return false; + // } + // std::cout << "ChatGLM callback PASSED" << std::endl; + + op::RoPE::Config config; + OutputVector new_args; + config.rotary_ndims = 64;//validator["ndims"]; + config.is_chatglm = true; + config.head_cnt = split_output_id == 0 ? 32 : 2;//validator["head_cnt"]; + config.head_size = 128;//validator["head_size"]; + + if (split_output_id == 0) { + // query : split_output_id == 0 + config.slice_start = 0; + config.slice_stop = 4096;//validator["total_size_q"]; + } else { + // key : split_output_id == 1 + config.slice_start = 4096;//validator["total_size_q"]; + config.slice_stop = config.slice_start + 256;//validator["total_size_k"]; + } + + // std::cout << config.rotary_ndims << " | " + // << config.head_cnt << " | " + // << config.head_size << " | " + // << config.slice_start << " | " + // << config.slice_stop << std::endl; + + new_args.push_back(pattern_map.at(qkv_linear)); + new_args.push_back(pattern_map.at(cos_sin_cache)); + new_args.push_back(pattern_map.at(cos_sin_cache)); + + auto old_node = root; + + auto new_node = std::make_shared(new_args, config); + new_node->set_friendly_name(old_node->get_friendly_name()); + ov::replace_node(old_node, new_node); + return true; + }; + + auto m = std::make_shared(result, "RoPEFusionChatGLM"); + this->register_matcher(m, callback); +} + +RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { + // std::cout << "RoPEFusionQwen ENTER" << std::endl; + // rotary_emb_cos & rotary_emb_sin are sliced by present kv-length (past-kv-length + cur_len) + auto rotary_emb_cos = transformation_utils::makePattern("f32[1,?,1,?]"); // [1,..4096,1,128] + auto rotary_emb_sin = transformation_utils::makePattern("f32[1,?,1,?]"); // [1,..4096,1,128] + auto qkv_proj = transformation_utils::makePattern("f32[?,?,?]"); // f32[?,?,12288] + + auto head_cnt = transformation_utils::Symbol("head_cnt"); + auto head_size = transformation_utils::Symbol("head_size"); + + auto ListUnpack_410_VariadicSplit = + transformation_utils::makePattern({qkv_proj, 2, {head_cnt * head_size, head_cnt * head_size, -1}}); + ListUnpack_410_VariadicSplit->set_output_size(3); + // B,L,H,S + auto view_Reshape_424 = transformation_utils::makePattern( + {ListUnpack_410_VariadicSplit->output(split_output_id), {0, 0, head_cnt, head_size}}, + {{"special_zero", true}}); + auto slice_Slice_543 = + transformation_utils::makePattern({view_Reshape_424, {0, 0, 0, 0}, {0, 0, 0, head_size}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); // tensor_array + + auto hidden_states = any_input();//transformation_utils::makePattern("f32[?,?,?]"); // + auto ShapeOf_485735 = transformation_utils::makePattern({hidden_states}, {}); + auto Multiply_567524 = transformation_utils::makePattern({ShapeOf_485735, {-1}}, {{"auto_broadcast", "numpy"}}); + auto Gather_377635 = transformation_utils::makePattern({Multiply_567524, {1}, 0}, {{"batch_dims", 0}}); + + auto input_ids = any_input();//transformation_utils::makePattern("i32[?,?]"); // [batch, length] + auto ShapeOf_409241 = transformation_utils::makePattern({input_ids}, {}); + auto Gather_311651 = transformation_utils::makePattern({ShapeOf_409241, {1}, 0}, {{"batch_dims", 0}}); + auto neg_Multiply = transformation_utils::makePattern({Gather_311651, {-1}}, {{"auto_broadcast", "numpy"}}); + + auto ScatterUpdate_463814 = makePattern({{0, 0}, {1}, Gather_377635 | neg_Multiply, {0}}); + + auto slice_Slice_446 = + transformation_utils::makePattern({rotary_emb_cos, ScatterUpdate_463814, {0, INT_MAX}, {1, 1}}, + {{"begin_mask", {1, 0}}, + {"end_mask", {1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); // tensor_array + auto mul_Multiply_552 = + transformation_utils::makePattern({slice_Slice_543, slice_Slice_446}, + {{"auto_broadcast", "numpy"}}); // tensor_array + + auto reshape_opt1 = [&](std::shared_ptr input_BLHS) { + auto ShapeOf_485814 = transformation_utils::makePattern({input_BLHS}, {}); + auto Gather_377647 = transformation_utils::makePattern({ShapeOf_485814, {1}, 0}, {{"batch_dims", 0}}); + // batch-size, we don't care + auto Gather_377641 = transformation_utils::makePattern("i32[1]"); + auto ListConstruct_581_Concat = + transformation_utils::makePattern({Gather_377641, Gather_377647, {head_cnt}, {2}, {head_size / 2}}, + {{"axis", 0}}); + auto Gather_391791 = transformation_utils::makePattern({ShapeOf_485814, {0, 1}, 0}, {{"batch_dims", 0}}); + auto ListConstruct_522_Concat = transformation_utils::makePattern({Gather_391791, {32}, {2}, {64}}, {{"axis", 0}}); + + auto reshape_Reshape_577 = + transformation_utils::makePattern({input_BLHS, {-1, 2, head_size / 2}}, {{"special_zero", true}}); + return transformation_utils::makePattern({reshape_Reshape_577, ListConstruct_581_Concat | ListConstruct_522_Concat}, + {{"special_zero", false}}); // tensor_array + }; + + auto reshape_opt2 = [&](std::shared_ptr input_BLHS) { + return transformation_utils::makePattern({input_BLHS, {0, 0, 0, 2, head_size / 2}}, + {{"special_zero", true}}); // tensor_array + }; + + auto ListUnpack_586_Split = + transformation_utils::makePattern({reshape_opt1(slice_Slice_543) | reshape_opt2(slice_Slice_543), -2}, + {{"num_splits", 2}}); // tensor_array + ListUnpack_586_Split->set_output_size(2); + + auto ListUnpack_586_Squeeze_0 = + transformation_utils::makePattern({ListUnpack_586_Split->output(1), -2}); // tensor_array + + auto Multiply_567527 = + transformation_utils::makePattern({ListUnpack_586_Squeeze_0, -1.000000f}, + {{"auto_broadcast", "numpy"}}); // tensor_array + + auto ListUnpack_586_Squeeze = + transformation_utils::makePattern({ListUnpack_586_Split->output(0), -2}); // tensor_array + auto cat_Concat_593 = transformation_utils::makePattern({Multiply_567527, ListUnpack_586_Squeeze}, + {{"axis", -1}}); // tensor_array + auto slice_Slice_470 = + transformation_utils::makePattern({rotary_emb_sin, ScatterUpdate_463814, {0, INT_MAX}, {1, 1}}, + {{"begin_mask", {1, 0}}, + {"end_mask", {1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); // tensor_array + auto mul_Multiply_594 = + transformation_utils::makePattern({cat_Concat_593, slice_Slice_470}, + {{"auto_broadcast", "numpy"}}); // tensor_array + auto add_Add_597 = transformation_utils::makePattern({mul_Multiply_552, mul_Multiply_594}, + {{"auto_broadcast", "numpy"}}); // tensor_array + + auto result = add_Add_597; + + matcher_pass_callback callback = [=](Matcher& m) { + std::cout << "CALLBACK ENTER" << std::endl; + const auto& pattern_map = m.get_pattern_value_map(); + auto root = m.get_match_root(); + transformation_utils::PatternValidator validator(m); + if (!validator) { + return false; + } + std::cout << "CALLBACK PASSED" << std::endl; + + op::RoPE::Config config; + OutputVector new_args; + config.is_qwen = true; + config.head_cnt = validator["head_cnt"]; + config.head_size = validator["head_size"]; + config.rotary_ndims = config.head_size; + + if (split_output_id == 0) { + // query : split_output_id == 0 + config.slice_start = 0; + config.slice_stop = config.head_cnt * config.head_size; + } else { + // key : split_output_id == 1 + config.slice_start = config.head_cnt * config.head_size; + config.slice_stop = config.slice_start + config.head_cnt * config.head_size; + } + + new_args.push_back(pattern_map.at(qkv_proj)); + new_args.push_back(pattern_map.at(rotary_emb_cos)); + new_args.push_back(pattern_map.at(rotary_emb_sin)); + + auto old_node = root; + auto new_node = std::make_shared(new_args, config); + new_node->set_friendly_name(old_node->get_friendly_name()); + ov::replace_node(old_node, new_node); + return true; + }; + + auto m = std::make_shared(result, "RoPEFusionQwen"); + this->register_matcher(m, callback); +} + +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations/rope_fusion.hpp b/src/plugins/intel_gpu/src/plugin/transformations/rope_fusion.hpp new file mode 100644 index 00000000000000..08cf6f6422cf08 --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/transformations/rope_fusion.hpp @@ -0,0 +1,83 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "intel_gpu/op/rope.hpp" +#include "utils.hpp" + +namespace ov { +namespace intel_gpu { + +class RoPEFusionGPTNEOX : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("RoPEFusionGPTNEOX", "0"); + RoPEFusionGPTNEOX(); +}; + +class RoPEFusionGPTJ : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("RoPEFusionGPTJ", "0"); + RoPEFusionGPTJ(); +}; + +class RoPEFusionChatGLM : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("RoPEFusionChatGLM", "0"); + RoPEFusionChatGLM(int split_output_id); +}; + +class RoPEFusionQwen : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("RoPEFusionQwen", "0"); + RoPEFusionQwen(int split_output_id); +}; + +class RoPEFusionIOSlicing : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("RoPEFusionIOSlicing", "0"); + RoPEFusionIOSlicing(); +}; + +class RoPEFusionPreprocess : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("RoPEFusionPreprocess", "0"); + RoPEFusionPreprocess(); +}; + +class RoPEFusionCosSinPreprocess : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("RoPEFusionCosSinPreprocess", "0"); + RoPEFusionCosSinPreprocess(); +}; + +class EliminateStridedSlice : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("EliminateStridedSlice", "0"); + EliminateStridedSlice(); +}; + +class RoPEFusion : public ov::pass::GraphRewrite { +public: + OPENVINO_RTTI("RoPEFusion", "0"); + RoPEFusion() { + add_matcher(); + add_matcher(); + // optional heads & tails are fused in separate matcher pass, + // after RoPENode has been created. + add_matcher(); + add_matcher(); + add_matcher(); + + add_matcher(0); + add_matcher(1); + + add_matcher(0); + add_matcher(1); + } +}; + +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations/utils.hpp b/src/plugins/intel_gpu/src/plugin/transformations/utils.hpp new file mode 100644 index 00000000000000..7886d235c7b6e9 --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/transformations/utils.hpp @@ -0,0 +1,1298 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "openvino/opsets/opset1.hpp" +#include "openvino/opsets/opset2.hpp" +#include "openvino/opsets/opset3.hpp" +#include "openvino/opsets/opset4.hpp" +#include "openvino/opsets/opset5.hpp" +#include "openvino/opsets/opset6.hpp" +#include "openvino/opsets/opset7.hpp" +#include "openvino/opsets/opset8.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/label.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" + +namespace ov { +namespace intel_gpu { +namespace transformation_utils { + +static bool force_matcher_verbose = false; + + +#define _VERBOSE_LOG(...) + +namespace detail { +inline std::vector split_string(const std::string& s, const std::string& delimiter) { + std::vector ret; + size_t pos = 0, pos_next; + std::string token; + while ((pos_next = s.find(delimiter, pos)) != std::string::npos) { + token = s.substr(pos, pos_next - pos); + ret.push_back(token); + pos = pos_next + 1; + } + // return whole string if no delimiter if found + token = s.substr(pos, pos_next); + ret.push_back(token); + return ret; +} + +template +std::string vec2str(const std::vector& vec, int cnt_limit = 9) { + std::stringstream ss; + ss << "{"; + const char* sep = ""; + for (auto& v : vec) { + cnt_limit--; + if (cnt_limit == 0) { + ss << sep << "..."; + break; + } + ss << sep << v; + sep = ","; + } + ss << "}"; + return ss.str(); +} +} // namespace detail + +struct values_info { + values_info(const char* pattern_list = nullptr) { + if (pattern_list == nullptr || pattern_list[0] == 0) { + all_type_pshape.clear(); + return; + } + auto pattern_vector = detail::split_string(pattern_list, " "); + for (auto& pattern : pattern_vector) { + if (pattern[0] == '[') { + all_type_pshape.emplace_back(ov::element::dynamic, ov::PartialShape(pattern)); + } else { + auto sep = pattern.find("["); + if (sep != std::string::npos) { + // ele_type[p_shape] + all_type_pshape.emplace_back(ov::element::Type(pattern.substr(0, sep)), + ov::PartialShape(pattern.substr(sep))); + } else { + // ele_type + all_type_pshape.emplace_back(ov::element::Type(pattern), ov::PartialShape::dynamic()); + } + } + } + } + + size_t size() { + return all_type_pshape.size(); + } + const std::pair& operator[](int index) { + return all_type_pshape[index]; + } + + //------------------------------------------------------------- + bool predicate(const ov::Output& value) const { + if (all_type_pshape.empty()) + return true; + auto index = value.get_index(); + auto& item = all_type_pshape[index]; + if (!item.first.compatible(value.get_element_type()) || !item.second.compatible(value.get_partial_shape())) { + _VERBOSE_LOG("* mismatched vtype between value & pattern : ", + value.get_element_type(), + value.get_partial_shape(), + "vs", + item.first, + item.second); + return false; + } + return true; + } + + std::string to_string() { + std::stringstream ss; + const char* sep = ""; + for (auto& t : all_type_pshape) { + ss << sep << t.first << t.second; + sep = ";"; + } + return ss.str(); + } + + std::vector> all_type_pshape; +}; + +// Symbol : a constant that unknown at the pattern's building time +// but collected and validated after pattern was matched +// with some sub-graph values. +class Symbol { +private: + struct Entity { + const char* name = "?"; + char op; + double literal_const_value; + std::shared_ptr lhs; + std::shared_ptr rhs; + // _,+,-,*,/ + // l : literal const + // n : named symbol + double eval(const std::map& value_map) const { + switch (op) { + case 'l': + return literal_const_value; + case 'n': + return value_map.at(this); + case '+': + return lhs->eval(value_map) + rhs->eval(value_map); + case '-': + return lhs->eval(value_map) - rhs->eval(value_map); + case '*': + return lhs->eval(value_map) * rhs->eval(value_map); + case '/': + return lhs->eval(value_map) / rhs->eval(value_map); + case '_': + return -lhs->eval(value_map); + case 'r': + return std::sqrt(lhs->eval(value_map)); + default: + assert(false); + return std::numeric_limits::quiet_NaN(); + } + } + }; + std::shared_ptr entity; + +public: + Symbol() { + entity = std::make_shared(); + entity->op = 'n'; + } + Symbol(const char* name) { + entity = std::make_shared(); + entity->op = 'n'; + entity->name = name; + } + Symbol(const int value) { + entity = std::make_shared(); + entity->op = 'l'; + entity->literal_const_value = value; + } + Symbol(char op, const Symbol& lhs, const Symbol& rhs) { + entity = std::make_shared(); + entity->op = op; + entity->lhs = lhs.entity; + entity->rhs = rhs.entity; + } + double eval(const std::map& value_map) const { + return entity->eval(value_map); + } + bool is_independent_var() const { + return entity->op == 'n'; + } + int is_literal_const() const { + return entity->op == 'l'; + } + char get_op() const { + return entity->op; + } + void* get_id() const { + return entity.get(); + } + const char* get_name() const { + return entity->name; + } + bool operator<(const Symbol& rhs) const { + return get_id() < rhs.get_id(); + } +}; + +inline Symbol operator-(const Symbol& lhs) { + return Symbol('_', lhs, lhs); +} +inline Symbol operator+(const Symbol& lhs, const Symbol& rhs) { + return Symbol('+', lhs, rhs); +} +inline Symbol operator-(const Symbol& lhs, const Symbol& rhs) { + return Symbol('-', lhs, rhs); +} +inline Symbol operator*(const Symbol& lhs, const Symbol& rhs) { + return Symbol('*', lhs, rhs); +} +inline Symbol operator/(const Symbol& lhs, const Symbol& rhs) { + return Symbol('/', lhs, rhs); +} +inline Symbol sqrt(Symbol lhs) { + return Symbol('r', lhs, lhs); +} + +namespace detail { + +// AttrAny is simple wrapper of Any to provide some constructor +// to take advantage of C++ implicit conversion to allow: +// - attribute expressed using initializer_list. +// - symbol to be used as attributes +struct AttrAny { + ov::Any any; + + // empty attribute, means empty vector, and error for scalar + AttrAny() {} + + AttrAny(const Symbol& v) : any(v) {} + AttrAny(const ov::element::Type& v) : any(v) {} + AttrAny(const ov::PartialShape& v) : any(v) {} + AttrAny(const ov::Dimension& v) : any(v) {} + AttrAny(bool v) : any(v) {} + AttrAny(int v) : any(v) {} + AttrAny(float v) : any(v) {} + AttrAny(double v) : any(v) {} + AttrAny(long v) : any(static_cast(v)) {} + AttrAny(long long v) : any(static_cast(v)) {} + AttrAny(const char* v) : any(v) {} + AttrAny(const std::string& v) : any(v) {} + + // template ::value>::type = true> + // AttrAny(const T& v) : any(v) {} + + // template ::value>::type = true> + // AttrAny(const std::vector& v) : any(v) {} + + AttrAny(const std::vector& v) : any(v) {} + + // template ::value>::type = true> + // AttrAny(std::initializer_list values) : any(std::vector(values)) {} + AttrAny(std::initializer_list values) : any(std::vector(values)) {} + AttrAny(std::initializer_list values) : any(std::vector(values.begin(), values.end())) {} + AttrAny(std::initializer_list values) : any(std::vector(values)) {} + AttrAny(std::initializer_list values) : any(std::vector(values)) {} + AttrAny(std::initializer_list values) : any(std::vector(values.begin(), values.end())) {} + + AttrAny(std::initializer_list values) : any(std::vector(values)) {} + AttrAny(std::initializer_list values) : any(std::vector(values)) {} + + std::string as_string() { + if (any.is()) + return any.as(); + return any.as(); + } + bool as_bool() { + if (any.is()) + return any.as(); + return any.as(); + } + double as_double() { + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + return any.as(); + } + int64_t as_int64_t() { + if (any.is()) + return any.as(); + return any.as(); + } + + template + std::vector as_vector() { + if (any.empty()) + return {}; + if (!std::is_same::value) { + if (any.is>()) { + auto ivec = any.as>(); + return std::vector(ivec.begin(), ivec.end()); + } + if (any.is>()) { + auto vec = any.as>(); + return std::vector(vec.begin(), vec.end()); + } + } + if (!std::is_same::value) { + if (any.is>()) { + auto ivec = any.as>(); + return std::vector(ivec.begin(), ivec.end()); + } + if (any.is>()) { + auto vec = any.as>(); + return std::vector(vec.begin(), vec.end()); + } + } + if (any.is>()) { + auto ivec = any.as>(); + return std::vector(ivec.begin(), ivec.end()); + } + return any.as>(); + } + + template + std::vector as_T_vector() { + if (any.empty()) + return {}; + if (any.is()) { + auto to_vec = [](std::initializer_list v) { + return std::vector(v); + }; + return to_vec({any.as()}); + } + if (any.is>()) { + auto ivec = any.as>(); + return std::vector(ivec.begin(), ivec.end()); + } + return any.as>(); + } + + std::vector as_str_vector() { + if (any.empty()) + return {}; + if (any.is>()) { + auto vec = any.as>(); + return std::vector(vec.begin(), vec.end()); + } + return any.as>(); + } + + template + T cast_to() { + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + return any.as(); + } + + template + bool equal_to(const std::vector& rhs) { + if (any.empty() && rhs.empty()) + return true; + auto& vec = any.as>(); + return std::equal(vec.begin(), vec.end(), rhs.begin()); + } + + template + bool equal_to(const std::vector& rhs) { + if (any.empty() && rhs.empty()) + return true; + + if (any.is>()) { + auto& vec = any.as>(); + return vec.size() == rhs.size() && std::equal(vec.begin(), vec.end(), rhs.begin()); + } + return equal_to(rhs); + } + + template + typename std::enable_if::value, bool>::type equal_to(const T& rhs) { + return rhs == any.as(); + } + + template + typename std::enable_if::value, bool>::type equal_to(const T& rhs) { + if (any.is()) { + auto& value = any.as(); + return rhs == static_cast(value); + } + return equal_to(rhs); + } +}; + +using AttrMap = std::map; + +class AttrSetter : public ov::AttributeVisitor { +public: + AttrMap& m_attr_map; + std::vector m_missing_attrs; + + AttrSetter(AttrMap& attrs) : m_attr_map(attrs) {} + + const std::vector& get_missing_attrs() { + return m_missing_attrs; + } + + bool should_skip(const std::string& name) { + if (m_attr_map.count(name) == 0) { + // attributes not specified is recorded as missing + m_missing_attrs.push_back(name); + return true; + } + + if (m_attr_map[name].any.is()) { + m_missing_attrs.push_back(name); + return true; + } + + if (m_attr_map[name].any.empty()) { + // input is set to empty, meaning default value is used. + return true; + } + return false; + } + + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_string()); + } + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_bool()); + } + void on_adapter(const std::string& name, ov::ValueAccessor& adapter) override { + if (should_skip(name)) + return; + auto& any = m_attr_map[name].any; + if (auto a = ov::as_type>(&adapter)) { + static_cast(*a) = any.as(); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_vector()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_vector()); + } else if (auto a = ov::as_type>>(&adapter)) { +#if defined(__APPLE__) || defined(__EMSCRIPTEN__) + static_cast&>(*a) = m_attr_map[name].as_vector(); +#else + a->set(m_attr_map[name].as_vector()); +#endif + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_vector()); + //} else if (auto a = ov::as_type>(&adapter)) { + // a->set(m_attr_map[name].as_string()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_string()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_vector()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_T_vector()); + } else { + OPENVINO_THROW("unsupported AttributeAdapter for attribute : ", name); + } + } + + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_double()); + } + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_int64_t()); + } + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_vector()); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_vector()); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_vector()); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_str_vector()); + } +}; + +class GenericPattern : public ov::pass::pattern::op::Pattern { +public: + OPENVINO_RTTI("GenericPattern"); + + explicit GenericPattern(const OutputVector& args = {}, const detail::AttrMap& attrs = {}) + : ov::pass::pattern::op::Pattern(args) { + set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic()); + m_attrs = attrs; + } + + // this allows code inside pred to access pattern node itself + void set_predicate(ov::pass::pattern::op::ValuePredicate pred) { + m_predicate = pred; + } + + bool match_value(ov::pass::pattern::Matcher* matcher, + const Output& pattern_value, + const Output& graph_value) override { + // strictly requires pattern & graph value to come from output port with same index, + // this is absolute necessary when pattern contains split node connections. + if (pattern_value.get_index() != graph_value.get_index()) + return false; + if (m_predicate(graph_value)) { + auto& pattern_map = matcher->get_pattern_value_map(); + pattern_map[shared_from_this()] = graph_value; + matcher->add_node(graph_value); + return (get_input_size() == 0 + ? true + : matcher->match_arguments(pattern_value.get_node(), graph_value.get_node_shared_ptr())); + } + return false; + } + + detail::AttrMap& get_attrs() { + return m_attrs; + } + +private: + detail::AttrMap m_attrs; +}; + +// A glue/syntax-sugar type which allows more types to be used as input to makePattern() +struct PatternNode { + std::shared_ptr node; + int output_port = -1; + + operator ov::Output() const { + return get_output(); + } + + ov::Output get_output() const { + if (output_port >= 0) + return node->output(output_port); + return node->get_default_output(); + } + + PatternNode(const Output& out) : node(out.get_node_shared_ptr()), output_port(out.get_index()) {} + + PatternNode() { + node = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank()); + } + PatternNode(ov::Rank rank) { + node = ov::pass::pattern::any_input([rank](const Output& value) { + if (!rank.compatible(value.get_partial_shape().rank())) { + _VERBOSE_LOG("*mismatched PatternNode rank ", value, " expecting ", rank); + return false; + } + return true; + }); + } + + PatternNode(values_info vt) { + node = ov::pass::pattern::any_input([vt](const Output& value) { + if (!vt.predicate(value)) { + _VERBOSE_LOG("*mismatched PatternNode ", value); + return false; + } + _VERBOSE_LOG(" matched PatternNode ", value); + return true; + }); + } + PatternNode(const std::shared_ptr& node) : node(node) {} + PatternNode(const std::shared_ptr& node) : node(node) {} + PatternNode(const std::shared_ptr& pattern) + : node(std::dynamic_pointer_cast(pattern)) {} + + // 1D-vector & scalar of symbol + PatternNode(std::initializer_list v) { + // initializer_list of Symbol ls special, need to be recorded + // and eval/check in the callback after whole match is complete, + // where all observed actual constant values are known, first + // we will go over all symbols and collect actual value for individual + // symbol(named symbol), and then we go over all derived symbols and + // evaluate their predicated values and compare against what observed, + // and check if they all match. + // node = ConstVector(std::vector(v), nullptr); + node = ov::pass::pattern::wrap_type(); + + auto& rt_info = node->get_rt_info(); + rt_info["symbolic_const_value"] = std::vector(v); + } + PatternNode(const std::vector& v) { + node = ov::pass::pattern::wrap_type(); + auto& rt_info = node->get_rt_info(); + rt_info["symbolic_const_value"] = v; + } + + PatternNode(Symbol v) { + node = ov::pass::pattern::wrap_type(); + auto& rt_info = node->get_rt_info(); + rt_info["symbolic_const_value"] = std::vector({v}); + } + + // scalar constant (treated as wildcard for single-element-constant with any rank) + PatternNode(int v) : node(std::make_shared(element::from(), Shape({}), v)) {} + PatternNode(float v) : node(std::make_shared(element::from(), Shape({}), v)) {} + + PatternNode(std::initializer_list v, values_info vi = nullptr) { + node = ConstVector(std::vector(v), vi); + } + PatternNode(std::initializer_list v, values_info vi = nullptr) { + node = ConstVector(std::vector(v), vi); + } + PatternNode(std::initializer_list v, values_info vi = nullptr) { + node = ConstVector(std::vector(v.begin(), v.end()), vi); + } + PatternNode(std::initializer_list v, values_info vi = nullptr) { + node = ConstVector(std::vector(v.begin(), v.end()), vi); + } + + // 1d const tensor or scalar + template ::value, bool>::type = true> + static std::shared_ptr ConstVector(const std::vector& vec, values_info vi = nullptr) { + if (vi.size() > 0) + return std::make_shared(vi[0].first, vi[0].second.to_shape(), vec); + // initializer_list w/o value_info means to create normal 1D vector + return std::make_shared(element::from(), Shape({vec.size()}), vec); + } +}; + +using SymbolObservationVector = std::vector>; + +template +void add_symbol_observed(SymbolObservationVector& sov, const Symbol& sym, const T& value) { + auto v = static_cast(value); + OPENVINO_ASSERT(static_cast(v) == value); // ensure there is no precison lost in double + sov.push_back(std::make_pair(sym, v)); +} +/* +template +static bool vector_equal_to_any(const std::vector& v0, detail::AttrAny& any) { + auto v1 = any.cast_to_vector(); + if (v0.size() != v1.size()) + return false; + return std::equal(v0.begin(), v0.end(), v1.begin()); +} + +template +static bool scalar_equal_to_any(const T& v0, detail::AttrAny& any) { + if (any.is()) { + return v0 == any.as(); + } else if (any.is()) { + return v0 == any.as(); + } + return v0 == any.as(); +} +*/ +// for arithmetic data type, Attr matcher will success as long as the actuall attributes +// is equal to the casted attributes from pattern w/o requiring exact type match. +class AttrMatcher : public ov::AttributeVisitor { +public: + AttrMap& m_attr_map; + std::vector m_missing_attrs; + SymbolObservationVector* m_psov; + bool m_all_matched; + + AttrMatcher(AttrMap& attrs, SymbolObservationVector* psov = nullptr) + : m_attr_map(attrs), + m_psov(psov), + m_all_matched(true) {} + + bool matched() { + return m_all_matched; + } + + const std::vector& get_missing_attrs() { + return m_missing_attrs; + } + + bool should_skip(const std::string& name, bool allow_symbol = false) { + if (m_attr_map.count(name) == 0) { + m_missing_attrs.push_back(name); + return true; + } + + if (!allow_symbol) { + OPENVINO_ASSERT(!m_attr_map[name].any.is(), "Symbol is not allowed."); + } + return false; + } + + void add_match_result(const std::string& name, bool is_matched) { + if (!is_matched) { + _VERBOSE_LOG(" attribute '", name, "' mismatch."); + } + m_all_matched = m_all_matched && is_matched; + } + + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + add_match_result(name, value.get() == m_attr_map[name].as_string()); + } + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + + // only integer is allowed to be of symbol type + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name, true)) + return; + auto& any = m_attr_map[name].any; + if (any.is()) { + if (m_psov) { + // collect symbol reference and do comparison later + add_symbol_observed(*m_psov, any.as(), value.get()); + } + return; + } + add_match_result(name, m_attr_map[name].cast_to() == value.get()); + } + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name, true)) + return; + auto& any = m_attr_map[name].any; + if (any.is()) { + if (m_psov) { + // collect symbol reference and do comparison later + add_symbol_observed(*m_psov, any.as(), value.get()); + } + return; + } + add_match_result(name, m_attr_map[name].cast_to() == value.get()); + } + + void on_adapter(const std::string& name, ov::ValueAccessor& adapter) override { + if (should_skip(name)) + return; + OPENVINO_ASSERT(m_attr_map.count(name) > 0); + auto& any = m_attr_map[name].any; + bool is_matched = true; + if (auto a = ov::as_type>(&adapter)) { + is_matched = (static_cast(*a) == any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = (a->get() == any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = (a->get() == any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = m_attr_map[name].equal_to(a->get()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = m_attr_map[name].equal_to(a->get()); + } else if (auto a = ov::as_type>>(&adapter)) { +#if defined(__APPLE__) || defined(__EMSCRIPTEN__) + is_matched = m_attr_map[name].equal_to(static_cast&>(*a)); +#else + is_matched = m_attr_map[name].equal_to(a->get()); +#endif + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = m_attr_map[name].equal_to(a->get()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = (a->get() == any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = (a->get() == any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = m_attr_map[name].equal_to(a->get()); + } else { + OPENVINO_THROW("AttrMatcher met unsupported AttributeAdapter ", name); + } + add_match_result(name, is_matched); + } +}; +} // namespace detail + +//================================================================================================== + +inline std::shared_ptr GenInput(values_info vt = nullptr) { + return ov::pass::pattern::any_input([vt](const Output& value) { + if (!vt.predicate(value)) { + _VERBOSE_LOG("*mismatched GenInput ", value); + return false; + } + _VERBOSE_LOG(" matched GenInput ", value); + return true; + }); +} + +inline std::shared_ptr makePattern() { + detail::PatternNode g; + return g.node; +} + +inline std::shared_ptr makePattern(ov::Rank rank) { + detail::PatternNode g(rank); + return g.node; +} + +inline std::shared_ptr makePattern(values_info vt) { + detail::PatternNode g(vt); + return g.node; +} + +// unknown const +inline std::shared_ptr makeConst(const ov::element::Type& type, + const ov::PartialShape& pshape, + std::function pred) { + return ov::pass::pattern::wrap_type([type, pshape, pred](const Output& value) { + auto cnode = ov::as_type_ptr(value.get_node_shared_ptr()); + if (!cnode) + return false; + + if (!type.compatible(value.get_element_type()) || !pshape.compatible(value.get_partial_shape())) { + return false; + } + if (pred && !pred(*cnode)) { + return false; + } + return true; + }); +} + +template +std::shared_ptr makeConst(const ov::element::Type& type, + const ov::Shape& shape, + std::initializer_list values) { + return std::make_shared(type, shape, std::vector(values)); +} + +inline std::shared_ptr makeConst(const std::vector& v) { + auto node = ov::pass::pattern::wrap_type(); + auto& rt_info = node->get_rt_info(); + rt_info["symbolic_const_value"] = v; + return node; +} + +template +std::shared_ptr makeConst(const ov::element::Type& type, const ov::Shape& shape, const std::vector& values) { + return std::make_shared(type, shape, values); +} + +template +std::shared_ptr makePattern(const std::vector& inputs, + detail::AttrMap attrmap = {}, + values_info vt = nullptr, + const char* friendly_name = nullptr) { + auto* p_type_info = &(T::get_type_info_static()); + OutputVector args; + for (auto& in : inputs) + args.push_back(in.get_output()); + + // pattern nodes are better for pattern matching because + // - it can be generic/incomplete, so normal OP node is not working properly + // - it has predicate to correctly decide which branch to take (in Or pattern) + auto pattern_node = std::make_shared(args, attrmap); + + if (friendly_name) { + pattern_node->set_friendly_name(friendly_name); + } else { + std::stringstream ss; + ss << p_type_info->get_version() << "::" << p_type_info->name; + ss << "("; + const char* sep = ""; + for (auto& i : args) { + ss << sep << i.get_node()->get_name(); + sep = ","; + } + ss << ")"; + pattern_node->set_friendly_name(ss.str()); + } + + auto* pnode = pattern_node.get(); + pnode->set_predicate([p_type_info, vt, pnode, friendly_name, attrmap](const Output& value) { + (void)friendly_name; + auto value_node = value.get_node_shared_ptr(); + if (!value_node->get_type_info().is_castable(*p_type_info)) { + _VERBOSE_LOG("*mismatched makePattern OP type: ", pnode->get_friendly_name(), "vs", value); + return false; + } + + if (!vt.predicate(value)) { + _VERBOSE_LOG("*mismatched makePattern value info: ", pnode->get_friendly_name(), "vs", value); + return false; + } + + auto& attr_map = pnode->get_attrs(); + if (!attr_map.empty()) { + detail::AttrMatcher visitor(attr_map); + value_node->visit_attributes(visitor); + if (!visitor.matched()) { + _VERBOSE_LOG("*mismatched attributes : ", + pnode->get_friendly_name(), + " vs ", + value_node->get_friendly_name()); + return false; + } + } + + _VERBOSE_LOG(" matched makePattern ", pnode->get_friendly_name(), " == ", value); + return true; + }); + + return pattern_node; +} + +template +std::shared_ptr makeOP(const std::vector& inputs, + detail::AttrMap attrmap = {}, + const char* friendly_name = nullptr) { + std::shared_ptr node = std::make_shared(); + + OutputVector args; + for (auto& in : inputs) + args.push_back(in.get_output()); + node->set_arguments(args); + + detail::AttrSetter visitor(attrmap); + node->visit_attributes(visitor); + + auto missing_attrs = visitor.get_missing_attrs(); + + // when some attribute is missing or is symbol, the returned + // node is suitable for pattern matching only. + OPENVINO_ASSERT(missing_attrs.size() == 0, + "missing ", + missing_attrs.size(), + " attributes : ", + missing_attrs[0], + "..."); + + if (friendly_name) + node->set_friendly_name(friendly_name); + node->constructor_validate_and_infer_types(); + return node; +} + +template +std::shared_ptr GenConst_tril(values_info vt) { + return ov::pass::pattern::wrap_type([vt](const Output& value) { + auto s1 = as_type_ptr(value.get_node_shared_ptr()); + if (!s1) { + _VERBOSE_LOG("*mismatched GenConst_tril op type: opset1::Constant vs", value); + return false; + } + + if (!vt.predicate(value)) { + _VERBOSE_LOG("*mismatched GenConst_tril values_info:", value); + return false; + } + + // ignore higher dimensions, require lowerst 2D to be lower triangular + auto shape = s1->get_output_shape(0); + auto rank = shape.size(); + if (rank < 2) { + _VERBOSE_LOG("*mismatched GenConst_tril rank < 2 (rank=", rank, ")"); + return false; + } + if (shape[rank - 1] != shape[rank - 2]) { + _VERBOSE_LOG("*mismatched GenConst_tril shape[-1] != shape[-2] : ", + shape[rank - 1], + " != ", + shape[rank - 2]); + return false; + } + // NxN const matrix + auto N = shape[rank - 1]; + std::vector output_vector = s1->cast_vector(); + // check if it's unit lower triangular matrix + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < N; j++) { + if (static_cast(output_vector[i * N + j]) != static_cast(j <= i)) + return false; + } + } + return true; + }); +} + +inline std::shared_ptr operator|(const Output& lhs, const Output& rhs) { + return std::make_shared(OutputVector{lhs, rhs}); +} + +inline std::shared_ptr operator|(const std::shared_ptr& lhs, const std::shared_ptr& rhs) { + return std::make_shared( + OutputVector{lhs->get_default_output(), rhs->get_default_output()}); +} + +inline std::shared_ptr GenSlice(detail::PatternNode data, Symbol start, Symbol stop, Symbol step, size_t axis) { + auto opt1 = makePattern({data, {start}, {stop}, {step}, {static_cast(axis)}}); + + std::vector vbegin(axis + 1, Symbol(0)); + std::vector vend(axis + 1, Symbol(0)); + std::vector vstride(axis + 1, Symbol(1)); + + vbegin[axis] = start; + vend[axis] = stop; + vstride[axis] = step; + + detail::PatternNode begin(vbegin); + detail::PatternNode end(vend); + detail::PatternNode stride(vstride); + + std::vector begin_mask(axis + 1, 1); + std::vector end_mask(axis + 1, 1); + std::vector new_axis_mask; + std::vector shrink_axis_mask; + std::vector ellipsis_mask; + + begin_mask[axis] = 0; + end_mask[axis] = 0; + + auto opt2 = makePattern({data, begin, end, stride}, + {{"begin_mask", begin_mask}, + {"end_mask", end_mask}, + {"new_axis_mask", new_axis_mask}, + {"shrink_axis_mask", shrink_axis_mask}, + {"ellipsis_mask", ellipsis_mask}}); + return opt1 | opt2; +} + +//================================================================================================== +class PatternValidator { +public: + PatternValidator(ov::pass::pattern::Matcher& m, bool force_verbose = false) { + auto saved_force_matcher_verbose = force_matcher_verbose; + force_matcher_verbose = force_verbose; + m_is_valid = validate(m); + force_matcher_verbose = saved_force_matcher_verbose; + } + + double& operator[](const char* symbol_name) { + return m_symbol_values[symbol_name]; + } + + operator bool() { + if (!m_is_valid) { + _VERBOSE_LOG("PatternValidator failed."); + } + return m_is_valid; + } + + bool validate(ov::pass::pattern::Matcher& m) { + detail::SymbolObservationVector sov; + + auto& pvmap = m.get_pattern_value_map(); + for (auto& pv : pvmap) { + auto pnode = pv.first; + auto value_node = pv.second.get_node_shared_ptr(); + auto& rt_info = pnode->get_rt_info(); + + if (auto pattern_node = std::dynamic_pointer_cast(pnode)) { + // pattern_node has no attribute and it has been matched in its predicate + if (rt_info.count("symbolic_const_value")) { + // symbolic constant node, a symbol reference is observed + auto& symbols = rt_info["symbolic_const_value"].as>(); + auto constop = std::dynamic_pointer_cast(value_node); + if (!constop) { + _VERBOSE_LOG("symbolic_const_value unexpected OP: ", value_node->get_friendly_name()); + return false; + } + auto ele_cnt = shape_size(constop->get_shape()); + auto ele_type = constop->get_element_type(); + + if (ele_cnt != symbols.size()) { + _VERBOSE_LOG("symbolic_const_value expect ", + symbols.size(), + " but got ", + ele_cnt, + " from ", + value_node->get_friendly_name()); + return false; + } + + if (ele_type == ov::element::i32 || ele_type == ov::element::f32 || ele_type == ov::element::i64) { + auto observed = constop->cast_vector(); + for (size_t i = 0; i < symbols.size(); i++) + detail::add_symbol_observed(sov, symbols[i], observed[i]); + } else { + _VERBOSE_LOG("Unexpect element type ", ele_type, " from ", value_node->get_friendly_name()); + return false; + } + } + continue; + } + if (auto pconst_node = std::dynamic_pointer_cast(pnode)) { + // const_node needs to match type/shape/value + auto vconst_node = std::dynamic_pointer_cast(value_node); + if (!vconst_node) { + _VERBOSE_LOG("expecting Constant op, but got ", value_node); + return false; + } + if (pconst_node->get_output_element_type(0) != vconst_node->get_output_element_type(0)) { + _VERBOSE_LOG("expecting Constant of type ", + pconst_node->get_output_element_type(0), + " but got ", + vconst_node); + return false; + } + // for constant node matched in pattern, a scalar constant is considered to + // be compatible with any shape with 1 element, like {}, {1,1}, {1,1,...} + const auto& expected_shape = pconst_node->get_output_shape(0); + if (expected_shape.size() == 0) { + if (shape_size(vconst_node->get_output_shape(0)) != 1) { + _VERBOSE_LOG("expecting a single element const, but got ", vconst_node); + return false; + } + } else { + if (expected_shape != vconst_node->get_output_shape(0)) { + _VERBOSE_LOG("expecting Constant of shape ", expected_shape, " but got ", vconst_node); + return false; + } + } + auto byte_size = + shape_size(vconst_node->get_output_shape(0)) * vconst_node->get_output_element_type(0).size(); + if (std::memcmp(pconst_node->get_data_ptr(), vconst_node->get_data_ptr(), byte_size) != 0) { + _VERBOSE_LOG("Constant value mismatch on ", pconst_node, " vs ", vconst_node); + return false; + } + continue; + } + + // compare attributes between them + // assume that there is no Symbol in the attributes, we need to fetch each attributes + // from + if (rt_info.count("__attrs__") == 0) { + _VERBOSE_LOG(" attr compare failed: __attrs__ not found for ", pnode->get_friendly_name()); + return false; + } + + // attr not specified is treated as not-care and ignored + // attr with symbol + + detail::AttrMap& attr_map = rt_info["__attrs__"].as(); + detail::AttrMatcher visitor(attr_map, &sov); + value_node->visit_attributes(visitor); + if (!visitor.matched()) { + _VERBOSE_LOG(" attr compare failed: ", + pnode->get_friendly_name(), + " vs ", + value_node->get_friendly_name()); + return false; + } + } + + // check symbol consistency & return independent symbols + // assign independent symbols & check literals + std::map symbol_value_map; + for (auto& ref : sov) { + auto& sym = ref.first; + auto& value = ref.second; + + if (sym.is_independent_var()) { + auto id = sym.get_id(); + if (symbol_value_map.count(id)) { + if (symbol_value_map[id] != value) { + _VERBOSE_LOG(" in-consistency between multiple references of same symbol : ", + symbol_value_map[id], + " != ", + value); + return false; + } + } else { + symbol_value_map[id] = value; + m_symbol_values[sym.get_name()] = value; + _VERBOSE_LOG("Independent Symbol: ", sym.get_name(), " = ", value); + } + } + + if (sym.is_literal_const()) { + auto literal = sym.eval(symbol_value_map); + if (literal != value) { + _VERBOSE_LOG(" mismatch between literal symbol & value : ", literal, " != ", value); + return false; + } + // no need to put literal into value map to eval them. + } + } + + // derive/eval dependent symbol's value and check against observed + for (auto& ref : sov) { + auto& sym = ref.first; + if (!sym.is_literal_const() && !sym.is_independent_var()) { + auto derived = sym.eval(symbol_value_map); + auto value = ref.second; + bool is_match; + + if (std::trunc(value) == value) { + // observed integer + is_match = (derived == value); + } else { + auto abs_diff = std::abs(derived - value); + auto avg = 0.5f * std::abs(derived + value); + if (avg != 0) { + is_match = abs_diff < avg * 1e-7; // relative error less than threshold + } else { + is_match = (derived == value); + } + } + if (!is_match) { + _VERBOSE_LOG(" mismatch between derived & value : ", + std::setprecision(std::numeric_limits::max_digits10), + derived, + " != ", + std::setprecision(std::numeric_limits::max_digits10), + value); + return false; + } + } + } + return true; + } + +private: + std::map m_symbol_values; + bool m_is_valid; +}; + +} // namespace transformation_utils +} // namespace intel_gpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 0c57b56671349c..ddcfc26965337c 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -124,6 +124,7 @@ #include "plugin/transformations/kv_cache_fusion.hpp" #include "plugin/transformations/fc_convert_fusion.hpp" #include "plugin/transformations/clamp_fp16_output.hpp" +#include "plugin/transformations/rope_fusion.hpp" #include "transformations/low_precision/mark_dequantization_subgraph.hpp" #include "low_precision/pull_reshape_through_dequantization.hpp" @@ -703,12 +704,18 @@ void TransformationsPipeline::apply(std::shared_ptr func) { manager.register_pass(); manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + // This is supposed to be the last pass to ensure that we don't have name collisions until // GPU plugin stops using friendly names for program creation manager.register_pass(true); - manager.run_passes(func); } + // ov::pass::Serialize("serialized_ir/openvino_model.xml", "openvino_model.bin").run_on_model(func); + // { + // pass::VisualizeTree("image.svg").run_on_model(func); + // } } } // namespace intel_gpu } // namespace ov diff --git a/src/plugins/intel_gpu/tests/common/gen_pattern.hpp b/src/plugins/intel_gpu/tests/common/gen_pattern.hpp new file mode 100644 index 00000000000000..a7b29c7f6c78e1 --- /dev/null +++ b/src/plugins/intel_gpu/tests/common/gen_pattern.hpp @@ -0,0 +1,1315 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "openvino/opsets/opset1.hpp" +#include "openvino/opsets/opset2.hpp" +#include "openvino/opsets/opset3.hpp" +#include "openvino/opsets/opset4.hpp" +#include "openvino/opsets/opset5.hpp" +#include "openvino/opsets/opset6.hpp" +#include "openvino/opsets/opset7.hpp" +#include "openvino/opsets/opset8.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/label.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" + +namespace ov { +namespace gen_pattern { + +static bool force_matcher_verbose = false; + +#ifdef CPU_DEBUG_CAPS + +template +static inline void _verbose_log(Args&&... args) { + std::stringstream ss; + int dummy[] = {(ss << std::forward(args) << " ", 0)...}; + (void)(dummy); + ss << std::endl; + std::cout << ss.str(); +} + +static int matcher_verbose_enabled() { + static const int enabled = std::getenv("GENP_VERBOSE") ? (atoi(std::getenv("GENP_VERBOSE"))) : 0; + return enabled; +} + +# define _VERBOSE_LOG(...) \ + if (matcher_verbose_enabled() || force_matcher_verbose) \ + _verbose_log(__VA_ARGS__) +#else +# define _VERBOSE_LOG(...) +#endif + +namespace detail { +inline std::vector split_string(const std::string& s, const std::string& delimiter) { + std::vector ret; + size_t pos = 0, pos_next; + std::string token; + while ((pos_next = s.find(delimiter, pos)) != std::string::npos) { + token = s.substr(pos, pos_next - pos); + ret.push_back(token); + pos = pos_next + 1; + } + // return whole string if no delimiter if found + token = s.substr(pos, pos_next); + ret.push_back(token); + return ret; +} + +template +std::string vec2str(const std::vector& vec, int cnt_limit = 9) { + std::stringstream ss; + ss << "{"; + const char* sep = ""; + for (auto& v : vec) { + cnt_limit--; + if (cnt_limit == 0) { + ss << sep << "..."; + break; + } + ss << sep << v; + sep = ","; + } + ss << "}"; + return ss.str(); +} +} // namespace detail + +struct values_info { + values_info(const char* pattern_list = nullptr) { + if (pattern_list == nullptr || pattern_list[0] == 0) { + all_type_pshape.clear(); + return; + } + auto pattern_vector = detail::split_string(pattern_list, " "); + for (auto& pattern : pattern_vector) { + if (pattern[0] == '[') { + all_type_pshape.emplace_back(ov::element::dynamic, ov::PartialShape(pattern)); + } else { + auto sep = pattern.find("["); + if (sep != std::string::npos) { + // ele_type[p_shape] + all_type_pshape.emplace_back(ov::element::Type(pattern.substr(0, sep)), + ov::PartialShape(pattern.substr(sep))); + } else { + // ele_type + all_type_pshape.emplace_back(ov::element::Type(pattern), ov::PartialShape::dynamic()); + } + } + } + } + + size_t size() { + return all_type_pshape.size(); + } + const std::pair& operator[](int index) { + return all_type_pshape[index]; + } + + //------------------------------------------------------------- + bool predicate(const ov::Output& value) const { + if (all_type_pshape.empty()) + return true; + auto index = value.get_index(); + auto& item = all_type_pshape[index]; + if (!item.first.compatible(value.get_element_type()) || !item.second.compatible(value.get_partial_shape())) { + _VERBOSE_LOG("* mismatched vtype between value & pattern : ", + value.get_element_type(), + value.get_partial_shape(), + "vs", + item.first, + item.second); + return false; + } + return true; + } + + std::string to_string() { + std::stringstream ss; + const char* sep = ""; + for (auto& t : all_type_pshape) { + ss << sep << t.first << t.second; + sep = ";"; + } + return ss.str(); + } + + std::vector> all_type_pshape; +}; + +// Symbol : a constant that unknown at the pattern's building time +// but collected and validated after pattern was matched +// with some sub-graph values. +class Symbol { +private: + struct Entity { + const char* name = "?"; + char op; + double literal_const_value; + std::shared_ptr lhs; + std::shared_ptr rhs; + // _,+,-,*,/ + // l : literal const + // n : named symbol + double eval(const std::map& value_map) const { + switch (op) { + case 'l': + return literal_const_value; + case 'n': + return value_map.at(this); + case '+': + return lhs->eval(value_map) + rhs->eval(value_map); + case '-': + return lhs->eval(value_map) - rhs->eval(value_map); + case '*': + return lhs->eval(value_map) * rhs->eval(value_map); + case '/': + return lhs->eval(value_map) / rhs->eval(value_map); + case '_': + return -lhs->eval(value_map); + case 'r': + return std::sqrt(lhs->eval(value_map)); + default: + assert(false); + return std::numeric_limits::quiet_NaN(); + } + } + }; + std::shared_ptr entity; + +public: + Symbol() { + entity = std::make_shared(); + entity->op = 'n'; + } + Symbol(const char* name) { + entity = std::make_shared(); + entity->op = 'n'; + entity->name = name; + } + Symbol(const int value) { + entity = std::make_shared(); + entity->op = 'l'; + entity->literal_const_value = value; + } + Symbol(char op, const Symbol& lhs, const Symbol& rhs) { + entity = std::make_shared(); + entity->op = op; + entity->lhs = lhs.entity; + entity->rhs = rhs.entity; + } + double eval(const std::map& value_map) const { + return entity->eval(value_map); + } + bool is_independent_var() const { + return entity->op == 'n'; + } + int is_literal_const() const { + return entity->op == 'l'; + } + char get_op() const { + return entity->op; + } + void* get_id() const { + return entity.get(); + } + const char* get_name() const { + return entity->name; + } + bool operator<(const Symbol& rhs) const { + return get_id() < rhs.get_id(); + } +}; + +inline Symbol operator-(const Symbol& lhs) { + return Symbol('_', lhs, lhs); +} +inline Symbol operator+(const Symbol& lhs, const Symbol& rhs) { + return Symbol('+', lhs, rhs); +} +inline Symbol operator-(const Symbol& lhs, const Symbol& rhs) { + return Symbol('-', lhs, rhs); +} +inline Symbol operator*(const Symbol& lhs, const Symbol& rhs) { + return Symbol('*', lhs, rhs); +} +inline Symbol operator/(const Symbol& lhs, const Symbol& rhs) { + return Symbol('/', lhs, rhs); +} +inline Symbol sqrt(Symbol lhs) { + return Symbol('r', lhs, lhs); +} + +namespace detail { + +// AttrAny is simple wrapper of Any to provide some constructor +// to take advantage of C++ implicit conversion to allow: +// - attribute expressed using initializer_list. +// - symbol to be used as attributes +struct AttrAny { + ov::Any any; + + // empty attribute, means empty vector, and error for scalar + AttrAny() {} + + AttrAny(const Symbol& v) : any(v) {} + AttrAny(const ov::element::Type& v) : any(v) {} + AttrAny(const ov::PartialShape& v) : any(v) {} + AttrAny(const ov::Dimension& v) : any(v) {} + AttrAny(bool v) : any(v) {} + AttrAny(int v) : any(v) {} + AttrAny(float v) : any(v) {} + AttrAny(double v) : any(v) {} + AttrAny(long v) : any(static_cast(v)) {} + AttrAny(long long v) : any(static_cast(v)) {} + AttrAny(const char* v) : any(v) {} + AttrAny(const std::string& v) : any(v) {} + + // template ::value>::type = true> + // AttrAny(const T& v) : any(v) {} + + // template ::value>::type = true> + // AttrAny(const std::vector& v) : any(v) {} + + AttrAny(const std::vector& v) : any(v) {} + + // template ::value>::type = true> + // AttrAny(std::initializer_list values) : any(std::vector(values)) {} + AttrAny(std::initializer_list values) : any(std::vector(values)) {} + AttrAny(std::initializer_list values) : any(std::vector(values.begin(), values.end())) {} + AttrAny(std::initializer_list values) : any(std::vector(values)) {} + AttrAny(std::initializer_list values) : any(std::vector(values)) {} + AttrAny(std::initializer_list values) : any(std::vector(values.begin(), values.end())) {} + + AttrAny(std::initializer_list values) : any(std::vector(values)) {} + AttrAny(std::initializer_list values) : any(std::vector(values)) {} + + std::string as_string() { + if (any.is()) + return any.as(); + return any.as(); + } + bool as_bool() { + if (any.is()) + return any.as(); + return any.as(); + } + double as_double() { + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + return any.as(); + } + int64_t as_int64_t() { + if (any.is()) + return any.as(); + return any.as(); + } + + template + std::vector as_vector() { + if (any.empty()) + return {}; + if (!std::is_same::value) { + if (any.is>()) { + auto ivec = any.as>(); + return std::vector(ivec.begin(), ivec.end()); + } + if (any.is>()) { + auto vec = any.as>(); + return std::vector(vec.begin(), vec.end()); + } + } + if (!std::is_same::value) { + if (any.is>()) { + auto ivec = any.as>(); + return std::vector(ivec.begin(), ivec.end()); + } + if (any.is>()) { + auto vec = any.as>(); + return std::vector(vec.begin(), vec.end()); + } + } + if (any.is>()) { + auto ivec = any.as>(); + return std::vector(ivec.begin(), ivec.end()); + } + return any.as>(); + } + + template + std::vector as_T_vector() { + if (any.empty()) + return {}; + if (any.is()) { + auto to_vec = [](std::initializer_list v) { + return std::vector(v); + }; + return to_vec({any.as()}); + } + if (any.is>()) { + auto ivec = any.as>(); + return std::vector(ivec.begin(), ivec.end()); + } + return any.as>(); + } + + std::vector as_str_vector() { + if (any.empty()) + return {}; + if (any.is>()) { + auto vec = any.as>(); + return std::vector(vec.begin(), vec.end()); + } + return any.as>(); + } + + template + T cast_to() { + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + if (any.is()) + return any.as(); + return any.as(); + } + + template + bool equal_to(const std::vector& rhs) { + if (any.empty() && rhs.empty()) + return true; + auto& vec = any.as>(); + return std::equal(vec.begin(), vec.end(), rhs.begin()); + } + + template + bool equal_to(const std::vector& rhs) { + if (any.empty() && rhs.empty()) + return true; + + if (any.is>()) { + auto& vec = any.as>(); + return vec.size() == rhs.size() && std::equal(vec.begin(), vec.end(), rhs.begin()); + } + return equal_to(rhs); + } + + template + typename std::enable_if::value, bool>::type equal_to(const T& rhs) { + return rhs == any.as(); + } + + template + typename std::enable_if::value, bool>::type equal_to(const T& rhs) { + if (any.is()) { + auto& value = any.as(); + return rhs == static_cast(value); + } + return equal_to(rhs); + } +}; + +using AttrMap = std::map; + +class AttrSetter : public ov::AttributeVisitor { +public: + AttrMap& m_attr_map; + std::vector m_missing_attrs; + + AttrSetter(AttrMap& attrs) : m_attr_map(attrs) {} + + const std::vector& get_missing_attrs() { + return m_missing_attrs; + } + + bool should_skip(const std::string& name) { + if (m_attr_map.count(name) == 0) { + // attributes not specified is recorded as missing + m_missing_attrs.push_back(name); + return true; + } + + if (m_attr_map[name].any.is()) { + m_missing_attrs.push_back(name); + return true; + } + + if (m_attr_map[name].any.empty()) { + // input is set to empty, meaning default value is used. + return true; + } + return false; + } + + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_string()); + } + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_bool()); + } + void on_adapter(const std::string& name, ov::ValueAccessor& adapter) override { + if (should_skip(name)) + return; + auto& any = m_attr_map[name].any; + if (auto a = ov::as_type>(&adapter)) { + static_cast(*a) = any.as(); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_vector()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_vector()); + } else if (auto a = ov::as_type>>(&adapter)) { +#if defined(__APPLE__) || defined(__EMSCRIPTEN__) + static_cast&>(*a) = m_attr_map[name].as_vector(); +#else + a->set(m_attr_map[name].as_vector()); +#endif + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_vector()); + //} else if (auto a = ov::as_type>(&adapter)) { + // a->set(m_attr_map[name].as_string()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_string()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_vector()); + } else if (auto a = ov::as_type>(&adapter)) { + a->set(m_attr_map[name].as_T_vector()); + } else { + OPENVINO_THROW("unsupported AttributeAdapter for attribute : ", name); + } + } + + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_double()); + } + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_int64_t()); + } + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_vector()); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_vector()); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_vector()); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + value.set(m_attr_map[name].as_str_vector()); + } +}; + +class GenericPattern : public ov::pass::pattern::op::Pattern { +public: + OPENVINO_RTTI("GenericPattern"); + + explicit GenericPattern(const OutputVector& args = {}, const detail::AttrMap& attrs = {}) + : ov::pass::pattern::op::Pattern(args) { + set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic()); + m_attrs = attrs; + } + + // this allows code inside pred to access pattern node itself + void set_predicate(ov::pass::pattern::op::ValuePredicate pred) { + m_predicate = pred; + } + + bool match_value(ov::pass::pattern::Matcher* matcher, + const Output& pattern_value, + const Output& graph_value) override { + // strictly requires pattern & graph value to come from output port with same index, + // this is absolute necessary when pattern contains split node connections. + if (pattern_value.get_index() != graph_value.get_index()) + return false; + if (m_predicate(graph_value)) { + auto& pattern_map = matcher->get_pattern_value_map(); + pattern_map[shared_from_this()] = graph_value; + matcher->add_node(graph_value); + return (get_input_size() == 0 + ? true + : matcher->match_arguments(pattern_value.get_node(), graph_value.get_node_shared_ptr())); + } + return false; + } + + detail::AttrMap& get_attrs() { + return m_attrs; + } + +private: + detail::AttrMap m_attrs; +}; + +// A glue/syntax-sugar type which allows more types to be used as input to makePattern() +struct PatternNode { + std::shared_ptr node; + int output_port = -1; + + operator ov::Output() const { + return get_output(); + } + + ov::Output get_output() const { + if (output_port >= 0) + return node->output(output_port); + return node->get_default_output(); + } + + PatternNode(const Output& out) : node(out.get_node_shared_ptr()), output_port(out.get_index()) {} + + PatternNode() { + node = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank()); + } + PatternNode(ov::Rank rank) { + node = ov::pass::pattern::any_input([rank](const Output& value) { + if (!rank.compatible(value.get_partial_shape().rank())) { + _VERBOSE_LOG("*mismatched PatternNode rank ", value, " expecting ", rank); + return false; + } + return true; + }); + } + + PatternNode(values_info vt) { + node = ov::pass::pattern::any_input([vt](const Output& value) { + if (!vt.predicate(value)) { + _VERBOSE_LOG("*mismatched PatternNode ", value); + return false; + } + _VERBOSE_LOG(" matched PatternNode ", value); + return true; + }); + } + PatternNode(const std::shared_ptr& node) : node(node) {} + PatternNode(const std::shared_ptr& node) : node(node) {} + PatternNode(const std::shared_ptr& pattern) + : node(std::dynamic_pointer_cast(pattern)) {} + + // 1D-vector & scalar of symbol + PatternNode(std::initializer_list v) { + // initializer_list of Symbol ls special, need to be recorded + // and eval/check in the callback after whole match is complete, + // where all observed actual constant values are known, first + // we will go over all symbols and collect actual value for individual + // symbol(named symbol), and then we go over all derived symbols and + // evaluate their predicated values and compare against what observed, + // and check if they all match. + // node = ConstVector(std::vector(v), nullptr); + node = ov::pass::pattern::wrap_type(); + + auto& rt_info = node->get_rt_info(); + rt_info["symbolic_const_value"] = std::vector(v); + } + PatternNode(const std::vector& v) { + node = ov::pass::pattern::wrap_type(); + auto& rt_info = node->get_rt_info(); + rt_info["symbolic_const_value"] = v; + } + + PatternNode(Symbol v) { + node = ov::pass::pattern::wrap_type(); + auto& rt_info = node->get_rt_info(); + rt_info["symbolic_const_value"] = std::vector({v}); + } + + // scalar constant (treated as wildcard for single-element-constant with any rank) + PatternNode(int v) : node(std::make_shared(element::from(), Shape({}), v)) {} + PatternNode(float v) : node(std::make_shared(element::from(), Shape({}), v)) {} + + PatternNode(std::initializer_list v, values_info vi = nullptr) { + node = ConstVector(std::vector(v), vi); + } + PatternNode(std::initializer_list v, values_info vi = nullptr) { + node = ConstVector(std::vector(v), vi); + } + PatternNode(std::initializer_list v, values_info vi = nullptr) { + node = ConstVector(std::vector(v.begin(), v.end()), vi); + } + PatternNode(std::initializer_list v, values_info vi = nullptr) { + node = ConstVector(std::vector(v.begin(), v.end()), vi); + } + + // 1d const tensor or scalar + template ::value, bool>::type = true> + static std::shared_ptr ConstVector(const std::vector& vec, values_info vi = nullptr) { + if (vi.size() > 0) + return std::make_shared(vi[0].first, vi[0].second.to_shape(), vec); + // initializer_list w/o value_info means to create normal 1D vector + return std::make_shared(element::from(), Shape({vec.size()}), vec); + } +}; + +using SymbolObservationVector = std::vector>; + +template +void add_symbol_observed(SymbolObservationVector& sov, const Symbol& sym, const T& value) { + auto v = static_cast(value); + OPENVINO_ASSERT(static_cast(v) == value); // ensure there is no precison lost in double + sov.push_back(std::make_pair(sym, v)); +} +/* +template +static bool vector_equal_to_any(const std::vector& v0, detail::AttrAny& any) { + auto v1 = any.cast_to_vector(); + if (v0.size() != v1.size()) + return false; + return std::equal(v0.begin(), v0.end(), v1.begin()); +} + +template +static bool scalar_equal_to_any(const T& v0, detail::AttrAny& any) { + if (any.is()) { + return v0 == any.as(); + } else if (any.is()) { + return v0 == any.as(); + } + return v0 == any.as(); +} +*/ +// for arithmetic data type, Attr matcher will success as long as the actuall attributes +// is equal to the casted attributes from pattern w/o requiring exact type match. +class AttrMatcher : public ov::AttributeVisitor { +public: + AttrMap& m_attr_map; + std::vector m_missing_attrs; + SymbolObservationVector* m_psov; + bool m_all_matched; + + AttrMatcher(AttrMap& attrs, SymbolObservationVector* psov = nullptr) + : m_attr_map(attrs), + m_psov(psov), + m_all_matched(true) {} + + bool matched() { + return m_all_matched; + } + + const std::vector& get_missing_attrs() { + return m_missing_attrs; + } + + bool should_skip(const std::string& name, bool allow_symbol = false) { + if (m_attr_map.count(name) == 0) { + m_missing_attrs.push_back(name); + return true; + } + + if (!allow_symbol) { + OPENVINO_ASSERT(!m_attr_map[name].any.is(), "Symbol is not allowed."); + } + return false; + } + + void add_match_result(const std::string& name, bool is_matched) { + if (!is_matched) { + _VERBOSE_LOG(" attribute '", name, "' mismatch."); + } + m_all_matched = m_all_matched && is_matched; + } + + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + add_match_result(name, value.get() == m_attr_map[name].as_string()); + } + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + + void on_adapter(const std::string& name, ov::ValueAccessor>& value) override { + if (should_skip(name)) + return; + add_match_result(name, m_attr_map[name].equal_to(value.get())); + } + + // only integer is allowed to be of symbol type + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name, true)) + return; + auto& any = m_attr_map[name].any; + if (any.is()) { + if (m_psov) { + // collect symbol reference and do comparison later + add_symbol_observed(*m_psov, any.as(), value.get()); + } + return; + } + add_match_result(name, m_attr_map[name].cast_to() == value.get()); + } + void on_adapter(const std::string& name, ov::ValueAccessor& value) override { + if (should_skip(name, true)) + return; + auto& any = m_attr_map[name].any; + if (any.is()) { + if (m_psov) { + // collect symbol reference and do comparison later + add_symbol_observed(*m_psov, any.as(), value.get()); + } + return; + } + add_match_result(name, m_attr_map[name].cast_to() == value.get()); + } + + void on_adapter(const std::string& name, ov::ValueAccessor& adapter) override { + if (should_skip(name)) + return; + OPENVINO_ASSERT(m_attr_map.count(name) > 0); + auto& any = m_attr_map[name].any; + bool is_matched = true; + if (auto a = ov::as_type>(&adapter)) { + is_matched = (static_cast(*a) == any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = (a->get() == any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = (a->get() == any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = m_attr_map[name].equal_to(a->get()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = m_attr_map[name].equal_to(a->get()); + } else if (auto a = ov::as_type>>(&adapter)) { +#if defined(__APPLE__) || defined(__EMSCRIPTEN__) + is_matched = m_attr_map[name].equal_to(static_cast&>(*a)); +#else + is_matched = m_attr_map[name].equal_to(a->get()); +#endif + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = m_attr_map[name].equal_to(a->get()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = (a->get() == any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = (a->get() == any.as()); + } else if (auto a = ov::as_type>(&adapter)) { + is_matched = m_attr_map[name].equal_to(a->get()); + } else { + OPENVINO_THROW("AttrMatcher met unsupported AttributeAdapter ", name); + } + add_match_result(name, is_matched); + } +}; +} // namespace detail + +//================================================================================================== + +inline std::shared_ptr GenInput(values_info vt = nullptr) { + return ov::pass::pattern::any_input([vt](const Output& value) { + if (!vt.predicate(value)) { + _VERBOSE_LOG("*mismatched GenInput ", value); + return false; + } + _VERBOSE_LOG(" matched GenInput ", value); + return true; + }); +} + +inline std::shared_ptr makePattern() { + detail::PatternNode g; + return g.node; +} + +inline std::shared_ptr makePattern(ov::Rank rank) { + detail::PatternNode g(rank); + return g.node; +} + +inline std::shared_ptr makePattern(values_info vt) { + detail::PatternNode g(vt); + return g.node; +} + +// unknown const +inline std::shared_ptr makeConst(const ov::element::Type& type, + const ov::PartialShape& pshape, + std::function pred) { + return ov::pass::pattern::wrap_type([type, pshape, pred](const Output& value) { + auto cnode = ov::as_type_ptr(value.get_node_shared_ptr()); + if (!cnode) + return false; + + if (!type.compatible(value.get_element_type()) || !pshape.compatible(value.get_partial_shape())) { + return false; + } + if (pred && !pred(*cnode)) { + return false; + } + return true; + }); +} + +template +std::shared_ptr makeConst(const ov::element::Type& type, + const ov::Shape& shape, + std::initializer_list values) { + return std::make_shared(type, shape, std::vector(values)); +} + +inline std::shared_ptr makeConst(const std::vector& v) { + auto node = ov::pass::pattern::wrap_type(); + auto& rt_info = node->get_rt_info(); + rt_info["symbolic_const_value"] = v; + return node; +} + +template +std::shared_ptr makeConst(const ov::element::Type& type, const ov::Shape& shape, const std::vector& values) { + return std::make_shared(type, shape, values); +} + +template +std::shared_ptr makePattern(const std::vector& inputs, + detail::AttrMap attrmap = {}, + values_info vt = nullptr, + const char* friendly_name = nullptr) { + auto* p_type_info = &(T::get_type_info_static()); + OutputVector args; + for (auto& in : inputs) + args.push_back(in.get_output()); + + // pattern nodes are better for pattern matching because + // - it can be generic/incomplete, so normal OP node is not working properly + // - it has predicate to correctly decide which branch to take (in Or pattern) + auto pattern_node = std::make_shared(args, attrmap); + + if (friendly_name) { + pattern_node->set_friendly_name(friendly_name); + } else { + std::stringstream ss; + ss << p_type_info->get_version() << "::" << p_type_info->name; + ss << "("; + const char* sep = ""; + for (auto& i : args) { + ss << sep << i.get_node()->get_name(); + sep = ","; + } + ss << ")"; + pattern_node->set_friendly_name(ss.str()); + } + + auto* pnode = pattern_node.get(); + pnode->set_predicate([p_type_info, vt, pnode, friendly_name, attrmap](const Output& value) { + (void)friendly_name; + auto value_node = value.get_node_shared_ptr(); + if (!value_node->get_type_info().is_castable(*p_type_info)) { + _VERBOSE_LOG("*mismatched makePattern OP type: ", pnode->get_friendly_name(), "vs", value); + return false; + } + + if (!vt.predicate(value)) { + _VERBOSE_LOG("*mismatched makePattern value info: ", pnode->get_friendly_name(), "vs", value); + return false; + } + + auto& attr_map = pnode->get_attrs(); + if (!attr_map.empty()) { + detail::AttrMatcher visitor(attr_map); + value_node->visit_attributes(visitor); + if (!visitor.matched()) { + _VERBOSE_LOG("*mismatched attributes : ", + pnode->get_friendly_name(), + " vs ", + value_node->get_friendly_name()); + return false; + } + } + + _VERBOSE_LOG(" matched makePattern ", pnode->get_friendly_name(), " == ", value); + return true; + }); + + return pattern_node; +} + +template +std::shared_ptr makeOP(const std::vector& inputs, + detail::AttrMap attrmap = {}, + const char* friendly_name = nullptr) { + std::shared_ptr node = std::make_shared(); + + OutputVector args; + for (auto& in : inputs) + args.push_back(in.get_output()); + node->set_arguments(args); + + detail::AttrSetter visitor(attrmap); + node->visit_attributes(visitor); + + auto missing_attrs = visitor.get_missing_attrs(); + + // when some attribute is missing or is symbol, the returned + // node is suitable for pattern matching only. + OPENVINO_ASSERT(missing_attrs.size() == 0, + "missing ", + missing_attrs.size(), + " attributes : ", + missing_attrs[0], + "..."); + + if (friendly_name) + node->set_friendly_name(friendly_name); + node->constructor_validate_and_infer_types(); + return node; +} + +template +std::shared_ptr GenConst_tril(values_info vt) { + return ov::pass::pattern::wrap_type([vt](const Output& value) { + auto s1 = as_type_ptr(value.get_node_shared_ptr()); + if (!s1) { + _VERBOSE_LOG("*mismatched GenConst_tril op type: opset1::Constant vs", value); + return false; + } + + if (!vt.predicate(value)) { + _VERBOSE_LOG("*mismatched GenConst_tril values_info:", value); + return false; + } + + // ignore higher dimensions, require lowerst 2D to be lower triangular + auto shape = s1->get_output_shape(0); + auto rank = shape.size(); + if (rank < 2) { + _VERBOSE_LOG("*mismatched GenConst_tril rank < 2 (rank=", rank, ")"); + return false; + } + if (shape[rank - 1] != shape[rank - 2]) { + _VERBOSE_LOG("*mismatched GenConst_tril shape[-1] != shape[-2] : ", + shape[rank - 1], + " != ", + shape[rank - 2]); + return false; + } + // NxN const matrix + auto N = shape[rank - 1]; + std::vector output_vector = s1->cast_vector(); + // check if it's unit lower triangular matrix + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < N; j++) { + if (static_cast(output_vector[i * N + j]) != static_cast(j <= i)) + return false; + } + } + return true; + }); +} + +inline std::shared_ptr operator|(const Output& lhs, const Output& rhs) { + return std::make_shared(OutputVector{lhs, rhs}); +} + +inline std::shared_ptr operator|(const std::shared_ptr& lhs, const std::shared_ptr& rhs) { + return std::make_shared( + OutputVector{lhs->get_default_output(), rhs->get_default_output()}); +} + +inline std::shared_ptr GenSlice(detail::PatternNode data, Symbol start, Symbol stop, Symbol step, size_t axis) { + auto opt1 = makePattern({data, {start}, {stop}, {step}, {static_cast(axis)}}); + + std::vector vbegin(axis + 1, Symbol(0)); + std::vector vend(axis + 1, Symbol(0)); + std::vector vstride(axis + 1, Symbol(1)); + + vbegin[axis] = start; + vend[axis] = stop; + vstride[axis] = step; + + detail::PatternNode begin(vbegin); + detail::PatternNode end(vend); + detail::PatternNode stride(vstride); + + std::vector begin_mask(axis + 1, 1); + std::vector end_mask(axis + 1, 1); + std::vector new_axis_mask; + std::vector shrink_axis_mask; + std::vector ellipsis_mask; + + begin_mask[axis] = 0; + end_mask[axis] = 0; + + auto opt2 = makePattern({data, begin, end, stride}, + {{"begin_mask", begin_mask}, + {"end_mask", end_mask}, + {"new_axis_mask", new_axis_mask}, + {"shrink_axis_mask", shrink_axis_mask}, + {"ellipsis_mask", ellipsis_mask}}); + return opt1 | opt2; +} + +//================================================================================================== +class PatternValidator { +public: + PatternValidator(ov::pass::pattern::Matcher& m, bool force_verbose = false) { + auto saved_force_matcher_verbose = force_matcher_verbose; + force_matcher_verbose = force_verbose; + m_is_valid = validate(m); + force_matcher_verbose = saved_force_matcher_verbose; + } + + double& operator[](const char* symbol_name) { + return m_symbol_values[symbol_name]; + } + + operator bool() { + if (!m_is_valid) { + _VERBOSE_LOG("PatternValidator failed."); + } + return m_is_valid; + } + + bool validate(ov::pass::pattern::Matcher& m) { + detail::SymbolObservationVector sov; + + auto& pvmap = m.get_pattern_value_map(); + for (auto& pv : pvmap) { + auto pnode = pv.first; + auto value_node = pv.second.get_node_shared_ptr(); + auto& rt_info = pnode->get_rt_info(); + + if (auto pattern_node = std::dynamic_pointer_cast(pnode)) { + // pattern_node has no attribute and it has been matched in its predicate + if (rt_info.count("symbolic_const_value")) { + // symbolic constant node, a symbol reference is observed + auto& symbols = rt_info["symbolic_const_value"].as>(); + auto constop = std::dynamic_pointer_cast(value_node); + if (!constop) { + _VERBOSE_LOG("symbolic_const_value unexpected OP: ", value_node->get_friendly_name()); + return false; + } + auto ele_cnt = shape_size(constop->get_shape()); + auto ele_type = constop->get_element_type(); + + if (ele_cnt != symbols.size()) { + _VERBOSE_LOG("symbolic_const_value expect ", + symbols.size(), + " but got ", + ele_cnt, + " from ", + value_node->get_friendly_name()); + return false; + } + + if (ele_type == ov::element::i32 || ele_type == ov::element::f32 || ele_type == ov::element::i64) { + auto observed = constop->cast_vector(); + for (size_t i = 0; i < symbols.size(); i++) + detail::add_symbol_observed(sov, symbols[i], observed[i]); + } else { + _VERBOSE_LOG("Unexpect element type ", ele_type, " from ", value_node->get_friendly_name()); + return false; + } + } + continue; + } + if (auto pconst_node = std::dynamic_pointer_cast(pnode)) { + // const_node needs to match type/shape/value + auto vconst_node = std::dynamic_pointer_cast(value_node); + if (!vconst_node) { + _VERBOSE_LOG("expecting Constant op, but got ", value_node); + return false; + } + if (pconst_node->get_output_element_type(0) != vconst_node->get_output_element_type(0)) { + _VERBOSE_LOG("expecting Constant of type ", + pconst_node->get_output_element_type(0), + " but got ", + vconst_node); + return false; + } + // for constant node matched in pattern, a scalar constant is considered to + // be compatible with any shape with 1 element, like {}, {1,1}, {1,1,...} + const auto& expected_shape = pconst_node->get_output_shape(0); + if (expected_shape.size() == 0) { + if (shape_size(vconst_node->get_output_shape(0)) != 1) { + _VERBOSE_LOG("expecting a single element const, but got ", vconst_node); + return false; + } + } else { + if (expected_shape != vconst_node->get_output_shape(0)) { + _VERBOSE_LOG("expecting Constant of shape ", expected_shape, " but got ", vconst_node); + return false; + } + } + auto byte_size = + shape_size(vconst_node->get_output_shape(0)) * vconst_node->get_output_element_type(0).size(); + if (std::memcmp(pconst_node->get_data_ptr(), vconst_node->get_data_ptr(), byte_size) != 0) { + _VERBOSE_LOG("Constant value mismatch on ", pconst_node, " vs ", vconst_node); + return false; + } + continue; + } + + // compare attributes between them + // assume that there is no Symbol in the attributes, we need to fetch each attributes + // from + if (rt_info.count("__attrs__") == 0) { + _VERBOSE_LOG(" attr compare failed: __attrs__ not found for ", pnode->get_friendly_name()); + return false; + } + + // attr not specified is treated as not-care and ignored + // attr with symbol + + detail::AttrMap& attr_map = rt_info["__attrs__"].as(); + detail::AttrMatcher visitor(attr_map, &sov); + value_node->visit_attributes(visitor); + if (!visitor.matched()) { + _VERBOSE_LOG(" attr compare failed: ", + pnode->get_friendly_name(), + " vs ", + value_node->get_friendly_name()); + return false; + } + } + + // check symbol consistency & return independent symbols + // assign independent symbols & check literals + std::map symbol_value_map; + for (auto& ref : sov) { + auto& sym = ref.first; + auto& value = ref.second; + + if (sym.is_independent_var()) { + auto id = sym.get_id(); + if (symbol_value_map.count(id)) { + if (symbol_value_map[id] != value) { + _VERBOSE_LOG(" in-consistency between multiple references of same symbol : ", + symbol_value_map[id], + " != ", + value); + return false; + } + } else { + symbol_value_map[id] = value; + m_symbol_values[sym.get_name()] = value; + _VERBOSE_LOG("Independent Symbol: ", sym.get_name(), " = ", value); + } + } + + if (sym.is_literal_const()) { + auto literal = sym.eval(symbol_value_map); + if (literal != value) { + _VERBOSE_LOG(" mismatch between literal symbol & value : ", literal, " != ", value); + return false; + } + // no need to put literal into value map to eval them. + } + } + + // derive/eval dependent symbol's value and check against observed + for (auto& ref : sov) { + auto& sym = ref.first; + if (!sym.is_literal_const() && !sym.is_independent_var()) { + auto derived = sym.eval(symbol_value_map); + auto value = ref.second; + bool is_match; + + if (std::trunc(value) == value) { + // observed integer + is_match = (derived == value); + } else { + auto abs_diff = std::abs(derived - value); + auto avg = 0.5f * std::abs(derived + value); + if (avg != 0) { + is_match = abs_diff < avg * 1e-7; // relative error less than threshold + } else { + is_match = (derived == value); + } + } + if (!is_match) { + _VERBOSE_LOG(" mismatch between derived & value : ", + std::setprecision(std::numeric_limits::max_digits10), + derived, + " != ", + std::setprecision(std::numeric_limits::max_digits10), + value); + return false; + } + } + } + return true; + } + +private: + std::map m_symbol_values; + bool m_is_valid; +}; + +} // namespace gen_pattern +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp new file mode 100644 index 00000000000000..fee1fc5bdbe736 --- /dev/null +++ b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp @@ -0,0 +1,488 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +// #include + +#include +#include +#include +#include +#include +#include +#include + +#include "gen_pattern.hpp" +#include "common_test_utils/common_utils.hpp" +#include "functional_test_utils/skip_tests_config.hpp" +#include "shared_test_classes/base/ov_subgraph.hpp" +// #include "test_utils/gpu_test_utils.hpp" +// #include "test_utils/fusing_test_utils.hpp" +// #include "utils/gen_pattern.hpp" + +// using namespace GPUTestUtils; +using namespace ov::gen_pattern; +using namespace ov; + +namespace ov { +namespace test { + +inline void CheckNumberOfNodesWithType(std::shared_ptr function, + const std::unordered_set& nodeTypes, + size_t expectedCount) { + ASSERT_NE(nullptr, function); + int num_ops = 0; + for (const auto& node : function->get_ordered_ops()) { + const auto& rt_info = node->get_rt_info(); + const auto layer_type = rt_info.find("layerType")->second.as(); + std::cout << layer_type << std::endl; + if (nodeTypes.count(layer_type)) { + num_ops++; + } + } + ASSERT_EQ(num_ops, expectedCount); +} + +static ov::OutputVector makeCosSinCache(int max_position_embeddings, int rotary_ndims) { + std::vector lut_sin(max_position_embeddings * rotary_ndims, 0.0f); + std::vector lut_cos(max_position_embeddings * rotary_ndims, 0.0f); + + // rotate_half style cos/sin table: + // y1 = cos(m*xita_i) * x1 - sin(m*xita_i) * x2 + // y2 = cos(m*xita_i) * x2 + sin(m*xita_i) * x1 + // + + for (int i = 0, k = 0; i < rotary_ndims; i += 2, k++) { + auto xita_i = 1.0 / std::pow(10000.0, static_cast(i) / rotary_ndims); + float* psin = lut_sin.data(); + float* pcos = lut_cos.data(); + for (int m = 0; m < max_position_embeddings; m++, psin += rotary_ndims, pcos += rotary_ndims) { + auto vsin = std::sin(xita_i * m); + auto vcos = std::cos(xita_i * m); + pcos[k] = pcos[k + rotary_ndims / 2] = vcos; + psin[k] = psin[k + rotary_ndims / 2] = vsin; + } + } + auto shape = ov::Shape({1, 1, static_cast(max_position_embeddings), static_cast(rotary_ndims)}); + auto Cos = makeConst(ov::element::f32, shape, lut_cos); + auto Sin = makeConst(ov::element::f32, shape, lut_sin); + return {Cos, Sin}; +} + +static std::shared_ptr buildROPE_Llama2(const int batch, + const int seq_length, + const int max_position_embeddings, + const int num_head, + const int ndims) { + auto input = std::make_shared(ov::element::f32, PartialShape{batch, -1, num_head, ndims}); + auto pos_id_end = std::make_shared(ov::element::i32, ov::Shape{}); + auto pos_ids = std::make_shared(ov::element::i32, PartialShape{1, -1}); + + auto cos_sin_cache = makeCosSinCache(max_position_embeddings, ndims); + auto Constant582 = cos_sin_cache[0]; + auto Constant585 = cos_sin_cache[1]; + + // concat KV length + auto transpose_Transpose = makeOP({input, {0, 2, 1, 3}}); + auto slice_Unsqueeze_426 = makeOP({pos_id_end, 0}); + auto ScatterUpdate_152236 = makeOP({{0, 0, 0}, {2}, slice_Unsqueeze_426, {0}}); + auto slice_Slice = makeOP({Constant582, {0, 0, 0}, ScatterUpdate_152236, {1, 1, 1}}, + {{"begin_mask", {1, 1, 0}}, + {"end_mask", {1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto squeeze_Squeeze = makeOP({slice_Slice, 1}); + auto squeeze_Squeeze_435 = makeOP({squeeze_Squeeze, 0}); + auto index_441_Gather = makeOP({squeeze_Squeeze_435, pos_ids, 0}, {{"batch_dims", 0}}); + auto unsqueeze_Unsqueeze = makeOP({index_441_Gather, 1}); + auto mul_Multiply = + makeOP({transpose_Transpose, unsqueeze_Unsqueeze}, {{"auto_broadcast", "numpy"}}); + auto size_ShapeOf_448 = makeOP({transpose_Transpose}, {{"output_type", "i32"}}); + auto size_Gather_450 = makeOP({size_ShapeOf_448, 3, 0}, {{"batch_dims", 0}}); + auto floor_divide_Divide = + makeOP({size_Gather_450, 2}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); + auto floor_divide_Floor = makeOP({floor_divide_Divide}); + auto slice_Unsqueeze_452 = makeOP({floor_divide_Floor, 0}); + auto ScatterUpdate_152312 = makeOP({{0, 0, 0, 0}, {3}, slice_Unsqueeze_452, {0}}); + auto slice_Slice_459 = makeOP( + {transpose_Transpose, ScatterUpdate_152312, {0ll, 0ll, 0ll, LLONG_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Constant_182988 = makeConst(element::f32, + ov::Shape({ + 1, + 1, + 1, + 1, + }), + {-1.000000f}); + + auto neg_Multiply = makeOP({slice_Slice_459, Constant_182988}, {{"auto_broadcast", "numpy"}}); + auto ScatterUpdate_152368 = makeOP({{0, 0, 0, 0}, {3}, slice_Unsqueeze_452, {0}}); + auto slice_Slice2 = + makeOP({transpose_Transpose, {0, 0, 0, 0}, ScatterUpdate_152368, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto cat_Concat = makeOP({neg_Multiply, slice_Slice2}, {{"axis", -1}}); + auto ScatterUpdate_152421 = makeOP({{0, 0, 0}, {2}, slice_Unsqueeze_426, {0}}); + auto slice_Slice_433 = makeOP({Constant585, {0, 0, 0}, ScatterUpdate_152421, {1, 1, 1}}, + {{"begin_mask", {1, 1, 0}}, + {"end_mask", {1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto squeeze_Squeeze_436 = makeOP({slice_Slice_433, 1}); + auto squeeze_Squeeze_437 = makeOP({squeeze_Squeeze_436, 0}); + auto index_446_Gather = makeOP({squeeze_Squeeze_437, pos_ids, 0}, {{"batch_dims", 0}}); + auto unsqueeze_Unsqueeze_447 = makeOP({index_446_Gather, 1}); + auto mul_Multiply_463 = + makeOP({cat_Concat, unsqueeze_Unsqueeze_447}, {{"auto_broadcast", "numpy"}}); + auto add_Add = makeOP({mul_Multiply, mul_Multiply_463}, {{"auto_broadcast", "numpy"}}); + + return std::make_shared(ov::NodeVector{add_Add}, ov::ParameterVector{input, pos_id_end, pos_ids}); +} + +class RoPEGPUTestLlama2 : public SubgraphBaseTest { +public: + ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1) { + auto tensor = ov::Tensor(ov::element::i32, shape); + auto* ptr = static_cast(tensor.data()); + for (size_t i = 0; i < tensor.get_size(); i++) { + ptr[i] = start; + start += step; + } + return tensor; + } + + void generate_inputs(const std::vector& targetInputStaticShapes) override { + const auto& funcInputs = function->inputs(); + + const int position_id_start = 15; + auto& input_shape = targetInputStaticShapes[0]; + auto seq_length = input_shape[1]; + + ov::test::utils::InputGenerateData in_data; + in_data.start_from = -1; + in_data.range = 2; + in_data.resolution = 32768; + ov::Tensor t_input = utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, in_data); + ov::Tensor t_position_id_end = create_i32_tensor(ov::Shape({}), position_id_start + seq_length); + ov::Tensor t_position_ids = create_i32_tensor(ov::Shape({1, seq_length}), position_id_start); + + inputs.clear(); + inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input}); + inputs.insert({funcInputs[1].get_node_shared_ptr(), t_position_id_end}); + inputs.insert({funcInputs[2].get_node_shared_ptr(), t_position_ids}); + } + +protected: + void SetUp() override { + targetDevice = ov::test::utils::DEVICE_GPU; + + const int batch = 2; + const int seq_length = 7; + const size_t max_position_embeddings = 2048; + const size_t ndims = 128; + const size_t num_head = 32; + + InputShape inpShape = {{batch, seq_length, num_head, ndims}, {{batch, seq_length, num_head, ndims}}}; + init_input_shapes({inpShape}); + function = buildROPE_Llama2(batch, seq_length, max_position_embeddings, num_head, ndims); + } +}; + +TEST_F(RoPEGPUTestLlama2, smoke_CompareWithRefs) { + run(); + std::shared_ptr function = compiledModel.get_runtime_model(); + CheckNumberOfNodesWithType(function, {"RoPE"}, 1); +} + +class RoPEGPUTestChatGLM : public SubgraphBaseTest { +public: + ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1) { + auto tensor = ov::Tensor(ov::element::i32, shape); + auto* ptr = static_cast(tensor.data()); + for (size_t i = 0; i < tensor.get_size(); i++) { + ptr[i] = start; + start += step; + } + return tensor; + } + + void generate_inputs(const std::vector& targetInputStaticShapes) override { + const auto& funcInputs = function->inputs(); + + auto& input_shape = targetInputStaticShapes[0]; + auto seq_length = input_shape[0]; + // auto batch = input_shape[1]; + + ov::Tensor t_input = + utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, 2, -1.0f, 32768); + ov::Tensor t_cos_sin_cache = + utils::create_and_fill_tensor(funcInputs[1].get_element_type(), {32768, 32, 2}, 2, -1.0f, 32768); + ov::Tensor t_position_ids = create_i32_tensor(ov::Shape({1, seq_length}), 15); + + inputs.clear(); + inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input}); + inputs.insert({funcInputs[1].get_node_shared_ptr(), t_cos_sin_cache}); + inputs.insert({funcInputs[2].get_node_shared_ptr(), t_position_ids}); + } + +protected: + std::shared_ptr buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims) { + auto input = + std::make_shared(ov::element::f32, PartialShape{-1, batch, 4096 + 256 + 256}); + auto cos_sin_cache = std::make_shared(ov::element::f32, PartialShape{32768, 32, 2}); + auto position_ids = std::make_shared(ov::element::i32, PartialShape{-1, -1}); + + auto __module_transformer_index_67_Gather = + makeOP({cos_sin_cache, position_ids, 0}, {{"batch_dims", 0}}); + auto __module_transformer_transpose_Transpose = + makeOP({__module_transformer_index_67_Gather, {1, 0, 2, 3}}); + auto size_ShapeOf_110 = + makeOP({__module_transformer_transpose_Transpose}, {{"output_type", "i32"}}); + auto __getitem___Gather = makeOP({size_ShapeOf_110, -2, 0}, {{"batch_dims", 0}}); + auto mul_Multiply = makeOP({__getitem___Gather, 2}, {{"auto_broadcast", "numpy"}}); + auto slice_Unsqueeze_112 = makeOP({mul_Multiply, 0}); + + auto floordiv_Divide = + makeOP({mul_Multiply, 2}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); + auto floordiv_Floor = makeOP({floordiv_Divide}); + auto ListConstruct_126_Reshape_2 = makeOP({floordiv_Floor, {-1}}, {{"special_zero", false}}); + + auto ListUnpack_321 = makeOP({input, -1, {4096, 256, 256}}); + auto view_Reshape = + makeOP({ListUnpack_321->output(0), {0, 0, 32, 128}}, {{"special_zero", true}}); + + auto ScatterUpdate_229053 = makeOP({{0, 0, 0, 0}, {3}, slice_Unsqueeze_112, {0}}); + auto slice_Slice_357 = + makeOP({view_Reshape, {0, 0, 0, 0}, ScatterUpdate_229053, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto size_ShapeOf_346 = makeOP({view_Reshape}, {{"output_type", "i32"}}); + auto size_Gather_348 = makeOP({size_ShapeOf_346, 0, 0}, {{"batch_dims", 0}}); + auto ListConstruct_372_Reshape = makeOP({size_Gather_348, {-1}}, {{"special_zero", false}}); + auto size_Gather_351 = makeOP({size_ShapeOf_346, {2}, 0}, {{"batch_dims", 0}}); + auto ListConstruct_372_Concat = + makeOP({ListConstruct_372_Reshape, {-1}, size_Gather_351, ListConstruct_126_Reshape_2, {2}}, + {{"axis", 0}}); + auto reshape_Reshape_373 = + makeOP({slice_Slice_357, ListConstruct_372_Concat}, {{"special_zero", false}}); + auto select_Gather_381 = makeOP({reshape_Reshape_373, 0, -1}, {{"batch_dims", 0}}); + auto slice_Unsqueeze_367 = makeOP({size_Gather_348, 0}); + auto slice_Slice_369 = + makeOP({__module_transformer_transpose_Transpose, {0}, slice_Unsqueeze_367, {1}}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto size_ShapeOf_374 = makeOP({reshape_Reshape_373}, {{"output_type", "i32"}}); + auto size_Gather_376 = makeOP({size_ShapeOf_374, {3}, 0}, {{"batch_dims", 0}}); + auto ListConstruct_379_Concat = + makeOP({ListConstruct_372_Reshape, {-1}, {1}, size_Gather_376, {2}}, {{"axis", 0}}); + auto view_Reshape_380 = + makeOP({slice_Slice_369, ListConstruct_379_Concat}, {{"special_zero", false}}); + auto select_Gather_382 = makeOP({view_Reshape_380, 0, -1}, {{"batch_dims", 0}}); + auto mul_Multiply_383 = + makeOP({select_Gather_381, select_Gather_382}, {{"auto_broadcast", "numpy"}}); + auto select_Gather_384 = makeOP({reshape_Reshape_373, 1, -1}, {{"batch_dims", 0}}); + auto select_Gather_385 = makeOP({view_Reshape_380, 1, -1}, {{"batch_dims", 0}}); + auto mul_Multiply_386 = + makeOP({select_Gather_384, select_Gather_385}, {{"auto_broadcast", "numpy"}}); + auto sub_Subtract_389 = + makeOP({mul_Multiply_383, mul_Multiply_386}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_62716 = makeOP({sub_Subtract_389, -1}); + auto mul_Multiply_391 = + makeOP({select_Gather_384, select_Gather_382}, {{"auto_broadcast", "numpy"}}); + auto mul_Multiply_393 = + makeOP({select_Gather_381, select_Gather_385}, {{"auto_broadcast", "numpy"}}); + auto add_Add_396 = makeOP({mul_Multiply_391, mul_Multiply_393}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_62717 = makeOP({add_Add_396, -1}); + auto stack_401 = makeOP({Unsqueeze_62716, Unsqueeze_62717}, {{"axis", -1}}); + auto flatten_ShapeOf_402 = makeOP({stack_401}, {{"output_type", "i32"}}); + auto flatten_Slice_417 = makeOP({flatten_ShapeOf_402, {0}, {3}, {1}}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto flatten_Concat_420 = makeOP({flatten_Slice_417, {-1}}, {{"axis", 0}}); + auto flatten_Reshape_421 = makeOP({stack_401, flatten_Concat_420}, {{"special_zero", true}}); + auto ScatterUpdate_229067 = makeOP({{0, 0, 0, 0}, {3}, slice_Unsqueeze_112, {0}}); + auto slice_Slice_363 = + makeOP({view_Reshape, ScatterUpdate_229067, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto cat_Concat_425 = makeOP({flatten_Reshape_421, slice_Slice_363}, {{"axis", -1}}); + return std::make_shared(ov::NodeVector{cat_Concat_425}, + ov::ParameterVector{input, cos_sin_cache, position_ids}); + } + void SetUp() override { + targetDevice = ov::test::utils::DEVICE_GPU; + + const int batch = 2; + const int seq_length = 7; + const int num_head = 32; + const int rotary_dims = 64; + + InputShape inpShape = {{-1, batch, 4096 + 256 + 256}, {{seq_length, batch, 4096 + 256 + 256}}}; + init_input_shapes({inpShape}); + function = buildROPE_ChatGLM(batch, num_head, rotary_dims); + } +}; + +TEST_F(RoPEGPUTestChatGLM, smoke_CompareWithRefs) { + run(); + std::shared_ptr function = compiledModel.get_runtime_model(); + CheckNumberOfNodesWithType(function, {"RoPE"}, 1); +} + +class RoPEGPUTestQwen7b : public SubgraphBaseTest { +public: + void generate_inputs(const std::vector& targetInputStaticShapes) override { + const auto& funcInputs = function->inputs(); + + auto& input_shape = targetInputStaticShapes[0]; + + ov::Tensor t_input = + utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, 2, -1.0f, 32768); + ov::Tensor t_cos_cache = + utils::create_and_fill_tensor(funcInputs[1].get_element_type(), {1, 4096, 1, 128}, 2, -1.0f, 32768); + ov::Tensor t_sin_cache = + utils::create_and_fill_tensor(funcInputs[1].get_element_type(), {1, 4096, 1, 128}, 2, -1.0f, 32768); + + inputs.clear(); + inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input}); + inputs.insert({funcInputs[1].get_node_shared_ptr(), t_cos_cache}); + inputs.insert({funcInputs[2].get_node_shared_ptr(), t_sin_cache}); + } + +protected: + std::shared_ptr buildROPE_QWen7b() { + auto input = + std::make_shared(ov::element::f32, PartialShape{-1, -1, 4096 + 4096 + 4096}); + auto cos_cache = std::make_shared(ov::element::f32, PartialShape{1, -1, 1, 128}); + auto sin_cache = std::make_shared(ov::element::f32, PartialShape{1, -1, 1, 128}); + + auto ListUnpack_389_VariadicSplit = makeOP({input, 2, {4096, 4096, -1}}); + auto view_Reshape = makeOP({ListUnpack_389_VariadicSplit->output(0), {0, 0, 32, 128}}, + {{"special_zero", true}}); + auto size_ShapeOf_414 = makeOP({view_Reshape}, {{"output_type", "i32"}}); + auto size_Gather_416 = makeOP({size_ShapeOf_414, 1, 0}, {{"batch_dims", 0}}); + auto neg_Multiply = makeOP({size_Gather_416, -1}, {{"auto_broadcast", "numpy"}}); + auto slice_Unsqueeze_422 = makeOP({neg_Multiply, 0}); + auto ScatterUpdate_261437 = makeOP({{0, 0}, {1}, slice_Unsqueeze_422, {0}}); + auto slice_Slice_425 = makeOP({cos_cache, ScatterUpdate_261437, {0ll, LLONG_MAX}, {1, 1}}, + {{"begin_mask", {1, 0}}, + {"end_mask", {1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto slice_Slice_431 = + makeOP({slice_Slice_425, {0, 0, 0}, {0ll, 0ll, LLONG_MAX}, {1, 1, 1}}, + {{"begin_mask", {1, 1, 0}}, + {"end_mask", {1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto slice_Slice_437 = + makeOP({slice_Slice_431, {0, 0, 0, 0}, {0ll, 0ll, 0ll, LLONG_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto size_ShapeOf_462 = makeOP({slice_Slice_437}, {{"output_type", "i32"}}); + auto size_Gather_464 = makeOP({size_ShapeOf_462, {3}, 0}, {{"batch_dims", 0}}); + auto ScatterUpdate_261533 = makeOP({{0, 0, 0, 0}, {3}, size_Gather_464, {0}}); + auto slice_Slice_470 = + makeOP({view_Reshape, {0, 0, 0, 0}, ScatterUpdate_261533, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto mul_Multiply = makeOP({slice_Slice_470, slice_Slice_437}, {{"auto_broadcast", "numpy"}}); + auto size_ShapeOf_478 = makeOP({slice_Slice_470}, {{"output_type", "i32"}}); + auto Gather_239390 = makeOP({size_ShapeOf_478, {0, 1, 2}, 0}, {{"batch_dims", 0}}); + auto size_Gather_489 = makeOP({size_ShapeOf_478, 3, 0}, {{"batch_dims", 0}}); + auto floor_divide_Divide = + makeOP({size_Gather_489, 2}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); + auto floor_divide_Floor = makeOP({floor_divide_Divide}); + auto ListConstruct_493_Reshape_3 = + makeOP({floor_divide_Floor, {-1}}, {{"special_zero", false}}); + auto ListConstruct_493_Concat = + makeOP({Gather_239390, {2}, ListConstruct_493_Reshape_3}, {{"axis", 0}}); + auto reshape_Reshape = + makeOP({slice_Slice_470, ListConstruct_493_Concat}, {{"special_zero", false}}); + auto ListUnpack_496_Split = makeOP({reshape_Reshape, -2}, {{"num_splits", 2}}); + auto ListUnpack_496_Squeeze_0 = makeOP({ListUnpack_496_Split->output(1), -2}); + auto Constant_296840_compressed = makeConst(element::f16, + ov::Shape({ + 1, + 1, + 1, + 1, + }), + {-1}); + auto Constant_296840 = makeOP({Constant_296840_compressed}, {{"destination_type", "f32"}}); + auto neg_Multiply_499 = + makeOP({ListUnpack_496_Squeeze_0, Constant_296840}, {{"auto_broadcast", "numpy"}}); + auto ListUnpack_496_Squeeze = makeOP({ListUnpack_496_Split->output(0), -2}); + auto cat_Concat = makeOP({neg_Multiply_499, ListUnpack_496_Squeeze}, {{"axis", -1}}); + auto slice_Slice_449 = makeOP({sin_cache, ScatterUpdate_261437, {0ll, LLONG_MAX}, {1, 1}}, + {{"begin_mask", {1, 0}}, + {"end_mask", {1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto slice_Slice_455 = + makeOP({slice_Slice_449, {0, 0, 0}, {0ll, 0ll, LLONG_MAX}, {1, 1, 1}}, + {{"begin_mask", {1, 1, 0}}, + {"end_mask", {1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto slice_Slice_461 = + makeOP({slice_Slice_455, {0, 0, 0, 0}, {0ll, 0ll, 0ll, LLONG_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto mul_Multiply_503 = makeOP({cat_Concat, slice_Slice_461}, {{"auto_broadcast", "numpy"}}); + auto add_Add = makeOP({mul_Multiply, mul_Multiply_503}, {{"auto_broadcast", "numpy"}}); + return std::make_shared(ov::NodeVector{add_Add}, ov::ParameterVector{input, cos_cache, sin_cache}); + } + void SetUp() override { + targetDevice = ov::test::utils::DEVICE_GPU; + const int batch = 2; + const int seq_length = 7; + InputShape inpShape = {{batch, -1, 4096 + 4096 + 4096}, {{batch, seq_length, 4096 + 4096 + 4096}}}; + init_input_shapes({inpShape}); + function = buildROPE_QWen7b(); + } +}; + +TEST_F(RoPEGPUTestQwen7b, smoke_CompareWithRefs) { + run(); + std::shared_ptr function = compiledModel.get_runtime_model(); + CheckNumberOfNodesWithType(function, {"RoPE"}, 1); +} + +} // namespace test +} // namespace ov diff --git a/src/plugins/intel_gpu/tests/unit/transformations/convert_to_rope.cpp b/src/plugins/intel_gpu/tests/unit/transformations/convert_to_rope.cpp new file mode 100644 index 00000000000000..924b129def5638 --- /dev/null +++ b/src/plugins/intel_gpu/tests/unit/transformations/convert_to_rope.cpp @@ -0,0 +1,564 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include "gen_pattern.hpp" + +#include "intel_gpu/op/rope.hpp" + +#include "common_test_utils/ov_test_utils.hpp" +#include "gen_pattern.hpp" +// #include "utils/print_model.hpp" + +using namespace testing; +using namespace ov::intel_gpu; +using namespace ov::gen_pattern; + +static ov::OutputVector makeCosSinCache(size_t max_position_embeddings, size_t rotary_ndims) { + std::vector lut_sin(max_position_embeddings * rotary_ndims, 0.0f); + std::vector lut_cos(max_position_embeddings * rotary_ndims, 0.0f); + + // rotate_half style cos/sin table: + // y1 = cos(m*xita_i) * x1 - sin(m*xita_i) * x2 + // y2 = cos(m*xita_i) * x2 + sin(m*xita_i) * x1 + // + for (size_t i = 0, k = 0; i < rotary_ndims; i += 2, k++) { + auto xita_i = 1.0 / std::pow(10000.0, static_cast(i) / rotary_ndims); + float* psin = lut_sin.data(); + float* pcos = lut_cos.data(); + for (size_t m = 0; m < max_position_embeddings; m++, psin += rotary_ndims, pcos += rotary_ndims) { + auto vsin = std::sin(xita_i * m); + auto vcos = std::cos(xita_i * m); + pcos[k] = pcos[k + rotary_ndims / 2] = vcos; + psin[k] = psin[k + rotary_ndims / 2] = vsin; + } + } + auto Cos = std::make_shared(ov::element::f32, ov::Shape({1, 1, max_position_embeddings, rotary_ndims}), lut_cos); + auto Sin = std::make_shared(ov::element::f32, ov::Shape({1, 1, max_position_embeddings, rotary_ndims}), lut_sin); + + return {Cos, Sin}; +} + +static std::shared_ptr buildROPE_Llama2(const size_t batch, + const size_t seq_length, + const size_t max_position_embeddings, + const size_t ndims, + bool sin_cos_preprocessing) { + auto input = std::make_shared(ov::element::f32, ov::Shape{batch, seq_length, 32, ndims}); + auto param_cos = std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length, ndims}); + auto param_sin = std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length, ndims}); + + auto seq_len = std::make_shared(ov::element::i32, ov::Shape{1}); + auto gather_id = std::make_shared(ov::element::i32, ov::Shape{1, seq_length}); + + auto gather_from_sin_cos = [&](const ov::Output& const_tab) { + auto ScatterUpdate_152236 = makeOP({{0, 0, 0}, {2}, seq_len, {0}}); + auto slice_Slice = makeOP({const_tab, {0, 0, 0}, ScatterUpdate_152236, {1, 1, 1}}, + {{"begin_mask", {1, 1, 0}}, + {"end_mask", {1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto squeeze_Squeeze_435 = + makeOP({slice_Slice, {-1, static_cast(ndims)}}, {{"special_zero", false}}); + auto index_441_Gather = makeOP({squeeze_Squeeze_435, gather_id, {0}}, {{"batch_dims", 0}}); + return makeOP({index_441_Gather, {1, 1, -1, static_cast(ndims)}}, + {{"special_zero", false}}); + }; + + ov::OutputVector cos_sin(2); + ov::ParameterVector parameters; + if (sin_cos_preprocessing) { + auto cos_sin_cache = makeCosSinCache(max_position_embeddings, ndims); + cos_sin[0] = gather_from_sin_cos(cos_sin_cache[0]); + cos_sin[1] = gather_from_sin_cos(cos_sin_cache[1]); + parameters = ov::ParameterVector{input, seq_len, gather_id}; + } else { + cos_sin[0] = param_cos; + cos_sin[1] = param_sin; + parameters = ov::ParameterVector{input, param_cos, param_sin}; + } + + auto transpose_Transpose = makeOP({input, {0, 2, 1, 3}}); + auto mul_Multiply = makeOP({transpose_Transpose, cos_sin[0]}, {{"auto_broadcast", "numpy"}}); + auto slice_Slice_459 = + makeOP({transpose_Transpose, {0, 0, 0, 64}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Constant_182988 = makeConst(ov::element::f32, ov::Shape({1, 1, 1, 1}), std::vector({-1.000000f})); + auto neg_Multiply = makeOP({slice_Slice_459, Constant_182988}, {{"auto_broadcast", "numpy"}}); + auto slice_Slice = + makeOP({transpose_Transpose, {0, 0, 0, 0}, {0, 0, 0, 64}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto cat_Concat = makeOP({neg_Multiply, slice_Slice}, {{"axis", -1}}); + auto mul_Multiply_463 = makeOP({cat_Concat, cos_sin[1]}, {{"auto_broadcast", "numpy"}}); + auto add_Add = makeOP({mul_Multiply, mul_Multiply_463}, {{"auto_broadcast", "numpy"}}); + + return std::make_shared(ov::NodeVector{add_Add}, parameters); +} + +TEST_F(TransformationTestsF, ConvertToROPE_LLama2_no_gather) { + disable_rt_info_check(); + const int batch = 2; + const int seq_length = 16; + const size_t max_position_embeddings = 2048; + const size_t ndims = 128; + const size_t num_head = 32; + + model = buildROPE_Llama2(batch, seq_length, max_position_embeddings, ndims, false); + manager.register_pass(); + + { + auto hidden_states = + std::make_shared(ov::element::f32, ov::Shape{batch, seq_length, num_head, ndims}); + auto param_cos = std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length, ndims}); + auto param_sin = std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length, ndims}); + auto add_Add = makeOP({hidden_states, param_cos, param_sin}, + {{"config.slice_start", 0}, + {"config.slice_stop", 0}, + {"config.input_trans0213", true}, + {"config.is_interleaved", false}, + {"config.is_chatglm", false}, + {"config.is_qwen", false}, + {"config.head_cnt", 0}, + {"config.head_size", 0}, + {"config.rotary_ndims", static_cast(ndims)}, + {"config.gather_position_arg_id", 0}}); + + model_ref = std::make_shared(ov::NodeVector{add_Add}, + ov::ParameterVector{hidden_states, param_cos, param_sin}); + } +} + +TEST_F(TransformationTestsF, ConvertToROPE_LLama2_with_gather) { + disable_rt_info_check(); + const int batch = 2; + const int seq_length = 16; + const size_t max_position_embeddings = 2048; + const size_t ndims = 128; + const size_t num_head = 32; + + model = buildROPE_Llama2(batch, seq_length, max_position_embeddings, ndims, true); + manager.register_pass(); + + { + auto hidden_states = + std::make_shared(ov::element::f32, ov::Shape{batch, seq_length, num_head, ndims}); + auto seq_len = std::make_shared(ov::element::i32, ov::Shape{1}); + auto gather_id = std::make_shared(ov::element::i32, ov::Shape{1, seq_length}); + auto cos_sin_cache = makeCosSinCache(max_position_embeddings, ndims); + + auto add_Add = makeOP({hidden_states, cos_sin_cache[0], cos_sin_cache[1], gather_id}, + {{"config.slice_start", 0}, + {"config.slice_stop", 0}, + {"config.input_trans0213", true}, + {"config.is_interleaved", false}, + {"config.is_chatglm", false}, + {"config.is_qwen", false}, + {"config.head_cnt", 0}, + {"config.head_size", 0}, + {"config.rotary_ndims", static_cast(ndims)}, + {"config.gather_position_arg_id", 3}}); + + model_ref = std::make_shared(ov::NodeVector{add_Add}, + ov::ParameterVector{hidden_states, seq_len, gather_id}); + } +} + +static std::shared_ptr buildROPE_GPTNEOX(const int batch, + const int seq_length, + const int max_position_embeddings, + const int ndims, + const int num_heads, + const int rotary_ndims, + bool sin_cos_preprocessing) { + auto batch_s = static_cast(batch); + auto seq_length_s = static_cast(seq_length); + auto ndims_s = static_cast(ndims); + auto rotary_ndims_s = static_cast(rotary_ndims); + auto num_heads_s = static_cast(num_heads); + + auto input = std::make_shared(ov::element::f32, + ov::Shape{batch_s, seq_length_s, num_heads_s, ndims_s * 3}); + auto seq_len = std::make_shared(ov::element::i32, ov::Shape{1}); + auto gather_idx = + std::make_shared(ov::element::i32, ov::Shape{1, 1, seq_length_s, rotary_ndims_s}); + auto batch_limit = std::make_shared(ov::element::i32, ov::Shape{1}); + + ov::ParameterVector parameters; + ov::OutputVector cos_sin(2); + if (sin_cos_preprocessing) { + auto cos_sin_lut = makeCosSinCache(max_position_embeddings, rotary_ndims); + auto ro_slice_Slice = makeOP({cos_sin_lut[0], {0}, batch_limit, {1}}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + cos_sin[0] = makeOP({ro_slice_Slice, gather_idx}, {{"axis", 2}}); + + auto ro_slice_Slice_385 = makeOP({cos_sin_lut[1], {0}, batch_limit, {1}}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + cos_sin[1] = makeOP({ro_slice_Slice_385, gather_idx}, {{"axis", 2}}); + parameters = ov::ParameterVector{input, gather_idx, batch_limit}; + } else { + auto param_cos = + std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length_s, rotary_ndims_s}); + auto param_sin = + std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_length_s, rotary_ndims_s}); + parameters = ov::ParameterVector{input, param_cos, param_sin}; + cos_sin[0] = param_cos; + cos_sin[1] = param_sin; + } + + auto slice_Slice = makeOP({input, {0, 0, 0, 0}, {0, 0, 0, ndims}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto permute_Transpose = makeOP({slice_Slice, {0, 2, 1, 3}}); + auto slice_Slice_351 = + makeOP({permute_Transpose, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto mul_Multiply = makeOP({slice_Slice_351, cos_sin[0]}, {{"auto_broadcast", "numpy"}}); + auto slice_Slice_420 = makeOP( + {slice_Slice_351, {0, 0, 0, rotary_ndims / 2}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + + auto Constant_396096 = makeConst(ov::element::f32, ov::Shape({1, 1, 1, 1}), std::vector({-1.000000f})); + auto neg_Multiply = makeOP({slice_Slice_420, Constant_396096}, {{"auto_broadcast", "numpy"}}); + auto slice_Slice_414 = + makeOP({slice_Slice_351, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims / 2}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto cat_Concat = makeOP({neg_Multiply, slice_Slice_414}, {{"axis", -1}}); + auto mul_Multiply_424 = makeOP({cat_Concat, cos_sin[1]}, {{"auto_broadcast", "numpy"}}); + auto add_Add = makeOP({mul_Multiply, mul_Multiply_424}, {{"auto_broadcast", "numpy"}}); + auto slice_Slice_357 = + makeOP({permute_Transpose, {0, 0, 0, rotary_ndims}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto cat_Concat_458 = makeOP({add_Add, slice_Slice_357}, {{"axis", -1}}); + + return std::make_shared(ov::NodeVector{cat_Concat_458}, parameters); +} + +TEST_F(TransformationTestsF, ConvertToROPE_GPTNEOX_no_gather) { + disable_rt_info_check(); + const int batch = 2; + const int seq_len = 16; + const int ndims = 80; + const int num_heads = 32; + const int rotary_ndims = 20; + const int max_position_embeddings = 2048; + + model = buildROPE_GPTNEOX(batch, seq_len, max_position_embeddings, ndims, num_heads, rotary_ndims, false); + manager.register_pass(); + { + auto input = + std::make_shared(ov::element::f32, ov::Shape{batch, seq_len, num_heads, ndims * 3}); + auto param_cos = + std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_len, rotary_ndims}); + auto param_sin = + std::make_shared(ov::element::f32, ov::Shape{1, 1, seq_len, rotary_ndims}); + auto rope = makeOP({input, param_cos, param_sin}, + {{"config.slice_start", 0}, + {"config.slice_stop", ndims}, + {"config.input_trans0213", true}, + {"config.is_interleaved", false}, + {"config.is_chatglm", false}, + {"config.is_qwen", false}, + {"config.head_cnt", 0}, + {"config.head_size", 0}, + {"config.rotary_ndims", rotary_ndims}, + {"config.gather_position_arg_id", 0}}); + model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, param_cos, param_sin}); + } +} + +TEST_F(TransformationTestsF, ConvertToROPE_GPTNEOX_with_gather) { + disable_rt_info_check(); + const int batch = 2; + const int seq_len = 16; + const int ndims = 80; + const int rotary_ndims = 20; + const int num_heads = 32; + const int max_position_embeddings = 2048; + + model = buildROPE_GPTNEOX(batch, seq_len, max_position_embeddings, ndims, num_heads, rotary_ndims, true); + manager.register_pass(); + { + auto cos_sin = makeCosSinCache(max_position_embeddings, rotary_ndims); + auto input = + std::make_shared(ov::element::f32, ov::Shape{batch, seq_len, num_heads, ndims * 3}); + auto gather_idx = + std::make_shared(ov::element::i32, ov::Shape{1, 1, seq_len, rotary_ndims}); + auto batch_limit = std::make_shared(ov::element::i32, ov::Shape{1}); + + auto rope = makeOP({input, cos_sin[0], cos_sin[1], gather_idx}, + {{"config.slice_start", 0}, + {"config.slice_stop", ndims}, + {"config.input_trans0213", true}, + {"config.is_interleaved", false}, + {"config.is_chatglm", false}, + {"config.is_qwen", false}, + {"config.head_cnt", 0}, + {"config.head_size", 0}, + {"config.rotary_ndims", rotary_ndims}, + {"config.gather_position_arg_id", 3}}); + model_ref = + std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, gather_idx, batch_limit}); + } +} + +TEST_F(TransformationTestsF, ConvertToROPE_GPTJ) { + disable_rt_info_check(); + const int batch = 2; + const int seq_len = 7; + const int num_heads = 16; + const int ndims = 256; + const int rotary_ndims = 64; + { + std::vector rpi_idx(rotary_ndims); + for (int i = 0, index = 0; i < rotary_ndims; i += 2, index++) { + rpi_idx[i] = index; + rpi_idx[i + 1] = index; + } + auto repeat_interleave_index = makeConst(ov::element::i32, ov::Shape({rotary_ndims}), rpi_idx); + + auto input = + std::make_shared(ov::element::f32, ov::Shape{batch, seq_len, num_heads, ndims}); + auto gather_sin_cos = + std::make_shared(ov::element::f32, ov::Shape{1, seq_len, rotary_ndims}); + + auto split = makeOP({gather_sin_cos, {-1}, {rotary_ndims / 2, -1}}); + auto sin_tab = + makeOP({split->output(0), {1, -1, 1, rotary_ndims / 2}}, {{"special_zero", false}}); + auto cos_tab = + makeOP({split->output(1), {1, -1, 1, rotary_ndims / 2}}, {{"special_zero", false}}); + + auto slice_Slice_576 = + makeOP({input, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto repeat_interleave_Cos = + makeOP({cos_tab, repeat_interleave_index, {3}}, {{"batch_dims", 0}}); + auto mul_Multiply_757 = + makeOP({slice_Slice_576, repeat_interleave_Cos}, {{"auto_broadcast", "numpy"}}); + + auto slice_Slice_787 = + makeOP({slice_Slice_576, {0, 0, 0, 1}, {0, 0, 0, INT_MAX}, {1, 1, 1, 2}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Constant_191672 = makeConst(ov::element::f32, ov::Shape({1, 1, 1, 1}), std::vector({-1.000000f})); + auto neg_Multiply_790 = + makeOP({slice_Slice_787, Constant_191672}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_61918 = makeOP({neg_Multiply_790, {-1}}); + auto slice_Slice_781 = + makeOP({slice_Slice_576, {0, 0, 0, 0}, {0, 0, 0, INT_MAX}, {1, 1, 1, 2}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Unsqueeze_61919 = makeOP({slice_Slice_781, {-1}}); + auto stack_795 = makeOP({Unsqueeze_61918, Unsqueeze_61919}, {{"axis", -1}}); + auto ShapeOf_165368 = makeOP>( + {stack_795}, + {{"type_relax", true}, {"input_data_types", {}}, {"output_data_types", {ov::element::i32}}}); + auto flatten_Slice_811 = makeOP({ShapeOf_165368, {0}, {3}, {1}}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto flatten_Concat_814 = makeOP({flatten_Slice_811, {-1}}, {{"axis", 0}}); + auto flatten_Reshape_815 = + makeOP({stack_795, flatten_Concat_814}, {{"special_zero", true}}); + auto repeat_interleave_Sin = + makeOP({sin_tab, repeat_interleave_index, {3}}, {{"batch_dims", 0}}); + auto mul_Multiply_816 = + makeOP({flatten_Reshape_815, repeat_interleave_Sin}, {{"auto_broadcast", "numpy"}}); + auto add_Add_819 = makeOP({mul_Multiply_757, mul_Multiply_816}, {{"auto_broadcast", "numpy"}}); + auto slice_Slice_582 = + makeOP({input, {0, 0, 0, rotary_ndims}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto cat_Concat_826 = makeOP({add_Add_819, slice_Slice_582}, {{"axis", -1}}); + auto permute_Transpose_828 = makeOP({cat_Concat_826, {0, 2, 1, 3}}); + model = std::make_shared(ov::NodeVector{permute_Transpose_828}, + ov::ParameterVector{input, gather_sin_cos}); + } + manager.register_pass(); + { + auto input = + std::make_shared(ov::element::f32, ov::Shape{batch, seq_len, num_heads, ndims}); + auto cos_sin = std::make_shared(ov::element::f32, ov::Shape{1, seq_len, rotary_ndims}); + auto rope = makeOP({input, cos_sin, cos_sin}, + {{"config.slice_start", 0}, + {"config.slice_stop", 0}, + {"config.input_trans0213", false}, + {"config.is_interleaved", true}, + {"config.is_chatglm", false}, + {"config.is_qwen", false}, + {"config.head_cnt", 0}, + {"config.head_size", 0}, + {"config.rotary_ndims", rotary_ndims}, + {"config.gather_position_arg_id", 0}}); + model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, cos_sin}); + } +} + +TEST_F(TransformationTestsF, ConvertToROPE_chatGML) { + disable_rt_info_check(); + const int batch = 2; + const int seq_len = 7; + const int num_heads = 32; + const int ndims = 128; + const int rotary_ndims = 64; + const int max_pos_length = 2048; + { + auto input = std::make_shared(ov::element::f32, ov::Shape{seq_len, batch, 4608}); + auto seq_length = std::make_shared(ov::element::i32, ov::Shape{1}); + auto cos_sin_cache = + std::make_shared(ov::element::f32, + ov::Shape{max_pos_length, batch, rotary_ndims / 2, 2}); + + auto ListUnpack_321 = makeOP({input, -1, {4096, 256, 256}}); + auto view_Reshape = makeOP({ListUnpack_321->output(0), {0, 0, num_heads, ndims}}, + {{"special_zero", true}}); + auto aten_slice_Slice_357 = + makeOP({view_Reshape, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto ListConstruct_372_Concat = + makeOP({seq_length, {-1}, {num_heads}, {rotary_ndims / 2}, {2}}, {{"axis", 0}}); + auto aten_reshape_Reshape_373 = + makeOP({aten_slice_Slice_357, ListConstruct_372_Concat}, {{"special_zero", false}}); + auto aten_select_Gather_381 = + makeOP({aten_reshape_Reshape_373, 0, -1}, {{"batch_dims", 0}}); + auto aten_slice_Slice_369 = makeOP({cos_sin_cache, {0}, seq_length, {1}}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto ListConstruct_379_Concat = + makeOP({seq_length, {-1}, {1}, {rotary_ndims / 2}, {2}}, {{"axis", 0}}); + auto aten_view_Reshape_380 = + makeOP({aten_slice_Slice_369, ListConstruct_379_Concat}, {{"special_zero", false}}); + auto aten_select_Gather_382 = makeOP({aten_view_Reshape_380, 0, -1}, {{"batch_dims", 0}}); + auto aten_mul_Multiply_383 = makeOP({aten_select_Gather_381, aten_select_Gather_382}, + {{"auto_broadcast", "numpy"}}); + auto aten_select_Gather_384 = + makeOP({aten_reshape_Reshape_373, 1, -1}, {{"batch_dims", 0}}); + auto aten_select_Gather_385 = makeOP({aten_view_Reshape_380, 1, -1}, {{"batch_dims", 0}}); + auto aten_mul_Multiply_386 = makeOP({aten_select_Gather_384, aten_select_Gather_385}, + {{"auto_broadcast", "numpy"}}); + auto Multiply_101315 = + makeOP({aten_mul_Multiply_386, -1.000000f}, {{"auto_broadcast", "numpy"}}); + auto aten_sub_Subtract_389 = + makeOP({aten_mul_Multiply_383, Multiply_101315}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_62716 = makeOP({aten_sub_Subtract_389, -1}); + auto aten_mul_Multiply_391 = makeOP({aten_select_Gather_384, aten_select_Gather_382}, + {{"auto_broadcast", "numpy"}}); + auto aten_mul_Multiply_393 = makeOP({aten_select_Gather_381, aten_select_Gather_385}, + {{"auto_broadcast", "numpy"}}); + auto aten_add_Add_396 = + makeOP({aten_mul_Multiply_391, aten_mul_Multiply_393}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_62717 = makeOP({aten_add_Add_396, -1}); + auto aten_stack_401 = makeOP({Unsqueeze_62716, Unsqueeze_62717}, {{"axis", -1}}); + auto ShapeOf_134820 = makeOP>( + {aten_stack_401}, + {{"type_relax", true}, {"input_data_types", {}}, {"output_data_types", {ov::element::i32}}}); + auto aten_flatten_Slice_417 = makeOP({ShapeOf_134820, {0}, {3}, {1}}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto aten_flatten_Concat_420 = makeOP({aten_flatten_Slice_417, {-1}}, {{"axis", 0}}); + auto aten_flatten_Reshape_421 = + makeOP({aten_stack_401, aten_flatten_Concat_420}, {{"special_zero", true}}); + auto aten_slice_Slice_363 = + makeOP({view_Reshape, {0, 0, 0, rotary_ndims}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto aten_cat_Concat_425 = + makeOP({aten_flatten_Reshape_421, aten_slice_Slice_363}, {{"axis", -1}}); + model = std::make_shared(ov::NodeVector{aten_cat_Concat_425}, + ov::ParameterVector{input, seq_length, cos_sin_cache}); + } + manager.register_pass(); + { + auto input = std::make_shared(ov::element::f32, ov::Shape{seq_len, batch, 4608}); + auto seq_length = std::make_shared(ov::element::i32, ov::Shape{1}); + auto cos_sin_cache = + std::make_shared(ov::element::f32, + ov::Shape{max_pos_length, batch, rotary_ndims / 2, 2}); + auto rope = makeOP({input, cos_sin_cache, cos_sin_cache}, + {{"config.slice_start", 0}, + {"config.slice_stop", 4096}, + {"config.input_trans0213", false}, + {"config.is_interleaved", false}, + {"config.rotary_ndims", rotary_ndims}, + {"config.is_chatglm", true}, + {"config.is_qwen", false}, + {"config.head_cnt", num_heads}, + {"config.head_size", ndims}, + {"config.gather_position_arg_id", 0}}); + model_ref = + std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, seq_length, cos_sin_cache}); + } +}