From f6e0ba02a580cebfb4a4f9305b07a806c5e9509f Mon Sep 17 00:00:00 2001 From: Katarzyna Mitrus Date: Fri, 22 Nov 2024 10:07:15 +0100 Subject: [PATCH] [Op][Internal] Rename SwiGLU to GLU (#27683) ### Details: - Rename internal op SwiGLU to GLU (no naming changes for GPU swiglu kernel in this PR) Current SwiGLU can be also GeGLU, it depends on the glu_type member. It has been proposed by several people to rename this op and make the name more generic like GLU. Related comment: https://github.com/openvinotoolkit/openvino/pull/27579#discussion_r1846130138 ### Tickets: - 157623 --- .../include/ov_ops/{swiglu.hpp => glu.hpp} | 24 +++--- .../{swiglu_fusion.hpp => glu_fusion.hpp} | 6 +- .../src/ov_ops/{swiglu.cpp => glu.cpp} | 34 ++++----- .../{swiglu_fusion.cpp => glu_fusion.cpp} | 32 ++++---- ...lu_fusion_test.cpp => glu_fusion_test.cpp} | 74 +++++++++---------- .../intel_gpu/plugin/primitives_list.hpp | 2 +- .../include/intel_gpu/primitives/swiglu.hpp | 6 +- src/plugins/intel_gpu/src/graph/swiglu.cpp | 4 +- .../kernels/swiglu/swiglu_kernel_base.cpp | 4 +- .../kernels/swiglu/swiglu_kernel_base.h | 6 +- .../intel_gpu/src/plugin/ops/swiglu.cpp | 8 +- .../src/plugin/transformations_pipeline.cpp | 4 +- .../tests/unit/test_cases/swiglu_gpu_test.cpp | 4 +- 13 files changed, 104 insertions(+), 104 deletions(-) rename src/common/transformations/include/ov_ops/{swiglu.hpp => glu.hpp} (80%) rename src/common/transformations/include/transformations/common_optimizations/{swiglu_fusion.hpp => glu_fusion.hpp} (69%) rename src/common/transformations/src/ov_ops/{swiglu.cpp => glu.cpp} (67%) rename src/common/transformations/src/transformations/common_optimizations/{swiglu_fusion.cpp => glu_fusion.cpp} (83%) rename src/common/transformations/tests/common_optimizations/{swiglu_fusion_test.cpp => glu_fusion_test.cpp} (70%) diff --git a/src/common/transformations/include/ov_ops/swiglu.hpp b/src/common/transformations/include/ov_ops/glu.hpp similarity index 80% rename from src/common/transformations/include/ov_ops/swiglu.hpp rename to src/common/transformations/include/ov_ops/glu.hpp index f03c1ac1a26666..760641978b574d 100644 --- a/src/common/transformations/include/ov_ops/swiglu.hpp +++ b/src/common/transformations/include/ov_ops/glu.hpp @@ -11,16 +11,16 @@ namespace ov { namespace op { namespace internal { -/// \brief Operator performing Swish Gated Linear Unit Activation +/// \brief Operator performing Gated Linear Unit Activation /// This operation performs gated linear unit activation that combines swish or gelu activation function -class TRANSFORMATIONS_API SwiGLU : public ov::op::Op { +class TRANSFORMATIONS_API GLU : public ov::op::Op { public: - OPENVINO_OP("SwiGLU", "ie_internal_opset"); + OPENVINO_OP("GLU", "ie_internal_opset"); enum GluType { Swish = 0, Gelu, Gelu_Tanh }; - SwiGLU() = default; - /// \brief Constructs an SwiGLU operation. + GLU() = default; + /// \brief Constructs an GLU operation. /// /// \param data Input tensor with data /// \param axis The index of an axis in "data" along which to perform the split @@ -28,12 +28,12 @@ class TRANSFORMATIONS_API SwiGLU : public ov::op::Op { /// \param glu_type GLU type, one of Swish, Gelu and Gelu_Tanh /// \param split_to_glu_idx Output index of variadic split, which is connected to GLU /// \param output_type Output element type - SwiGLU(const Output& data, - int64_t axis, - int64_t split_lengths, - const GluType glu_type, - const size_t split_to_glu_idx, - const ov::element::Type output_type = ov::element::undefined); + GLU(const Output& data, + int64_t axis, + int64_t split_lengths, + const GluType glu_type, + const size_t split_to_glu_idx, + const ov::element::Type output_type = ov::element::undefined); bool visit_attributes(ov::AttributeVisitor& visitor) override; @@ -76,7 +76,7 @@ class TRANSFORMATIONS_API SwiGLU : public ov::op::Op { }; // TODO 157615: Move to shape_inference -TRANSFORMATIONS_API std::vector shape_infer(const SwiGLU* op, +TRANSFORMATIONS_API std::vector shape_infer(const GLU* op, std::vector input_shapes); } // namespace internal diff --git a/src/common/transformations/include/transformations/common_optimizations/swiglu_fusion.hpp b/src/common/transformations/include/transformations/common_optimizations/glu_fusion.hpp similarity index 69% rename from src/common/transformations/include/transformations/common_optimizations/swiglu_fusion.hpp rename to src/common/transformations/include/transformations/common_optimizations/glu_fusion.hpp index 18205bd1a1e8e2..7ec71a05027d80 100644 --- a/src/common/transformations/include/transformations/common_optimizations/swiglu_fusion.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/glu_fusion.hpp @@ -11,10 +11,10 @@ namespace ov { namespace pass { -class TRANSFORMATIONS_API SwiGLUFusion : public ov::pass::MatcherPass { +class TRANSFORMATIONS_API GLUFusion : public ov::pass::MatcherPass { public: - OPENVINO_RTTI("SwiGLUFusion", "0"); - SwiGLUFusion(); + OPENVINO_RTTI("GLUFusion", "0"); + GLUFusion(); }; } // namespace pass diff --git a/src/common/transformations/src/ov_ops/swiglu.cpp b/src/common/transformations/src/ov_ops/glu.cpp similarity index 67% rename from src/common/transformations/src/ov_ops/swiglu.cpp rename to src/common/transformations/src/ov_ops/glu.cpp index b3b9e71076aee0..bc3dfb89ab8b9b 100644 --- a/src/common/transformations/src/ov_ops/swiglu.cpp +++ b/src/common/transformations/src/ov_ops/glu.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "ov_ops/swiglu.hpp" +#include "ov_ops/glu.hpp" #include "openvino/core/partial_shape.hpp" #include "openvino/core/validation_util.hpp" @@ -13,12 +13,12 @@ namespace ov { namespace op { namespace internal { -SwiGLU::SwiGLU(const Output& data, - int64_t axis, - int64_t split_lengths, - const GluType glu_type, - const size_t split_to_glu_idx, - const ov::element::Type output_type) +GLU::GLU(const Output& data, + int64_t axis, + int64_t split_lengths, + const GluType glu_type, + const size_t split_to_glu_idx, + const ov::element::Type output_type) : Op({data}), m_axis(axis), m_split_lengths(split_lengths), @@ -28,14 +28,14 @@ SwiGLU::SwiGLU(const Output& data, validate_and_infer_types(); } -bool SwiGLU::visit_attributes(ov::AttributeVisitor& visitor) { +bool GLU::visit_attributes(ov::AttributeVisitor& visitor) { visitor.on_attribute("axis", m_axis); visitor.on_attribute("split_lengths", m_split_lengths); visitor.on_attribute("output_type", m_output_type); return true; } -void SwiGLU::validate_and_infer_types() { +void GLU::validate_and_infer_types() { auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type; std::vector input_shapes = {get_input_partial_shape(0), @@ -45,17 +45,17 @@ void SwiGLU::validate_and_infer_types() { set_output_type(0, output_type, shape_infer(this, input_shapes)[0]); } -std::shared_ptr SwiGLU::clone_with_new_inputs(const ov::OutputVector& new_args) const { +std::shared_ptr GLU::clone_with_new_inputs(const ov::OutputVector& new_args) const { check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), - m_axis, - m_split_lengths, - m_glu_type, - m_split_to_glu_idx, - m_output_type); + return std::make_shared(new_args.at(0), + m_axis, + m_split_lengths, + m_glu_type, + m_split_to_glu_idx, + m_output_type); } -std::vector shape_infer(const SwiGLU* op, std::vector input_shapes) { +std::vector shape_infer(const GLU* op, std::vector input_shapes) { ov::op::v1::VariadicSplit variadic_split; std::vector axis = {op->get_axis()}; std::vector split_lengths = {op->get_split_lengths(), -1}; diff --git a/src/common/transformations/src/transformations/common_optimizations/swiglu_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/glu_fusion.cpp similarity index 83% rename from src/common/transformations/src/transformations/common_optimizations/swiglu_fusion.cpp rename to src/common/transformations/src/transformations/common_optimizations/glu_fusion.cpp index 84c6dbceb39f2f..2b6c2092a054c2 100644 --- a/src/common/transformations/src/transformations/common_optimizations/swiglu_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/glu_fusion.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "transformations/common_optimizations/swiglu_fusion.hpp" +#include "transformations/common_optimizations/glu_fusion.hpp" #include "openvino/core/rt_info.hpp" #include "openvino/op/constant.hpp" @@ -13,13 +13,13 @@ #include "openvino/pass/manager.hpp" #include "openvino/pass/pattern/op/or.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" -#include "ov_ops/swiglu.hpp" +#include "ov_ops/glu.hpp" #include "transformations/utils/utils.hpp" namespace ov { namespace pass { -SwiGLUFusion::SwiGLUFusion() { +GLUFusion::GLUFusion() { using namespace ov::pass::pattern; using ov::pass::pattern::op::Or; @@ -28,8 +28,8 @@ SwiGLUFusion::SwiGLUFusion() { return out_ps.rank().is_static() && out_ps[out_ps.rank().get_length() - 1].is_static() && out_ps.size() <= 5; }; - // Detect SwiGLU decomposition pattern - // SwiGLU(Xw, Xv, beta) = (Xw * (1.0 + exp(-beta * Xw))) * Xv + // Detect GLU decomposition pattern + // GLU(Xw, Xv, beta) = (Xw * (1.0 + exp(-beta * Xw))) * Xv auto data_m = any_input(last_dim_static); // VariadicSplit(X, axis, split_lengths) = Xw, Xv @@ -60,11 +60,11 @@ SwiGLUFusion::SwiGLUFusion() { auto isSwiGLU = pattern_map.count(swish_m); auto isGeGLU = pattern_map.count(gelu_m); size_t split_to_glu_idx = 0; - ov::op::internal::SwiGLU::GluType glu_type = ov::op::internal::SwiGLU::GluType::Swish; + ov::op::internal::GLU::GluType glu_type = ov::op::internal::GLU::GluType::Swish; if (isSwiGLU) { auto swish = std::dynamic_pointer_cast(pattern_map.at(swish_m).get_node_shared_ptr()); - glu_type = ov::op::internal::SwiGLU::GluType::Swish; + glu_type = ov::op::internal::GLU::GluType::Swish; split_to_glu_idx = swish->input_value(0).get_index(); size_t split_in_idx = ov::is_type(mul->get_input_node_shared_ptr(0)) ? 1 : 0; @@ -73,8 +73,8 @@ SwiGLUFusion::SwiGLUFusion() { } else if (isGeGLU) { auto gelu = std::dynamic_pointer_cast(pattern_map.at(gelu_m).get_node_shared_ptr()); glu_type = (gelu->get_approximation_mode() == ov::op::GeluApproximationMode::ERF) - ? ov::op::internal::SwiGLU::GluType::Gelu - : ov::op::internal::SwiGLU::GluType::Gelu_Tanh; + ? ov::op::internal::GLU::GluType::Gelu + : ov::op::internal::GLU::GluType::Gelu_Tanh; split_to_glu_idx = gelu->input_value(0).get_index(); size_t split_in_idx = ov::is_type(mul->get_input_node_shared_ptr(0)) ? 1 : 0; @@ -107,12 +107,12 @@ SwiGLUFusion::SwiGLUFusion() { auto data = pattern_map.at(data_m); auto output_type = m.get_match_root()->get_output_element_type(0); - auto swiglu = std::make_shared(data, - axis_value, - split_lengths_value, - glu_type, - split_to_glu_idx, - output_type); + auto swiglu = std::make_shared(data, + axis_value, + split_lengths_value, + glu_type, + split_to_glu_idx, + output_type); swiglu->set_friendly_name(m.get_match_root()->get_friendly_name()); ov::copy_runtime_info(m.get_matched_nodes(), swiglu); ov::replace_node(m.get_match_root(), swiglu); @@ -120,7 +120,7 @@ SwiGLUFusion::SwiGLUFusion() { return true; }; - auto m = std::make_shared(mul_m, "SwiGLUFusion"); + auto m = std::make_shared(mul_m, "GLUFusion"); this->register_matcher(m, callback); } diff --git a/src/common/transformations/tests/common_optimizations/swiglu_fusion_test.cpp b/src/common/transformations/tests/common_optimizations/glu_fusion_test.cpp similarity index 70% rename from src/common/transformations/tests/common_optimizations/swiglu_fusion_test.cpp rename to src/common/transformations/tests/common_optimizations/glu_fusion_test.cpp index 75c8fba75024c3..4d879be57672cd 100644 --- a/src/common/transformations/tests/common_optimizations/swiglu_fusion_test.cpp +++ b/src/common/transformations/tests/common_optimizations/glu_fusion_test.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "transformations/common_optimizations/swiglu_fusion.hpp" +#include "transformations/common_optimizations/glu_fusion.hpp" #include @@ -18,13 +18,13 @@ #include "openvino/op/swish.hpp" #include "openvino/op/variadic_split.hpp" #include "openvino/pass/manager.hpp" -#include "ov_ops/swiglu.hpp" +#include "ov_ops/glu.hpp" #include "transformations/utils/utils.hpp" using namespace testing; using namespace ov::pass; -TEST_F(TransformationTestsF, SwiGLUFusionTest1) { +TEST_F(TransformationTestsF, GLUFusionTest1) { { auto input = std::make_shared(ov::element::f16, ov::PartialShape{2, 1, 6}); auto axis_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}); @@ -34,24 +34,24 @@ TEST_F(TransformationTestsF, SwiGLUFusionTest1) { auto mul = std::make_shared(swish, variadic_split->output(1)); model = std::make_shared(ov::NodeVector{mul}, ov::ParameterVector{input}); - manager.register_pass(); + manager.register_pass(); } { int64_t axis = -1; int64_t split_lenghts = 3; auto input = std::make_shared(ov::element::f16, ov::PartialShape{2, 1, 6}); - auto swiglu = std::make_shared(input, - axis, - split_lenghts, - ov::op::internal::SwiGLU::GluType::Swish, - 0, - ov::element::f16); + auto swiglu = std::make_shared(input, + axis, + split_lenghts, + ov::op::internal::GLU::GluType::Swish, + 0, + ov::element::f16); model_ref = std::make_shared(ov::NodeVector{swiglu}, ov::ParameterVector{input}); } } -TEST_F(TransformationTestsF, SwiGLUFusionTest2) { +TEST_F(TransformationTestsF, GLUFusionTest2) { { auto input = std::make_shared(ov::element::f16, ov::PartialShape{-1, -1, 6}); auto axis_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {0}); @@ -61,11 +61,11 @@ TEST_F(TransformationTestsF, SwiGLUFusionTest2) { auto mul = std::make_shared(swish, variadic_split->output(1)); model = std::make_shared(ov::NodeVector{mul}, ov::ParameterVector{input}); - manager.register_pass(); + manager.register_pass(); } } -TEST_F(TransformationTestsF, SwiGLUFusionTest3) { +TEST_F(TransformationTestsF, GLUFusionTest3) { { auto input = std::make_shared(ov::element::f16, ov::PartialShape{-1, -1, 6}); auto axis_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}); @@ -75,24 +75,24 @@ TEST_F(TransformationTestsF, SwiGLUFusionTest3) { auto mul = std::make_shared(swish, variadic_split->output(1)); model = std::make_shared(ov::NodeVector{mul}, ov::ParameterVector{input}); - manager.register_pass(); + manager.register_pass(); } { int64_t axis = -1; int64_t split_lenghts = 3; auto input = std::make_shared(ov::element::f16, ov::PartialShape{-1, -1, 6}); - auto swiglu = std::make_shared(input, - axis, - split_lenghts, - ov::op::internal::SwiGLU::GluType::Swish, - 0, - ov::element::f16); + auto swiglu = std::make_shared(input, + axis, + split_lenghts, + ov::op::internal::GLU::GluType::Swish, + 0, + ov::element::f16); model_ref = std::make_shared(ov::NodeVector{swiglu}, ov::ParameterVector{input}); } } -TEST_F(TransformationTestsF, SwiGLUFusionTest3ReverseOrder) { +TEST_F(TransformationTestsF, GLUFusionTest3ReverseOrder) { { auto input = std::make_shared(ov::element::f16, ov::PartialShape{-1, -1, 6}); auto axis_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}); @@ -102,24 +102,24 @@ TEST_F(TransformationTestsF, SwiGLUFusionTest3ReverseOrder) { auto mul = std::make_shared(variadic_split->output(1), swish); model = std::make_shared(ov::NodeVector{mul}, ov::ParameterVector{input}); - manager.register_pass(); + manager.register_pass(); } { int64_t axis = -1; int64_t split_lenghts = 3; auto input = std::make_shared(ov::element::f16, ov::PartialShape{-1, -1, 6}); - auto swiglu = std::make_shared(input, - axis, - split_lenghts, - ov::op::internal::SwiGLU::GluType::Swish, - 0, - ov::element::f16); + auto swiglu = std::make_shared(input, + axis, + split_lenghts, + ov::op::internal::GLU::GluType::Swish, + 0, + ov::element::f16); model_ref = std::make_shared(ov::NodeVector{swiglu}, ov::ParameterVector{input}); } } -TEST_F(TransformationTestsF, SwiGLUFusionTest4) { +TEST_F(TransformationTestsF, GLUFusionTest4) { { auto input = std::make_shared(ov::element::f16, ov::PartialShape{-1, -1, 6}); auto axis_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}); @@ -129,7 +129,7 @@ TEST_F(TransformationTestsF, SwiGLUFusionTest4) { auto mul = std::make_shared(swish, variadic_split->output(0)); model = std::make_shared(ov::NodeVector{mul}, ov::ParameterVector{input}); - manager.register_pass(); + manager.register_pass(); } } @@ -143,18 +143,18 @@ TEST_F(TransformationTestsF, GeGLUFusionTest1) { auto mul = std::make_shared(variadic_split->output(0), gelu); model = std::make_shared(ov::NodeVector{mul}, ov::ParameterVector{input}); - manager.register_pass(); + manager.register_pass(); } { int64_t axis = -1; int64_t split_lenghts = 3; auto input = std::make_shared(ov::element::f16, ov::PartialShape{2, 1, 6}); - auto swiglu = std::make_shared(input, - axis, - split_lenghts, - ov::op::internal::SwiGLU::GluType::Gelu, - 1, - ov::element::f16); + auto swiglu = std::make_shared(input, + axis, + split_lenghts, + ov::op::internal::GLU::GluType::Gelu, + 1, + ov::element::f16); model_ref = std::make_shared(ov::NodeVector{swiglu}, ov::ParameterVector{input}); } diff --git a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp index 27e5540a3786ab..ced915d25610e8 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp @@ -287,7 +287,7 @@ REGISTER_FACTORY(internal, KVCacheCompressed); REGISTER_FACTORY(internal, ReadValue); REGISTER_FACTORY(internal, ReadValues); REGISTER_FACTORY(internal, Gemm); -REGISTER_FACTORY(internal, SwiGLU); +REGISTER_FACTORY(internal, GLU); REGISTER_FACTORY(internal, IndirectGemm); REGISTER_FACTORY(internal, Convolution); REGISTER_FACTORY(internal, Placeholder); diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/swiglu.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/swiglu.hpp index 8e9ea5aff03902..1a72e36d471dfc 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/swiglu.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/swiglu.hpp @@ -3,7 +3,7 @@ // #pragma once -#include "ov_ops/swiglu.hpp" +#include "ov_ops/glu.hpp" #include "primitive.hpp" namespace cldnn { @@ -25,7 +25,7 @@ struct swiglu : public primitive_base { const input_info& input, const int64_t& axis, const int64_t& split_lengths, - const ov::op::internal::SwiGLU::GluType glu_type, + const ov::op::internal::GLU::GluType glu_type, const size_t split_to_glu_idx, const tensor output_size) : primitive_base(id, {input}), @@ -37,7 +37,7 @@ struct swiglu : public primitive_base { int64_t axis = 0; int64_t split_lengths = 0; - ov::op::internal::SwiGLU::GluType glu_type = ov::op::internal::SwiGLU::GluType::Swish; + ov::op::internal::GLU::GluType glu_type = ov::op::internal::GLU::GluType::Swish; size_t split_to_glu_idx = 0; tensor output_size; diff --git a/src/plugins/intel_gpu/src/graph/swiglu.cpp b/src/plugins/intel_gpu/src/graph/swiglu.cpp index 127b8645870157..e82e4e974b1868 100644 --- a/src/plugins/intel_gpu/src/graph/swiglu.cpp +++ b/src/plugins/intel_gpu/src/graph/swiglu.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "ov_ops/swiglu.hpp" +#include "ov_ops/glu.hpp" #include "swiglu_inst.h" #include "primitive_type_base.h" @@ -28,7 +28,7 @@ std::vector swiglu_inst::calc_output_layouts(swiglu_node const& /*node*/ auto output_type = impl_param.desc->output_data_types[0].value_or(input_layout.data_type); auto output_format = input_layout.format; - ov::op::internal::SwiGLU op; + ov::op::internal::GLU op; op.set_axis(desc->axis); op.set_split_lengths(desc->split_lengths); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/swiglu/swiglu_kernel_base.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/swiglu/swiglu_kernel_base.cpp index b3c31f31128c49..b6b67bd4ed222d 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/swiglu/swiglu_kernel_base.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/swiglu/swiglu_kernel_base.cpp @@ -25,10 +25,10 @@ JitConstants SwiGLUKernelBase::GetJitConstants(const swiglu_params& params, cons jit.AddConstants({MakeJitConstant("LWS1", dispatchData.lws[1])}); jit.AddConstants({MakeJitConstant("LWS2", dispatchData.lws[2])}); const std::string type_suffix = (GetAccumulatorType(params) == Datatype::F32) ? "f" : "h"; - if (params.glu_type == ov::op::internal::SwiGLU::GluType::Gelu) { + if (params.glu_type == ov::op::internal::GLU::GluType::Gelu) { jit.AddConstants({MakeJitConstant("GEGLU_HALF", "0.5" + type_suffix)}); jit.AddConstants({MakeJitConstant("GEGLU_MULT", "0.7071067811865475" + type_suffix)}); - } else if (params.glu_type == ov::op::internal::SwiGLU::GluType::Gelu_Tanh) { + } else if (params.glu_type == ov::op::internal::GLU::GluType::Gelu_Tanh) { jit.AddConstants({MakeJitConstant("GEGLU_HALF", "0.5" + type_suffix)}); jit.AddConstants({MakeJitConstant("GEGLU_MULT", "0.044715" + type_suffix)}); jit.AddConstants({MakeJitConstant("GEGLU_SQUARE_2_OVER_PI", "0.79788458347320556640625" + type_suffix)}); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/swiglu/swiglu_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/swiglu/swiglu_kernel_base.h index 73d679c8a643fb..2f5c046690f78d 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/swiglu/swiglu_kernel_base.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/swiglu/swiglu_kernel_base.h @@ -6,7 +6,7 @@ #include "kernel_base_opencl.h" #include "kernel_selector_params.h" -#include "ov_ops/swiglu.hpp" +#include "ov_ops/glu.hpp" namespace kernel_selector { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -14,10 +14,10 @@ namespace kernel_selector { /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// struct swiglu_params : public base_params { swiglu_params() : base_params(KernelType::SWIGLU), axis(0), split_length(0), - glu_type(ov::op::internal::SwiGLU::GluType::Swish), split_to_glu_idx(0) {} + glu_type(ov::op::internal::GLU::GluType::Swish), split_to_glu_idx(0) {} int32_t axis; int32_t split_length; - ov::op::internal::SwiGLU::GluType glu_type; + ov::op::internal::GLU::GluType glu_type; int32_t split_to_glu_idx; }; diff --git a/src/plugins/intel_gpu/src/plugin/ops/swiglu.cpp b/src/plugins/intel_gpu/src/plugin/ops/swiglu.cpp index 32d2f296670a91..5df2cafd41a41f 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/swiglu.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/swiglu.cpp @@ -6,14 +6,14 @@ #include "intel_gpu/plugin/common_utils.hpp" #include "intel_gpu/primitives/swiglu.hpp" -#include "ov_ops/swiglu.hpp" +#include "ov_ops/glu.hpp" -using SwiGLU = ov::op::internal::SwiGLU; +using GLU = ov::op::internal::GLU; namespace ov { namespace intel_gpu { -static void CreateSwiGLUOp(ProgramBuilder& p, const std::shared_ptr& op) { +static void CreateGLUOp(ProgramBuilder& p, const std::shared_ptr& op) { validate_inputs_count(op, {1}); auto inputs = p.GetInputInfo(op); std::string primitive_name = layer_type_name_ID(op); @@ -41,7 +41,7 @@ static void CreateSwiGLUOp(ProgramBuilder& p, const std::shared_ptr& op) } } -REGISTER_FACTORY_IMPL(internal, SwiGLU); +REGISTER_FACTORY_IMPL(internal, GLU); } // namespace intel_gpu } // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 01696615e545f8..f4ec7afb5c3d1e 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -93,7 +93,7 @@ #include "transformations/common_optimizations/move_eltwise_up_data_movement.hpp" #include "transformations/common_optimizations/mvn_fusion.hpp" #include "transformations/common_optimizations/softmax_fusion.hpp" -#include "transformations/common_optimizations/swiglu_fusion.hpp" +#include "transformations/common_optimizations/glu_fusion.hpp" #include "transformations/common_optimizations/transpose_sinking.hpp" #include "transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp" #include "transformations/common_optimizations/wrap_interpolate_into_transposes.hpp" @@ -943,7 +943,7 @@ void TransformationsPipeline::apply(std::shared_ptr func) { } manager.register_pass(); - manager.register_pass(); + manager.register_pass(); manager.register_pass(); auto kv_cache_compression_dt = config.get_property(ov::hint::kv_cache_precision); diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/swiglu_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/swiglu_gpu_test.cpp index 0d96a165108972..11bca6e27ba942 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/swiglu_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/swiglu_gpu_test.cpp @@ -7,7 +7,7 @@ #include #include #include -#include "ov_ops/swiglu.hpp" +#include "ov_ops/glu.hpp" #include "swiglu_inst.h" using namespace cldnn; @@ -64,7 +64,7 @@ TEST(swiglu_gpu_test, swiglu_test_bfyx_dyn) { topology topology; topology.add(input_layout("input", input_layout_dynamic)); - topology.add(swiglu("swiglu", input_info("input"), -1, 3, ov::op::internal::SwiGLU::GluType::Swish, 0, tensor())); + topology.add(swiglu("swiglu", input_info("input"), -1, 3, ov::op::internal::GLU::GluType::Swish, 0, tensor())); ExecutionConfig config = get_test_default_config(engine); config.set_property(ov::intel_gpu::allow_new_shape_infer(true));