Skip to content

Commit

Permalink
[ShapeInfer][Op] Internal GLU op - common shape_infer (#27750)
Browse files Browse the repository at this point in the history
### Details:
 - Align internal GLU shape_infer
 - Move shape_infer from the op to the shape_inference directory 
 - Update shape_infer to use template as a shape type
- Remove VariadicSplit object creation to call variadic split
shape_infer
 - Register GLU shape_infer for CPU
 - Update GPU calc_output to use common GLU shape_infer
 
### Tickets:
 - 157615
  • Loading branch information
mitruska authored Dec 3, 2024
1 parent f5d2214 commit 5d2317d
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 34 deletions.
4 changes: 0 additions & 4 deletions src/common/transformations/include/ov_ops/glu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,6 @@ class TRANSFORMATIONS_API GLU : public ov::op::Op {
ov::element::Type m_output_type{};
};

// TODO 157615: Move to shape_inference
TRANSFORMATIONS_API std::vector<ov::PartialShape> shape_infer(const GLU* op,
std::vector<ov::PartialShape> input_shapes);

} // namespace internal
} // namespace op
} // namespace ov
26 changes: 4 additions & 22 deletions src/common/transformations/src/ov_ops/glu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

#include "ov_ops/glu.hpp"

#include "glu_shape_inference.hpp"
#include "openvino/core/partial_shape.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/variadic_split.hpp"
#include "variadic_split_shape_inference.hpp"

namespace ov {
namespace op {
Expand Down Expand Up @@ -38,11 +37,9 @@ bool GLU::visit_attributes(ov::AttributeVisitor& visitor) {
void GLU::validate_and_infer_types() {
auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type;

std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0),
ov::PartialShape(ov::Shape{}),
ov::PartialShape(ov::Shape{2})};

set_output_type(0, output_type, shape_infer(this, input_shapes)[0]);
const auto input_shapes = ov::util::get_node_input_partial_shapes(*this);
const auto output_shapes = shape_infer(this, input_shapes);
set_output_type(0, output_type, output_shapes[0]);
}

std::shared_ptr<Node> GLU::clone_with_new_inputs(const ov::OutputVector& new_args) const {
Expand All @@ -54,21 +51,6 @@ std::shared_ptr<Node> GLU::clone_with_new_inputs(const ov::OutputVector& new_arg
m_split_to_glu_idx,
m_output_type);
}

std::vector<ov::PartialShape> shape_infer(const GLU* op, std::vector<ov::PartialShape> input_shapes) {
ov::op::v1::VariadicSplit variadic_split;
std::vector<int64_t> axis = {op->get_axis()};
std::vector<int64_t> split_lengths = {op->get_split_lengths(), -1};

std::unordered_map<size_t, ov::Tensor> const_data;
const_data.emplace(1, ov::Tensor(ov::element::i64, ov::Shape{}, static_cast<void*>(axis.data())));
const_data.emplace(
2,
ov::Tensor(ov::element::i64, ov::Shape{split_lengths.size()}, static_cast<void*>(split_lengths.data())));

return ov::op::v1::shape_infer(&variadic_split, input_shapes, ov::make_tensor_accessor(const_data));
}

} // namespace internal
} // namespace op
} // namespace ov
34 changes: 34 additions & 0 deletions src/core/shape_inference/include/glu_shape_inference.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "ov_ops/glu.hpp"
#include "utils.hpp"
#include "variadic_split_shape_inference.hpp"

