From 5d2317d5f5c3cd037fd59b300f9e0ab624decde4 Mon Sep 17 00:00:00 2001 From: Katarzyna Mitrus Date: Tue, 3 Dec 2024 10:40:25 +0100 Subject: [PATCH] [ShapeInfer][Op] Internal GLU op - common shape_infer (#27750) ### Details: - Align internal GLU shape_infer - Move shape_infer from the op to the shape_inference directory - Update shape_infer to use template as a shape type - Remove VariadicSplit object creation to call variadic split shape_infer - Register GLU shape_infer for CPU - Update GPU calc_output to use common GLU shape_infer ### Tickets: - 157615 --- .../transformations/include/ov_ops/glu.hpp | 4 -- src/common/transformations/src/ov_ops/glu.cpp | 26 ++--------- .../include/glu_shape_inference.hpp | 34 ++++++++++++++ .../variadic_split_shape_inference.hpp | 14 ++++-- .../src/shape_inference/shape_inference.cpp | 2 + .../glu_shape_inference_test.cpp | 46 +++++++++++++++++++ src/plugins/intel_gpu/src/graph/swiglu.cpp | 7 +-- 7 files changed, 99 insertions(+), 34 deletions(-) create mode 100644 src/core/shape_inference/include/glu_shape_inference.hpp create mode 100644 src/plugins/intel_cpu/tests/unit/shape_inference_test/glu_shape_inference_test.cpp diff --git a/src/common/transformations/include/ov_ops/glu.hpp b/src/common/transformations/include/ov_ops/glu.hpp index 760641978b574d..add8c3a0582525 100644 --- a/src/common/transformations/include/ov_ops/glu.hpp +++ b/src/common/transformations/include/ov_ops/glu.hpp @@ -75,10 +75,6 @@ class TRANSFORMATIONS_API GLU : public ov::op::Op { ov::element::Type m_output_type{}; }; -// TODO 157615: Move to shape_inference -TRANSFORMATIONS_API std::vector shape_infer(const GLU* op, - std::vector input_shapes); - } // namespace internal } // namespace op } // namespace ov diff --git a/src/common/transformations/src/ov_ops/glu.cpp b/src/common/transformations/src/ov_ops/glu.cpp index bc3dfb89ab8b9b..9b5fb780d36bb8 100644 --- a/src/common/transformations/src/ov_ops/glu.cpp +++ b/src/common/transformations/src/ov_ops/glu.cpp @@ -4,10 +4,9 @@ #include "ov_ops/glu.hpp" +#include "glu_shape_inference.hpp" #include "openvino/core/partial_shape.hpp" #include "openvino/core/validation_util.hpp" -#include "openvino/op/variadic_split.hpp" -#include "variadic_split_shape_inference.hpp" namespace ov { namespace op { @@ -38,11 +37,9 @@ bool GLU::visit_attributes(ov::AttributeVisitor& visitor) { void GLU::validate_and_infer_types() { auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type; - std::vector input_shapes = {get_input_partial_shape(0), - ov::PartialShape(ov::Shape{}), - ov::PartialShape(ov::Shape{2})}; - - set_output_type(0, output_type, shape_infer(this, input_shapes)[0]); + const auto input_shapes = ov::util::get_node_input_partial_shapes(*this); + const auto output_shapes = shape_infer(this, input_shapes); + set_output_type(0, output_type, output_shapes[0]); } std::shared_ptr GLU::clone_with_new_inputs(const ov::OutputVector& new_args) const { @@ -54,21 +51,6 @@ std::shared_ptr GLU::clone_with_new_inputs(const ov::OutputVector& new_arg m_split_to_glu_idx, m_output_type); } - -std::vector shape_infer(const GLU* op, std::vector input_shapes) { - ov::op::v1::VariadicSplit variadic_split; - std::vector axis = {op->get_axis()}; - std::vector split_lengths = {op->get_split_lengths(), -1}; - - std::unordered_map const_data; - const_data.emplace(1, ov::Tensor(ov::element::i64, ov::Shape{}, static_cast(axis.data()))); - const_data.emplace( - 2, - ov::Tensor(ov::element::i64, ov::Shape{split_lengths.size()}, static_cast(split_lengths.data()))); - - return ov::op::v1::shape_infer(&variadic_split, input_shapes, ov::make_tensor_accessor(const_data)); -} - } // namespace internal } // namespace op } // namespace ov diff --git a/src/core/shape_inference/include/glu_shape_inference.hpp b/src/core/shape_inference/include/glu_shape_inference.hpp new file mode 100644 index 00000000000000..365b57244036a2 --- /dev/null +++ b/src/core/shape_inference/include/glu_shape_inference.hpp @@ -0,0 +1,34 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "ov_ops/glu.hpp" +#include "utils.hpp" +#include "variadic_split_shape_inference.hpp" + +namespace ov { +namespace op { +namespace internal { +template > +std::vector shape_infer(const GLU* op, const std::vector& input_shapes) { + const auto inputs_count = input_shapes.size(); + NODE_SHAPE_INFER_CHECK(op, input_shapes, inputs_count == 1); + + int64_t axis = op->get_axis(); + std::vector split_lengths = {op->get_split_lengths(), -1}; + std::unordered_map const_data; + const_data.emplace(1, ov::Tensor(ov::element::i64, ov::Shape{}, &axis)); + const_data.emplace(2, ov::Tensor(ov::element::i64, ov::Shape{split_lengths.size()}, split_lengths.data())); + + const ov::Shape split_len_size{split_lengths.size()}; + const ov::Shape scalar{}; + std::vector variadic_split_input_shapes{input_shapes[0], scalar, split_len_size}; + + return {std::move( + ov::op::variadic_split::shape_infer(op, variadic_split_input_shapes, ov::make_tensor_accessor(const_data))[0])}; +} +} // namespace internal +} // namespace op +} // namespace ov diff --git a/src/core/shape_inference/include/variadic_split_shape_inference.hpp b/src/core/shape_inference/include/variadic_split_shape_inference.hpp index a0eff51f238e61..e0cd837003a331 100644 --- a/src/core/shape_inference/include/variadic_split_shape_inference.hpp +++ b/src/core/shape_inference/include/variadic_split_shape_inference.hpp @@ -10,10 +10,9 @@ namespace ov { namespace op { -namespace v1 { - +namespace variadic_split { template > -std::vector shape_infer(const VariadicSplit* op, +std::vector shape_infer(const Node* op, const std::vector& input_shapes, const ITensorAccessor& ta = make_tensor_accessor()) { constexpr bool is_dynamic_shape = std::is_base_of::value; @@ -120,6 +119,15 @@ std::vector shape_infer(const VariadicSplit* op, } return output_shapes; } +} // namespace variadic_split + +namespace v1 { +template > +std::vector shape_infer(const VariadicSplit* op, + const std::vector& input_shapes, + const ITensorAccessor& ta = make_tensor_accessor()) { + return op::variadic_split::shape_infer(op, input_shapes, ta); +} } // namespace v1 } // namespace op diff --git a/src/plugins/intel_cpu/src/shape_inference/shape_inference.cpp b/src/plugins/intel_cpu/src/shape_inference/shape_inference.cpp index b99e1bc62c4b11..2dccce257ae116 100644 --- a/src/plugins/intel_cpu/src/shape_inference/shape_inference.cpp +++ b/src/plugins/intel_cpu/src/shape_inference/shape_inference.cpp @@ -56,6 +56,7 @@ #include "gather_nd_shape_inference.hpp" #include "gather_shape_inference.hpp" #include "gather_tree_shape_inference.hpp" +#include "glu_shape_inference.hpp" #include "grid_sample_shape_inference.hpp" #include "group_convolution_backprop_shape_inference.hpp" #include "group_convolution_shape_inference.hpp" @@ -575,6 +576,7 @@ const IStaticShapeInferFactory::TRegistry IStaticShapeInferFactory::registry{ _OV_OP_SHAPE_INFER_MASK_REG(ov::op::internal::AUGRUCell, ShapeInferTA, util::bit::mask()), _OV_OP_SHAPE_INFER_MASK_REG(ov::op::internal::AUGRUSequence, ShapeInferTA, util::bit::mask()), _OV_OP_SHAPE_INFER_MASK_REG(ov::op::internal::RMSNorm, ShapeInferTA, util::bit::mask(1)), + _OV_OP_SHAPE_INFER_MASK_REG(ov::op::internal::GLU, ShapeInferTA, util::bit::mask()), }; // clang-format on diff --git a/src/plugins/intel_cpu/tests/unit/shape_inference_test/glu_shape_inference_test.cpp b/src/plugins/intel_cpu/tests/unit/shape_inference_test/glu_shape_inference_test.cpp new file mode 100644 index 00000000000000..f7647d52dc5bae --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/shape_inference_test/glu_shape_inference_test.cpp @@ -0,0 +1,46 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "common_test_utils/test_assertions.hpp" +#include "ov_ops/glu.hpp" +#include "utils.hpp" + +using namespace ov; +using namespace ov::intel_cpu; +using ov::op::v0::Constant; +using ov::op::v0::Parameter; +using testing::HasSubstr; + +TEST(StaticShapeInferenceTest, GLUStaticShapeInferenceTestDefaultCtor) { + constexpr int64_t axis = -1; + constexpr int64_t split_lengths = 48; + + const auto op = std::make_shared(); + const auto data = std::make_shared(element::f16, PartialShape::dynamic()); + + op->set_arguments(ov::OutputVector{data}); + op->set_axis(axis); + op->set_split_lengths(split_lengths); + + std::vector static_input_shapes = {StaticShape{20, 1, 96}}; + const auto static_output_shapes = shape_inference(op.get(), static_input_shapes); + ASSERT_EQ(static_output_shapes.size(), 1); + EXPECT_EQ(static_output_shapes[0], StaticShape({20, 1, 48})); +} + +TEST(StaticShapeInferenceTest, GLUStaticShapeInferenceTestBasic) { + constexpr int64_t axis = -1; + constexpr int64_t split_lengths = 48; + const auto glu_type = ov::op::internal::GLU::GluType::Swish; + + const auto data = std::make_shared(element::f16, PartialShape::dynamic()); + const auto op = std::make_shared(data, axis, split_lengths, glu_type, 1); + + std::vector static_input_shapes = {StaticShape{20, 1, 96}}; + const auto static_output_shapes = shape_inference(op.get(), static_input_shapes); + ASSERT_EQ(static_output_shapes.size(), 1); + EXPECT_EQ(static_output_shapes[0], StaticShape({20, 1, 48})); +} diff --git a/src/plugins/intel_gpu/src/graph/swiglu.cpp b/src/plugins/intel_gpu/src/graph/swiglu.cpp index e82e4e974b1868..ffd5333318cee4 100644 --- a/src/plugins/intel_gpu/src/graph/swiglu.cpp +++ b/src/plugins/intel_gpu/src/graph/swiglu.cpp @@ -3,6 +3,7 @@ // #include "ov_ops/glu.hpp" +#include "glu_shape_inference.hpp" #include "swiglu_inst.h" #include "primitive_type_base.h" @@ -32,11 +33,7 @@ std::vector swiglu_inst::calc_output_layouts(swiglu_node const& /*node*/ op.set_axis(desc->axis); op.set_split_lengths(desc->split_lengths); - std::vector input_shapes = { - impl_param.get_input_layout(0).get(), - ShapeType(ov::Shape({})), - ShapeType(ov::Shape{2}) - }; + std::vector input_shapes = {impl_param.get_input_layout(0).get()}; std::vector output_shapes = shape_infer(&op, input_shapes);