namespace ov {
namespace op {
namespace internal {
template <class TShape, class TRShape = result_shape_t<TShape>>
std::vector<TRShape> shape_infer(const GLU* op, const std::vector<TShape>& input_shapes) {
const auto inputs_count = input_shapes.size();
NODE_SHAPE_INFER_CHECK(op, input_shapes, inputs_count == 1);

int64_t axis = op->get_axis();
std::vector<int64_t> split_lengths = {op->get_split_lengths(), -1};
std::unordered_map<size_t, ov::Tensor> const_data;
const_data.emplace(1, ov::Tensor(ov::element::i64, ov::Shape{}, &axis));
const_data.emplace(2, ov::Tensor(ov::element::i64, ov::Shape{split_lengths.size()}, split_lengths.data()));

const ov::Shape split_len_size{split_lengths.size()};
const ov::Shape scalar{};
std::vector<TShape> variadic_split_input_shapes{input_shapes[0], scalar, split_len_size};

return {std::move(
ov::op::variadic_split::shape_infer(op, variadic_split_input_shapes, ov::make_tensor_accessor(const_data))[0])};
}
} // namespace internal
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@

namespace ov {
namespace op {
namespace v1 {

namespace variadic_split {
template <typename T, class TRShape = result_shape_t<T>>
std::vector<TRShape> shape_infer(const VariadicSplit* op,
std::vector<TRShape> shape_infer(const Node* op,
const std::vector<T>& input_shapes,
const ITensorAccessor& ta = make_tensor_accessor()) {
constexpr bool is_dynamic_shape = std::is_base_of<ov::PartialShape, T>::value;
Expand Down Expand Up @@ -120,6 +119,15 @@ std::vector<TRShape> shape_infer(const VariadicSplit* op,
}
return output_shapes;
}
} // namespace variadic_split

namespace v1 {
template <typename T, class TRShape = result_shape_t<T>>
std::vector<TRShape> shape_infer(const VariadicSplit* op,
const std::vector<T>& input_shapes,
const ITensorAccessor& ta = make_tensor_accessor()) {
return op::variadic_split::shape_infer(op, input_shapes, ta);
}

} // namespace v1
} // namespace op
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/shape_inference/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "gather_nd_shape_inference.hpp"
#include "gather_shape_inference.hpp"
#include "gather_tree_shape_inference.hpp"
#include "glu_shape_inference.hpp"
#include "grid_sample_shape_inference.hpp"
#include "group_convolution_backprop_shape_inference.hpp"
#include "group_convolution_shape_inference.hpp"
Expand Down Expand Up @@ -575,6 +576,7 @@ const IStaticShapeInferFactory::TRegistry IStaticShapeInferFactory::registry{
_OV_OP_SHAPE_INFER_MASK_REG(ov::op::internal::AUGRUCell, ShapeInferTA, util::bit::mask()),
_OV_OP_SHAPE_INFER_MASK_REG(ov::op::internal::AUGRUSequence, ShapeInferTA, util::bit::mask()),
_OV_OP_SHAPE_INFER_MASK_REG(ov::op::internal::RMSNorm, ShapeInferTA, util::bit::mask(1)),
_OV_OP_SHAPE_INFER_MASK_REG(ov::op::internal::GLU, ShapeInferTA, util::bit::mask()),
};
// clang-format on

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include "common_test_utils/test_assertions.hpp"
#include "ov_ops/glu.hpp"
#include "utils.hpp"

using namespace ov;
using namespace ov::intel_cpu;
using ov::op::v0::Constant;
using ov::op::v0::Parameter;
using testing::HasSubstr;

TEST(StaticShapeInferenceTest, GLUStaticShapeInferenceTestDefaultCtor) {
constexpr int64_t axis = -1;
constexpr int64_t split_lengths = 48;

const auto op = std::make_shared<op::internal::GLU>();
const auto data = std::make_shared<Parameter>(element::f16, PartialShape::dynamic());

op->set_arguments(ov::OutputVector{data});
op->set_axis(axis);
op->set_split_lengths(split_lengths);

std::vector<StaticShape> static_input_shapes = {StaticShape{20, 1, 96}};
const auto static_output_shapes = shape_inference(op.get(), static_input_shapes);
ASSERT_EQ(static_output_shapes.size(), 1);
EXPECT_EQ(static_output_shapes[0], StaticShape({20, 1, 48}));
}

TEST(StaticShapeInferenceTest, GLUStaticShapeInferenceTestBasic) {
constexpr int64_t axis = -1;
constexpr int64_t split_lengths = 48;
const auto glu_type = ov::op::internal::GLU::GluType::Swish;

const auto data = std::make_shared<Parameter>(element::f16, PartialShape::dynamic());
const auto op = std::make_shared<op::internal::GLU>(data, axis, split_lengths, glu_type, 1);

std::vector<StaticShape> static_input_shapes = {StaticShape{20, 1, 96}};
const auto static_output_shapes = shape_inference(op.get(), static_input_shapes);
ASSERT_EQ(static_output_shapes.size(), 1);
EXPECT_EQ(static_output_shapes[0], StaticShape({20, 1, 48}));
}
7 changes: 2 additions & 5 deletions src/plugins/intel_gpu/src/graph/swiglu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#include "ov_ops/glu.hpp"
#include "glu_shape_inference.hpp"
#include "swiglu_inst.h"

#include "primitive_type_base.h"
Expand Down Expand Up @@ -32,11 +33,7 @@ std::vector<layout> swiglu_inst::calc_output_layouts(swiglu_node const& /*node*/
op.set_axis(desc->axis);
op.set_split_lengths(desc->split_lengths);

std::vector<ShapeType> input_shapes = {
impl_param.get_input_layout(0).get<ShapeType>(),
ShapeType(ov::Shape({})),
ShapeType(ov::Shape{2})
};
std::vector<ShapeType> input_shapes = {impl_param.get_input_layout(0).get<ShapeType>()};

std::vector<ShapeType> output_shapes = shape_infer(&op, input_shapes);

Expand Down

0 comments on commit 5d2317d

Please sign in to comment.