diff --git a/src/plugins/intel_gpu/include/intel_gpu/op/gemm.hpp b/src/plugins/intel_gpu/include/intel_gpu/op/gemm.hpp index 18380badcfe196..0f1e690483119e 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/op/gemm.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/op/gemm.hpp @@ -26,15 +26,30 @@ class Gemm : public ov::op::v0::MatMul { const std::vector& order_c, const ov::element::Type output_type = ov::element::undefined); + Gemm(const ov::Output& A, + const ov::Output& B, + const std::vector& target_shape_a, + const std::vector& target_shape_b, + const std::vector& output_pattern_a, + const std::vector& output_pattern_b, + const std::vector& order_a, + const std::vector& order_b, + const std::vector& order_c, + const ov::element::Type output_type = ov::element::undefined); + bool visit_attributes(ov::AttributeVisitor &visitor) override; void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; - std::vector get_input0_order() const { return m_order_a; } - std::vector get_input1_order() const { return m_order_b; } - std::vector get_output_order() const { return m_order_c; } + std::vector get_input0_broadcast_target_shape() const { return m_target_shape_a; } + std::vector get_input1_broadcast_target_shape() const { return m_target_shape_b; } + std::vector get_input0_reshape_pattern() const { return m_output_pattern_a; } + std::vector get_input1_reshape_pattern() const { return m_output_pattern_b; } + std::vector get_input0_transpose_order() const { return m_order_a; } + std::vector get_input1_transpose_order() const { return m_order_b; } + std::vector get_output_transpose_order() const { return m_order_c; } ov::element::Type get_output_type() const { return m_output_type; } static std::vector default_order(size_t rank) { @@ -44,6 +59,10 @@ class Gemm : public ov::op::v0::MatMul { } protected: + std::vector m_target_shape_a; + std::vector m_target_shape_b; + std::vector m_output_pattern_a; + std::vector m_output_pattern_b; std::vector m_order_a; std::vector m_order_b; std::vector m_order_c; @@ -52,6 +71,10 @@ class Gemm : public ov::op::v0::MatMul { std::vector shape_infer(const Gemm* op, std::vector input_shapes, + const std::vector& target_shape_a, + const std::vector& target_shape_b, + const std::vector& output_pattern_a, + const std::vector& output_pattern_b, const std::vector& order_a, const std::vector& order_b, const std::vector& order_c); diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/gemm.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/gemm.hpp index 41b08e16a0b466..15dd92cd23f6d9 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/gemm.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/gemm.hpp @@ -54,6 +54,10 @@ struct gemm : public primitive_base { : primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}), transpose_input0(transpose_input0 ? 1 : 0), transpose_input1(transpose_input1 ? 1 : 0), + input0_broadcast_target_shape({}), + input1_broadcast_target_shape({}), + input0_reshape_pattern({}), + input1_reshape_pattern({}), alpha(alpha), beta(beta), input_rank(input_rank), @@ -70,9 +74,9 @@ struct gemm : public primitive_base { return order; }; - input0_order = get_transposed_order(input_rank, transpose_input0); - input1_order = get_transposed_order(weight_rank, transpose_input1); - output_order = {}; + input0_transpose_order = get_transposed_order(input_rank, transpose_input0); + input1_transpose_order = get_transposed_order(weight_rank, transpose_input1); + output_transpose_order = {}; } /// @brief Constructs gemm layer. @@ -86,48 +90,60 @@ struct gemm : public primitive_base { gemm(const primitive_id& id, const std::vector& inputs, const data_types data_type, - const std::vector& input0_order = {0, 1, 2, 3}, - const std::vector& input1_order = {0, 1, 2, 3}, - const std::vector& output_order = {}, + const std::vector& input0_broadcast_target_shape = {}, + const std::vector& input1_broadcast_target_shape = {}, + const std::vector& input0_reshape_pattern = {}, + const std::vector& input1_reshape_pattern = {}, + const std::vector& input0_transpose_order = {0, 1, 2, 3}, + const std::vector& input1_transpose_order = {0, 1, 2, 3}, + const std::vector& output_transpose_order = {}, const float alpha = 1.0f, const float beta = 0.0f, const padding& output_padding = padding()) : primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}), - input0_order(input0_order), - input1_order(input1_order), - output_order(output_order), + input0_broadcast_target_shape(input0_broadcast_target_shape), + input1_broadcast_target_shape(input1_broadcast_target_shape), + input0_reshape_pattern(input0_reshape_pattern), + input1_reshape_pattern(input1_reshape_pattern), + input0_transpose_order(input0_transpose_order), + input1_transpose_order(input1_transpose_order), + output_transpose_order(output_transpose_order), alpha(alpha), beta(beta), - input_rank(input0_order.size()), - weight_rank(input1_order.size()) { + input_rank(input0_transpose_order.size()), + weight_rank(input1_transpose_order.size()) { if (inputs.size() != 2 && inputs.size() != 3) { throw std::invalid_argument("Invalid inputs count - gemm expects either two or three inputs"); } - transpose_input0 = get_transpose_mode(input0_order); - transpose_input1 = get_transpose_mode(input1_order); + transpose_input0 = get_transpose_mode(input0_transpose_order); + transpose_input1 = get_transpose_mode(input1_transpose_order); } gemm(const primitive_id& id, const std::vector& inputs, const input_info& beam_table, const data_types data_type, - const std::vector& input0_order, - const std::vector& input1_order, - const std::vector& output_order, + const std::vector& input0_transpose_order, + const std::vector& input1_transpose_order, + const std::vector& output_transpose_order, bool indirect_a, bool indirect_b, const float alpha = 1.0f, const float beta = 0.0f, const padding& output_padding = padding()) : primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}), - input0_order(input0_order), - input1_order(input1_order), - output_order(output_order), + input0_broadcast_target_shape({}), + input1_broadcast_target_shape({}), + input0_reshape_pattern({}), + input1_reshape_pattern({}), + input0_transpose_order(input0_transpose_order), + input1_transpose_order(input1_transpose_order), + output_transpose_order(output_transpose_order), alpha(alpha), beta(beta), - input_rank(input0_order.size()), - weight_rank(input1_order.size()), + input_rank(input0_transpose_order.size()), + weight_rank(input1_transpose_order.size()), beam_table(beam_table), indirect_a(indirect_a), indirect_b(indirect_b) { @@ -135,20 +151,28 @@ struct gemm : public primitive_base { throw std::invalid_argument("Invalid inputs count - gemm expects either two or three inputs"); } - transpose_input0 = get_transpose_mode(input0_order); - transpose_input1 = get_transpose_mode(input1_order); + transpose_input0 = get_transpose_mode(input0_transpose_order); + transpose_input1 = get_transpose_mode(input1_transpose_order); } /// @brief Flag for transposing first input matrix uint32_t transpose_input0 = 0; /// @brief Flag for transposing second input matrix uint32_t transpose_input1 = 0; + /// @brief broadcasted target shape of input 0 + std::vector input0_broadcast_target_shape; + /// @brief broadcasted target shape of input 1 + std::vector input1_broadcast_target_shape; + /// @brief reshaped output pattern of input 0 + std::vector input0_reshape_pattern; + /// @brief reshaped output pattern of input 1 + std::vector input1_reshape_pattern; /// @brief order of input 0 - std::vector input0_order; + std::vector input0_transpose_order; /// @brief order of input 1 - std::vector input1_order; + std::vector input1_transpose_order; /// @brief order of output - std::vector output_order; + std::vector output_transpose_order; /// @brief Variable containing ALPHA parameter float alpha = 1.0f; /// @brief Variable containing BETA parameter @@ -169,12 +193,13 @@ struct gemm : public primitive_base { seed = hash_combine(seed, transpose_input1); seed = hash_combine(seed, indirect_a); seed = hash_combine(seed, indirect_b); - for (auto order : input0_order) - seed = hash_combine(seed, order); - for (auto order : input1_order) - seed = hash_combine(seed, order); - for (auto order : output_order) - seed = hash_combine(seed, order); + seed = hash_range(seed, input0_broadcast_target_shape.begin(), input0_broadcast_target_shape.end()); + seed = hash_range(seed, input1_broadcast_target_shape.begin(), input1_broadcast_target_shape.end()); + seed = hash_range(seed, input0_reshape_pattern.begin(), input0_reshape_pattern.end()); + seed = hash_range(seed, input1_reshape_pattern.begin(), input1_reshape_pattern.end()); + seed = hash_range(seed, input0_transpose_order.begin(), input0_transpose_order.end()); + seed = hash_range(seed, input1_transpose_order.begin(), input1_transpose_order.end()); + seed = hash_range(seed, output_transpose_order.begin(), output_transpose_order.end()); seed = hash_combine(seed, alpha); seed = hash_combine(seed, beta); return seed; @@ -200,9 +225,13 @@ struct gemm : public primitive_base { primitive_base::save(ob); ob << transpose_input0; ob << transpose_input1; - ob << input0_order; - ob << input1_order; - ob << output_order; + ob << input0_broadcast_target_shape; + ob << input1_broadcast_target_shape; + ob << input0_reshape_pattern; + ob << input1_reshape_pattern; + ob << input0_transpose_order; + ob << input1_transpose_order; + ob << output_transpose_order; ob << alpha; ob << beta; ob << input_rank; @@ -217,9 +246,13 @@ struct gemm : public primitive_base { primitive_base::load(ib); ib >> transpose_input0; ib >> transpose_input1; - ib >> input0_order; - ib >> input1_order; - ib >> output_order; + ib >> input0_broadcast_target_shape; + ib >> input1_broadcast_target_shape; + ib >> input0_reshape_pattern; + ib >> input1_reshape_pattern; + ib >> input0_transpose_order; + ib >> input1_transpose_order; + ib >> output_transpose_order; ib >> alpha; ib >> beta; ib >> input_rank; diff --git a/src/plugins/intel_gpu/src/graph/gemm.cpp b/src/plugins/intel_gpu/src/graph/gemm.cpp index a587b514fc0207..49f0fefd0f1ced 100644 --- a/src/plugins/intel_gpu/src/graph/gemm.cpp +++ b/src/plugins/intel_gpu/src/graph/gemm.cpp @@ -10,6 +10,18 @@ #include "intel_gpu/op/gemm.hpp" +namespace { +template ::value>::type> +int find_index_from_vec(const std::vector& vec, const DT value) { + int idx = 0; + for (auto v : vec) { + if (v != static_cast(value)) + break; + idx += 1; + } + return idx; +} +} // namespace namespace cldnn { GPU_DEFINE_PRIMITIVE_TYPE_ID(gemm) @@ -22,8 +34,8 @@ layout gemm_inst::calc_output_layout(gemm_node const& node, kernel_impl_params c auto input0_shape = input0_layout.get_shape(); auto input1_shape = input1_layout.get_shape(); - auto input0_order = prim->input0_order; - auto input1_order = prim->input1_order; + auto input0_transpose_order = prim->input0_transpose_order; + auto input1_transpose_order = prim->input1_transpose_order; bool reordered = prim->input_rank > 4 || prim->weight_rank > 4; size_t output_rank = std::max(prim->input_rank, prim->weight_rank); @@ -60,13 +72,13 @@ layout gemm_inst::calc_output_layout(gemm_node const& node, kernel_impl_params c return shape_transposed; }; - auto input0_shape_update = update_input_shape(input0_shape, input_rank, input0_order, true); - auto input1_shape_update = update_input_shape(input1_shape, weight_rank, input1_order, false); + auto input0_shape_update = update_input_shape(input0_shape, input_rank, input0_transpose_order, true); + auto input1_shape_update = update_input_shape(input1_shape, weight_rank, input1_transpose_order, false); ov::Shape bias_shape(output_rank); if (prim->input_size() == 3) { bias_shape = impl_param.get_input_layout(2).get_shape(); - bias_shape = update_input_shape(bias_shape, weight_rank, input1_order, false); + bias_shape = update_input_shape(bias_shape, weight_rank, input1_transpose_order, false); } auto output_shape = input0_shape_update; @@ -83,8 +95,8 @@ layout gemm_inst::calc_output_layout(gemm_node const& node, kernel_impl_params c size_t ones_to_add = 4 - std::min(output_shape.size(), static_cast(4)); output_shape.insert(output_shape.begin(), ones_to_add, 1); - if (prim->output_order.size() > 0) - output_shape = transpose_shape(output_shape, prim->output_order); + if (prim->output_transpose_order.size() > 0) + output_shape = transpose_shape(output_shape, prim->output_transpose_order); auto output_type = input0_layout.data_type; if ((output_type == data_types::u8 || output_type == data_types::i8) && prim->output_data_types[0]) @@ -125,8 +137,15 @@ std::vector gemm_inst::calc_output_layouts(gemm_node const& node, const input1_layout.get() }; - std::vector output_shapes = ov::intel_gpu::op::shape_infer(&op, input_shapes, - prim->input0_order, prim->input1_order, prim->output_order); + std::vector output_shapes = ov::intel_gpu::op::shape_infer(&op, + input_shapes, + prim->input0_broadcast_target_shape, + prim->input1_broadcast_target_shape, + prim->input0_reshape_pattern, + prim->input1_reshape_pattern, + prim->input0_transpose_order, + prim->input1_transpose_order, + prim->output_transpose_order); cldnn::format output_format = input0_layout.format; if (node.get_preferred_output_fmt() != format::any) @@ -139,58 +158,90 @@ template std::vector gemm_inst::calc_output_layouts(ge std::vector gemm_inst::transform_input_layouts(const std::shared_ptr primitive, const std::vector& input_layouts) { - auto get_updated_input_shape = [&](const ov::PartialShape& input_pshape, size_t input_rank, size_t output_rank, bool transpose, bool first_input) { - ov::PartialShape updated_input_pshape; + auto get_reshaped_input_shape = [&](const ov::PartialShape& input_pshape, + const std::vector& broadcast_target_shape, + const std::vector& reshape_pattern) { + ov::PartialShape reshaped_input_pshape; + + if (broadcast_target_shape.size() > 0 && reshape_pattern.size() > 0) { + std::vector dims(input_pshape); + int idx_recalc = find_index_from_vec(broadcast_target_shape, 1); + int idx_target = find_index_from_vec(reshape_pattern, 0); + if (dims[idx_recalc].is_static() && dims[idx_target].is_static()) { + dims[idx_recalc] *= dims[idx_target]; + } else { + dims[idx_recalc] = ov::Dimension::dynamic(); + } + dims.erase(dims.begin() + idx_target); + reshaped_input_pshape = ov::PartialShape(dims); + } else { + reshaped_input_pshape = input_pshape; + } + return reshaped_input_pshape; + }; + + auto get_transposed_input_shape = [&](const ov::PartialShape& input_pshape, size_t input_rank, size_t output_rank, bool transpose, bool first_input) { + ov::PartialShape transposed_input_pshape; if (input_rank == 1) { if (input_pshape.is_static()) { auto input_shape = input_pshape.to_shape(); - updated_input_pshape = ov::PartialShape{ static_cast(*std::max_element(input_shape.begin(), input_shape.end())) }; + transposed_input_pshape = ov::PartialShape{ static_cast(*std::max_element(input_shape.begin(), input_shape.end())) }; } else { - updated_input_pshape = ov::PartialShape::dynamic(input_rank); + transposed_input_pshape = ov::PartialShape::dynamic(input_rank); } } else { if (input_pshape.is_static()) { OPENVINO_ASSERT(input_pshape.size() >= input_rank, "[GPU] Requested input rank in gemm primitive is greater than actual shape"); std::vector dims(input_pshape.begin(), input_pshape.begin() + input_rank); - updated_input_pshape = ov::PartialShape(dims); + transposed_input_pshape = ov::PartialShape(dims); } else { - updated_input_pshape = input_pshape; + transposed_input_pshape = input_pshape; } } - if (updated_input_pshape.size() == 1) { - first_input ? updated_input_pshape.insert(updated_input_pshape.begin(), 1) - : updated_input_pshape.insert(updated_input_pshape.end(), 1); + if (transposed_input_pshape.size() == 1) { + first_input ? transposed_input_pshape.insert(transposed_input_pshape.begin(), 1) + : transposed_input_pshape.insert(transposed_input_pshape.end(), 1); if (transpose) { - std::swap(updated_input_pshape[0], updated_input_pshape[1]); + std::swap(transposed_input_pshape[0], transposed_input_pshape[1]); } } - size_t ones_to_add = std::max(output_rank, static_cast(4)) - updated_input_pshape.size(); - updated_input_pshape.insert(updated_input_pshape.begin(), ones_to_add, 1ul); + size_t ones_to_add = std::max(output_rank, static_cast(4)) - transposed_input_pshape.size(); + transposed_input_pshape.insert(transposed_input_pshape.begin(), ones_to_add, 1ul); - return updated_input_pshape; + return transposed_input_pshape; }; - auto input0_pshape = input_layouts[0].get_partial_shape(); - auto input1_pshape = input_layouts[1].get_partial_shape(); + auto reshaped_input0_pshape = get_reshaped_input_shape(input_layouts[0].get_partial_shape(), + primitive->input0_broadcast_target_shape, + primitive->input0_reshape_pattern); + auto reshaped_input1_pshape = get_reshaped_input_shape(input_layouts[1].get_partial_shape(), + primitive->input1_broadcast_target_shape, + primitive->input1_reshape_pattern); bool reordered = primitive->input_rank > 4 || primitive->weight_rank > 4; size_t output_rank = std::max(primitive->input_rank, primitive->weight_rank); size_t input_rank = reordered ? output_rank : primitive->input_rank; size_t weight_rank = reordered ? output_rank : primitive->weight_rank; - auto updated_input0_pshape = get_updated_input_shape(input0_pshape, input_rank, output_rank, primitive->transpose_input0, true); - auto updated_input1_pshape = get_updated_input_shape(input1_pshape, weight_rank, output_rank, primitive->transpose_input1, false); + auto transposed_input0_pshape = get_transposed_input_shape(reshaped_input0_pshape, input_rank, output_rank, primitive->transpose_input0, true); + auto transposed_input1_pshape = get_transposed_input_shape(reshaped_input1_pshape, weight_rank, output_rank, primitive->transpose_input1, false); std::vector layouts = input_layouts; - layouts[0].set_partial_shape(updated_input0_pshape); - layouts[1].set_partial_shape(updated_input1_pshape); + layouts[0].set_partial_shape(transposed_input0_pshape); + if (primitive->input0_broadcast_target_shape.size() > input_rank) { + layouts[0].format = format::adjust_to_rank(layouts[0].format, input_rank); + } + layouts[1].set_partial_shape(transposed_input1_pshape); + if (primitive->input1_broadcast_target_shape.size() > weight_rank) { + layouts[1].format = format::adjust_to_rank(layouts[1].format, weight_rank); + } if (primitive->input_size() == 3) { auto bias_pshape = input_layouts[2].get_partial_shape(); - auto updated_bias_pshape = get_updated_input_shape(bias_pshape, weight_rank, output_rank, primitive->transpose_input1, false); + auto updated_bias_pshape = get_transposed_input_shape(bias_pshape, weight_rank, output_rank, primitive->transpose_input1, false); layouts[2].set_partial_shape(updated_bias_pshape); } @@ -213,8 +264,8 @@ layout gemm_inst::transform_output_layout(const std::shared_ptr prim auto updated_output_layout = output_layout; auto output_rank = output_layout.get_partial_shape().size(); if (output_rank < 4) { - ov::PartialShape transposed_input0_pshape = transpose_pshape(input_layouts[0].get_partial_shape(), primitive->input0_order); - ov::PartialShape transposed_input1_pshape = transpose_pshape(input_layouts[1].get_partial_shape(), primitive->input1_order); + ov::PartialShape transposed_input0_pshape = transpose_pshape(input_layouts[0].get_partial_shape(), primitive->input0_transpose_order); + ov::PartialShape transposed_input1_pshape = transpose_pshape(input_layouts[1].get_partial_shape(), primitive->input1_transpose_order); auto M = (transposed_input0_pshape.size() > 1) ? transposed_input0_pshape[transposed_input0_pshape.size() - 2] : transposed_input0_pshape[0]; @@ -238,8 +289,8 @@ layout gemm_inst::transform_output_layout(const std::shared_ptr prim output_pshape[get_spatial_idx(updated_output_layout.format, 0)] = std::move(N); output_pshape[get_spatial_idx(updated_output_layout.format, 1)] = std::move(M); - if (primitive->output_order.size() > 0) { - output_pshape = transpose_pshape(output_pshape, primitive->output_order); + if (primitive->output_transpose_order.size() > 0) { + output_pshape = transpose_pshape(output_pshape, primitive->output_transpose_order); } updated_output_layout.set_partial_shape(output_pshape); diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp index 5e44c78cbb4724..eb38677d03bf95 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp @@ -514,8 +514,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program &p) { } auto gemm_prim = node.get_primitive(); - for (size_t idx = 0; idx < gemm_prim->output_order.size(); ++idx) { - size_t output_order_idx = static_cast(gemm_prim->output_order[idx]); + for (size_t idx = 0; idx < gemm_prim->output_transpose_order.size(); ++idx) { + size_t output_order_idx = static_cast(gemm_prim->output_transpose_order[idx]); if (idx != output_order_idx) { does_support_fusings = false; break; diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp index 2ed48b659d3c38..03124262072955 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp @@ -173,9 +173,13 @@ struct gemm_impl : multi_stage_primitive { params.beta = primitive->beta; params.transpose_input0 = primitive->transpose_input0; params.transpose_input1 = primitive->transpose_input1; - params.input0_order = primitive->input0_order; - params.input1_order = primitive->input1_order; - params.output_order = primitive->output_order; + params.input0_target_shape = primitive->input0_broadcast_target_shape; + params.input1_target_shape = primitive->input1_broadcast_target_shape; + params.input0_output_pattern = primitive->input0_reshape_pattern; + params.input1_output_pattern = primitive->input0_reshape_pattern; + params.input0_order = primitive->input0_transpose_order; + params.input1_order = primitive->input1_transpose_order; + params.output_order = primitive->output_transpose_order; params.indirect_input0 = primitive->indirect_a && indirect; params.indirect_input1 = primitive->indirect_b && indirect; diff --git a/src/plugins/intel_gpu/src/graph/layout_optimizer.cpp b/src/plugins/intel_gpu/src/graph/layout_optimizer.cpp index 0ce52b7e1a3d36..8f2021d39a2b7c 100644 --- a/src/plugins/intel_gpu/src/graph/layout_optimizer.cpp +++ b/src/plugins/intel_gpu/src/graph/layout_optimizer.cpp @@ -937,21 +937,21 @@ static bool is_node_for_onednn(gemm_node const& node) { auto gemm_prim = node.get_primitive(); - for (size_t idx = 0; idx < gemm_prim->output_order.size(); idx++) { - if (idx != static_cast(gemm_prim->output_order[idx])) + for (size_t idx = 0; idx < gemm_prim->output_transpose_order.size(); idx++) { + if (idx != static_cast(gemm_prim->output_transpose_order[idx])) return false; } if (gemm_prim->transpose_input0 > 1 || gemm_prim->transpose_input0 > 1) return false; - for (size_t idx = 0; idx < (gemm_prim->input0_order.size() - 2); idx++) { - if (idx != static_cast(gemm_prim->input0_order[idx])) + for (size_t idx = 0; idx < (gemm_prim->input0_transpose_order.size() - 2); idx++) { + if (idx != static_cast(gemm_prim->input0_transpose_order[idx])) return false; } - for (size_t idx = 0; idx < (gemm_prim->input1_order.size() - 2); idx++) { - if (idx != static_cast(gemm_prim->input1_order[idx])) + for (size_t idx = 0; idx < (gemm_prim->input1_transpose_order.size() - 2); idx++) { + if (idx != static_cast(gemm_prim->input1_transpose_order[idx])) return false; } diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_ref.cl index cfb9a7b4749cb1..e90841d56fd33d 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_ref.cl @@ -10,6 +10,9 @@ // ACCUMULATOR_TYPE [DataType] - type used for intermediate results accumulation. inline uint FUNC(get_input0_index_nt)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#if BROADCAST_INPUT0 + DO_BROADCAST_INPUT0 +#endif #if INPUT0_SIMPLE return GET_DATA_INDEX_6D_SAFE(INPUT0, b, f, w, z, y, x); #else @@ -30,6 +33,9 @@ inline uint FUNC(get_input0_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint } inline uint FUNC(get_input1_index_nt)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#if BROADCAST_INPUT1 + DO_BROADCAST_INPUT1 +#endif #if INPUT1_SIMPLE return GET_DATA_INDEX_6D_SAFE(INPUT1, b, f, w, z, y, x); #else diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_tiled_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_tiled_opt.cl index e9079c6fb395f3..13dab0314ddf23 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_tiled_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_tiled_opt.cl @@ -29,6 +29,9 @@ #endif // TILE_N > SIMD_WIDTH inline uint FUNC(get_input0_index_nt)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#if BROADCAST_INPUT0 + DO_BROADCAST_INPUT0 +#endif #if INPUT0_SIMPLE return GET_DATA_INDEX_6D_SAFE(INPUT0, b, f, w, z, y, x); #else @@ -41,6 +44,9 @@ inline uint FUNC(get_input0_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint } inline uint FUNC(get_input1_index_nt)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x) { +#if BROADCAST_INPUT1 + DO_BROADCAST_INPUT1 +#endif #if INPUT1_SIMPLE return GET_DATA_INDEX_6D_SAFE(INPUT1, b, f, w, z, y, x); #else diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.cpp index 87a3d3bb9d03e6..cb59cfee015e96 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.cpp @@ -211,6 +211,44 @@ JitConstants GemmKernelBase::GetJitConstants(const gemm_params& params) const { jit.AddConstant(MakeJitConstant("BIAS_TERM", 1)); } + auto get_broadcast_input_str = [](const std::vector& target_shape) { + const size_t target_rank = target_shape.size(); + std::vector dims; + if (target_rank == 1) { + dims = {"x"}; + } else if (target_rank == 2) { + dims = {"y", "x"}; + } else if (target_rank == 3) { + dims = {"f", "y", "x"}; + } else if (target_rank == 4) { + dims = {"b", "f", "y", "x"}; + } else if (target_rank == 5) { + dims = {"b", "f", "z", "y", "x"}; + } else if (target_rank == 6) { + dims = {"b", "f", "w", "z", "y", "x"}; + } + int pos = 0; + for (auto ts : target_shape) { + if (ts != 1) + break; + pos += 1; + } + std::string str = dims[pos] + " /= " + std::to_string(target_shape[pos]) + ";"; + return str; + }; + if (params.input0_target_shape.size() > 1) { + jit.AddConstants({ + MakeJitConstant("BROADCAST_INPUT0", true), + MakeJitConstant("DO_BROADCAST_INPUT0", get_broadcast_input_str(params.input0_target_shape)), + }); + } + if (params.input1_target_shape.size() > 1) { + jit.AddConstants({ + MakeJitConstant("BROADCAST_INPUT1", true), + MakeJitConstant("DO_BROADCAST_INPUT1", get_broadcast_input_str(params.input1_target_shape)), + }); + } + jit.AddConstants({ MakeJitConstant("TRANSPOSE_X_LAST", 0), MakeJitConstant("TRANSPOSE_Y_LAST", 1), diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.h index afb02169226eb4..633c8171c99ec8 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.h @@ -19,6 +19,10 @@ struct gemm_params : public base_params { float beta; uint32_t transpose_input0; uint32_t transpose_input1; + std::vector input0_target_shape; + std::vector input1_target_shape; + std::vector input0_output_pattern; + std::vector input1_output_pattern; std::vector input0_order; std::vector input1_order; std::vector output_order; diff --git a/src/plugins/intel_gpu/src/plugin/ops/matmul.cpp b/src/plugins/intel_gpu/src/plugin/ops/matmul.cpp index 6398635fdb2147..d455c1fa839b89 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/matmul.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/matmul.cpp @@ -156,20 +156,26 @@ static void CreateGemmOp(ProgramBuilder& p, const std::shared_ptrget_input_partial_shape(1); auto out_shape = op->get_output_partial_shape(0); - size_t rank_a = shape_a.rank().get_length(); - size_t rank_b = shape_b.rank().get_length(); + size_t rank_a = op->get_input0_reshape_pattern().size() > 0 ? op->get_input0_reshape_pattern().size() + : shape_a.rank().get_length(); + size_t rank_b = op->get_input1_reshape_pattern().size() > 0 ? op->get_input1_reshape_pattern().size() + :shape_b.rank().get_length(); size_t output_rank = out_shape.rank().get_length(); - OPENVINO_ASSERT(rank_a == op->get_input0_order().size(), "[GPU] Length of input0_order is not same as rank of input0"); - OPENVINO_ASSERT(rank_b == op->get_input1_order().size(), "[GPU] Length of input1_order is not same as rank of input1"); - OPENVINO_ASSERT(output_rank == op->get_output_order().size(), "[GPU] Length of output_order is not same as rank of output"); + OPENVINO_ASSERT(rank_a == op->get_input0_transpose_order().size(), "[GPU] Length of input0_order is not same as rank of input0"); + OPENVINO_ASSERT(rank_b == op->get_input1_transpose_order().size(), "[GPU] Length of input1_order is not same as rank of input1"); + OPENVINO_ASSERT(output_rank == op->get_output_transpose_order().size(), "[GPU] Length of output_order is not same as rank of output"); auto gemmPrim = cldnn::gemm(layerName, inputs, cldnn::element_type_to_data_type(op->get_output_element_type(0)), - op->get_input0_order(), - op->get_input1_order(), - op->get_output_order(), + op->get_input0_broadcast_target_shape(), + op->get_input1_broadcast_target_shape(), + op->get_input0_reshape_pattern(), + op->get_input1_reshape_pattern(), + op->get_input0_transpose_order(), + op->get_input1_transpose_order(), + op->get_output_transpose_order(), alpha, beta); @@ -200,9 +206,9 @@ static void CreateIndirectGemmOp(ProgramBuilder& p, const std::shared_ptr{ inputs[0], inputs[1] }, inputs[2], cldnn::element_type_to_data_type(op->get_output_element_type(0)), - op->get_input0_order(), - op->get_input1_order(), - op->get_output_order(), + op->get_input0_transpose_order(), + op->get_input1_transpose_order(), + op->get_output_transpose_order(), op->get_indirect_a(), op->get_indirect_b(), alpha, diff --git a/src/plugins/intel_gpu/src/plugin/transformations/broadcast_reshape_matmul_fusion.cpp b/src/plugins/intel_gpu/src/plugin/transformations/broadcast_reshape_matmul_fusion.cpp new file mode 100644 index 00000000000000..17df3d3d1a7294 --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/transformations/broadcast_reshape_matmul_fusion.cpp @@ -0,0 +1,145 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "broadcast_reshape_matmul_fusion.hpp" + +#include "intel_gpu/op/gemm.hpp" + +#include "openvino/core/rt_info.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "transformations/utils/utils.hpp" + +namespace ov { +namespace intel_gpu { + +BroadcastReshapeMatmulFusion::BroadcastReshapeMatmulFusion() { + using namespace ov::pass::pattern; + + auto not_reshape = [](const ov::Output& output) -> bool { + return std::dynamic_pointer_cast(output.get_node_shared_ptr()) == nullptr; + }; + + auto broadcast_rank_equals_and_has_static_dims = [](const ov::Output& output) -> bool { + return rank_equals(5)(output) && has_static_dims({2, 3}) && consumers_count(1); + }; + + auto reshape_rank_equals_and_has_static_dim = [](const ov::Output& output) -> bool { + return rank_equals(4)(output) && has_static_dim(2) && consumers_count(1); + }; + + auto input_a_m = any_input(not_reshape); + auto input_b_m = any_input(not_reshape); + + auto broadcast_a_target_shape_m = wrap_type(); + auto broadcast_a_m = wrap_type({input_a_m, broadcast_a_target_shape_m}, broadcast_rank_equals_and_has_static_dims); + auto broadcast_b_target_shape_m = wrap_type(); + auto broadcast_b_m = wrap_type({input_b_m, broadcast_b_target_shape_m}, broadcast_rank_equals_and_has_static_dims); + + auto reshape_a_pattern_m = wrap_type(); + auto reshape_a_m = wrap_type({broadcast_a_m, reshape_a_pattern_m}, reshape_rank_equals_and_has_static_dim); + auto reshape_b_pattern_m = wrap_type(); + auto reshape_b_m = wrap_type({broadcast_b_m, reshape_b_pattern_m}, reshape_rank_equals_and_has_static_dim); + + auto matmul_in_a = std::make_shared(OutputVector{input_a_m, reshape_a_m}); + auto matmul_in_b = std::make_shared(OutputVector{input_b_m, reshape_b_m}); + + auto matmul_m = wrap_type({matmul_in_a, matmul_in_b}); + + ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + auto matmul = std::dynamic_pointer_cast(pattern_map.at(matmul_m).get_node_shared_ptr()); + if (!matmul || transformation_callback(m.get_match_root())) { + return false; + } + + auto target_shape_a = std::vector(); + auto target_shape_b = std::vector(); + size_t input_a_output_idx = matmul->get_input_source_output(0).get_index(); + size_t input_b_output_idx = matmul->get_input_source_output(1).get_index(); + auto order_a = matmul->get_input0_transpose_order(); + auto order_b = matmul->get_input1_transpose_order(); + + auto valid_transpose_order = [](const std::vector& order) { + return order.size() == 4 && order[1] == 2; + }; + + auto valid_broadcast_target_shape = [](const std::vector& target_shape) { + return std::count_if(target_shape.begin(), target_shape.end(), [](int32_t s) { return s != 1; }) == 1; + }; + + if (pattern_map.count(broadcast_a_m) > 0) { + if (!valid_transpose_order(order_a)) + return false; + auto broadcast_a = std::dynamic_pointer_cast(pattern_map.at(broadcast_a_m).get_node_shared_ptr()); + if (!broadcast_a || broadcast_a->get_broadcast_spec().m_type != ov::op::BroadcastType::BIDIRECTIONAL) + return false; + auto broadcast_a_target_shape = std::dynamic_pointer_cast(pattern_map.at(broadcast_a_target_shape_m).get_node_shared_ptr()); + target_shape_a = broadcast_a_target_shape->cast_vector(); + if (!valid_broadcast_target_shape(target_shape_a)) + return false; + input_a_output_idx = broadcast_a->get_input_source_output(0).get_index(); + } + if (pattern_map.count(broadcast_b_m) > 0) { + if (!valid_transpose_order(order_b)) + return false; + auto broadcast_b = std::dynamic_pointer_cast(pattern_map.at(broadcast_b_m).get_node_shared_ptr()); + if (!broadcast_b || broadcast_b->get_broadcast_spec().m_type != ov::op::BroadcastType::BIDIRECTIONAL) + return false; + auto broadcast_b_target_shape = std::dynamic_pointer_cast(pattern_map.at(broadcast_b_target_shape_m).get_node_shared_ptr()); + target_shape_b = broadcast_b_target_shape->cast_vector(); + if (!valid_broadcast_target_shape(target_shape_b)) + return false; + input_b_output_idx = broadcast_b->get_input_source_output(0).get_index(); + } + + auto pattern_a = std::vector(); + auto pattern_b = std::vector(); + + auto valid_reshape_pattern = [](const std::vector& pattern) { + return std::count_if(pattern.begin(), pattern.end(), [](int64_t p) { return p == -1; }) == 0; + }; + + if (pattern_map.count(reshape_a_m) > 0) { + auto reshape_a_pattern = std::dynamic_pointer_cast(pattern_map.at(reshape_a_pattern_m).get_node_shared_ptr()); + pattern_a = reshape_a_pattern->cast_vector(); + if (!valid_reshape_pattern(pattern_a)) + return false; + } + if (pattern_map.count(reshape_b_m) > 0) { + auto reshape_b_pattern = std::dynamic_pointer_cast(pattern_map.at(reshape_b_pattern_m).get_node_shared_ptr()); + pattern_b = reshape_b_pattern->cast_vector(); + if (!valid_reshape_pattern(pattern_b)) + return false; + } + + auto input_a = ov::Output(pattern_map.at(input_a_m).get_node_shared_ptr(), input_a_output_idx); + auto input_b = ov::Output(pattern_map.at(input_b_m).get_node_shared_ptr(), input_b_output_idx); + auto order_c = matmul->get_output_transpose_order(); + + auto gemm = std::make_shared(input_a, + input_b, + target_shape_a, + target_shape_b, + pattern_a, + pattern_b, + order_a, + order_b, + order_c); + gemm->set_friendly_name(matmul->get_friendly_name()); + ov::copy_runtime_info(m.get_matched_nodes(), gemm); + ov::replace_node(matmul, gemm); + + return true; + }; + + auto m = std::make_shared(matmul_m, "BroadcastReshapeMatmulFusion"); + this->register_matcher(m, callback); +} + +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations/broadcast_reshape_matmul_fusion.hpp b/src/plugins/intel_gpu/src/plugin/transformations/broadcast_reshape_matmul_fusion.hpp new file mode 100644 index 00000000000000..e3ad540e0a4692 --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/transformations/broadcast_reshape_matmul_fusion.hpp @@ -0,0 +1,19 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" + +namespace ov { +namespace intel_gpu { + +class BroadcastReshapeMatmulFusion : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("BroadcastReshapeMatmulFusion", "0"); + BroadcastReshapeMatmulFusion(); +}; + +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.cpp b/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.cpp index 2a0bac302956c2..14b58e642a8116 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.cpp @@ -85,9 +85,9 @@ IndirectKVCache::IndirectKVCache() { auto matmul_kv_cache_index = kv_cache_users.begin()->get_index(); auto gemm_node = std::dynamic_pointer_cast(m.get_match_root()); - auto order_in0 = gemm_node->get_input0_order(); - auto order_in1 = gemm_node->get_input1_order(); - auto order_out = gemm_node->get_output_order(); + auto order_in0 = gemm_node->get_input0_transpose_order(); + auto order_in1 = gemm_node->get_input1_transpose_order(); + auto order_out = gemm_node->get_output_transpose_order(); auto indirect_gemm = std::make_shared(gemm_node->get_input_node_shared_ptr(0), gemm_node->get_input_node_shared_ptr(1), diff --git a/src/plugins/intel_gpu/src/plugin/transformations/op/gemm.cpp b/src/plugins/intel_gpu/src/plugin/transformations/op/gemm.cpp index 92f23f9b6b6663..45b200baba4ce9 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/op/gemm.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/op/gemm.cpp @@ -4,8 +4,12 @@ #include "intel_gpu/op/gemm.hpp" #include "matmul_shape_inference.hpp" +#include "broadcast_shape_inference.hpp" +#include "reshape_shape_inference.hpp" #include "openvino/core/partial_shape.hpp" #include "openvino/op/matmul.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/reshape.hpp" namespace ov { namespace intel_gpu { @@ -18,6 +22,35 @@ Gemm::Gemm(const ov::Output& A, const std::vector& order_c, const ov::element::Type output_type) : ov::op::v0::MatMul() + , m_target_shape_a({}) + , m_target_shape_b({}) + , m_output_pattern_a({}) + , m_output_pattern_b({}) + , m_order_a(order_a) + , m_order_b(order_b) + , m_order_c(order_c) + , m_output_type(output_type) { + set_arguments({A, B}); + set_transpose_a(false); + set_transpose_b(false); + validate_and_infer_types(); +} + +Gemm::Gemm(const ov::Output& A, + const ov::Output& B, + const std::vector& target_shape_a, + const std::vector& target_shape_b, + const std::vector& output_pattern_a, + const std::vector& output_pattern_b, + const std::vector& order_a, + const std::vector& order_b, + const std::vector& order_c, + const ov::element::Type output_type) + : ov::op::v0::MatMul() + , m_target_shape_a(target_shape_a) + , m_target_shape_b(target_shape_b) + , m_output_pattern_a(output_pattern_a) + , m_output_pattern_b(output_pattern_b) , m_order_a(order_a) , m_order_b(order_b) , m_order_c(order_c) @@ -31,7 +64,16 @@ Gemm::Gemm(const ov::Output& A, std::shared_ptr Gemm::clone_with_new_inputs(const ov::OutputVector& new_args) const { check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), new_args.at(1), m_order_a, m_order_b, m_order_c, m_output_type); + return std::make_shared(new_args.at(0), + new_args.at(1), + m_target_shape_a, + m_target_shape_b, + m_output_pattern_a, + m_output_pattern_b, + m_order_a, + m_order_b, + m_order_c, + m_output_type); } void Gemm::validate_and_infer_types() { @@ -42,7 +84,15 @@ void Gemm::validate_and_infer_types() { input_size, ", expected 2."); - auto out_shapes = shape_infer(this, std::vector{get_input_partial_shape(0), get_input_partial_shape(1)}, m_order_a, m_order_b, m_order_c); + auto out_shapes = shape_infer(this, + std::vector{get_input_partial_shape(0), get_input_partial_shape(1)}, + m_target_shape_a, + m_target_shape_b, + m_output_pattern_a, + m_output_pattern_b, + m_order_a, + m_order_b, + m_order_c); auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type; set_output_type(0, output_type, out_shapes[0]); @@ -58,9 +108,45 @@ bool Gemm::visit_attributes(ov::AttributeVisitor &visitor) { std::vector shape_infer(const Gemm* op, std::vector input_shapes, + const std::vector& target_shape_a, + const std::vector& target_shape_b, + const std::vector& output_pattern_a, + const std::vector& output_pattern_b, const std::vector& order_a, const std::vector& order_b, const std::vector& order_c) { + auto shape_a = input_shapes[0]; + auto shape_b = input_shapes[1]; + + // broadcasted shapes + auto broadcast_shape = [](const ov::PartialShape shape, const std::vector& target_shape) { + ov::op::v3::Broadcast broadcast; + auto tshape = target_shape; + broadcast.set_broadcast_spec(ov::op::BroadcastType::BIDIRECTIONAL); + std::unordered_map const_data; + const_data.emplace(1, ov::Tensor(ov::element::i32, ov::Shape{tshape.size()}, static_cast(tshape.data()))); + return ov::op::v3::shape_infer(&broadcast, + std::vector{shape, ov::PartialShape(ov::Shape{tshape.size()})}, + ov::make_tensor_accessor(const_data)); + }; + auto shape_a_b = (target_shape_a.size() > 1) ? broadcast_shape(shape_a, target_shape_a)[0] : shape_a; + auto shape_b_b = (target_shape_b.size() > 1) ? broadcast_shape(shape_b, target_shape_b)[0] : shape_b; + + // reshaped shapes + auto reshape_shape = [](const ov::PartialShape shape, const std::vector& output_pattern) { + ov::op::v1::Reshape reshape; + auto opattern = output_pattern; + reshape.set_special_zero(true); + std::unordered_map const_data; + const_data.emplace(1, ov::Tensor(ov::element::i64, ov::Shape{opattern.size()}, static_cast(opattern.data()))); + return ov::op::v1::shape_infer(&reshape, + std::vector{shape, ov::PartialShape(ov::Shape{opattern.size()})}, + ov::make_tensor_accessor(const_data)); + }; + auto shape_a_r = (output_pattern_a.size() > 1) ? reshape_shape(shape_a_b, output_pattern_a)[0] : shape_a_b; + auto shape_b_r = (output_pattern_b.size() > 1) ? reshape_shape(shape_b_b, output_pattern_b)[0] : shape_b_b; + + // transposed shapes auto transpose_shape = [](const ov::PartialShape shape, const std::vector& order) { auto shape_transposed = ov::PartialShape::dynamic(shape.rank()); for (size_t i = 0; i < order.size(); i++) { @@ -69,11 +155,8 @@ std::vector shape_infer(const Gemm* op, return shape_transposed; }; - auto shape_a = input_shapes[0]; - auto shape_b = input_shapes[1]; - - auto shape_a_t = (order_a.size() > 1) ? transpose_shape(shape_a, order_a) : shape_a; - auto shape_b_t = (order_b.size() > 1) ? transpose_shape(shape_b, order_b) : shape_b; + auto shape_a_t = (order_a.size() > 1) ? transpose_shape(shape_a_r, order_a) : shape_a_r; + auto shape_b_t = (order_b.size() > 1) ? transpose_shape(shape_b_r, order_b) : shape_b_r; auto out_shapes = ov::op::v0::shape_infer(dynamic_cast(op), std::vector{shape_a_t, shape_b_t}); if (order_c.size() > 0) { diff --git a/src/plugins/intel_gpu/src/plugin/transformations/op/indirect_gemm.cpp b/src/plugins/intel_gpu/src/plugin/transformations/op/indirect_gemm.cpp index 5e35d5cd1fc177..bd557f811e6951 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/op/indirect_gemm.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/op/indirect_gemm.cpp @@ -48,7 +48,15 @@ void IndirectGemm::validate_and_infer_types() { input_size, ", expected 3."); - auto out_shapes = shape_infer(this, std::vector{get_input_partial_shape(0), get_input_partial_shape(1)}, m_order_a, m_order_b, m_order_c); + auto out_shapes = shape_infer(this, + std::vector{get_input_partial_shape(0), get_input_partial_shape(1)}, + m_target_shape_a, + m_target_shape_b, + m_output_pattern_a, + m_output_pattern_b, + m_order_a, + m_order_b, + m_order_c); auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type; set_output_type(0, output_type, out_shapes[0]); diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 03c5118a7a8861..dfd01834eb710b 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -61,6 +61,7 @@ #include "plugin/transformations/transpose_matmul_fusion.hpp" #include "plugin/transformations/indirect_kv_cache.hpp" #include "plugin/transformations/convert_convolution.hpp" +#include "plugin/transformations/broadcast_reshape_matmul_fusion.hpp" #include "transformations/common_optimizations/broadcast_elementwise_fusion.hpp" #include "transformations/common_optimizations/broadcast_transition.hpp" #include "transformations/common_optimizations/common_optimizations.hpp" @@ -723,6 +724,8 @@ void TransformationsPipeline::apply(std::shared_ptr func) { manager.register_pass(); manager.register_pass(); + if (!device_info.supports_immad) + manager.register_pass(); const size_t zp_pad_size = 32; manager.register_pass(zp_pad_size); diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp index dc7ee7e120b586..a5b524e507f40e 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp @@ -10,6 +10,7 @@ #include #include "openvino/reference/matmul.hpp" #include "openvino/reference/transpose.hpp" +#include "openvino/reference/reshape.hpp" #include "intel_gpu/runtime/compilation_context.hpp" #include "gemm_inst.h" @@ -837,6 +838,285 @@ class gemm_gpu_tests: public ::testing::Test { } } + void test_broadcast_transpose_matmul(bool is_caching_test) { + tests::random_generator rg; + rg.set_seed(GET_SUITE_NAME); + + const unsigned long BATCH_SIZE = 1; + const unsigned long M_SIZE = 1; + const unsigned long K_SIZE = 32; + const unsigned long N_SIZE = 21; + + auto fill_mem = [&](cldnn::memory_ptr mem, std::vector& data) { + cldnn::mem_lock mem_ptr(mem, get_test_stream()); + auto&& l = mem->get_layout(); + auto data_idx = 0; + for (cldnn::tensor::value_type b = 0; b < l.batch(); ++b) { + for (cldnn::tensor::value_type f = 0; f < l.feature(); ++f) { + for (cldnn::tensor::value_type y = 0; y < l.spatial(1); ++y) { + for (cldnn::tensor::value_type x = 0; x < l.spatial(0); ++x) { + auto tensor_coord = cldnn::tensor{{b, f, x, y}, 0}; + auto buffer_idx = l.get_linear_offset(tensor_coord); + mem_ptr[buffer_idx] = data[data_idx++]; + } + } + } + } + }; + + auto& engine = get_test_engine(); + ov::Shape input0_shape; + ov::Shape input1_shape; + std::vector input1_target_shape; + std::vector input0_order; + std::vector input1_order; + ov::Shape beam_table_shape; + cldnn::layout input0_layout; + cldnn::layout input1_layout; + + input0_shape = { BATCH_SIZE, 16, M_SIZE, K_SIZE }; + input1_shape = { N_SIZE, BATCH_SIZE, 1, K_SIZE }; + input1_target_shape = { 1, 1, 16, 1 }; + input0_order = { 0, 1, 2, 3 }; + input1_order = { 1, 2, 3, 0 }; + + input0_layout = layout{ov::PartialShape::dynamic(input0_shape.size()), data_types::f32, format::bfyx}; + input1_layout = layout{ov::PartialShape::dynamic(input1_shape.size()), data_types::f32, format::bfyx}; + + auto input0_mem = engine.allocate_memory(layout{ov::PartialShape(input0_shape), data_types::f32, format::bfyx}); + auto input1_mem = engine.allocate_memory(layout{ov::PartialShape(input1_shape), data_types::f32, format::bfyx}); + + auto input_0_data = rg.generate_random_1d(ov::shape_size(input0_shape), -2, 2); + auto input_1_data = rg.generate_random_1d(ov::shape_size(input1_shape), -2, 2); + + fill_mem(input0_mem, input_0_data); + fill_mem(input1_mem, input_1_data); + + topology topology; + topology.add(input_layout("input0", input0_layout), + input_layout("input1", input1_layout), + gemm("gemm", { input_info("input0"), input_info("input1") }, data_types::f32, {}, input1_target_shape, {}, {}, input0_order, input1_order) + ); + + ExecutionConfig config = get_test_default_config(engine); + config.set_property(ov::intel_gpu::optimize_data(true)); + config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); + network::ptr network = get_network(engine, topology, config, get_test_stream_ptr(), is_caching_test); + network->set_input_data("input0", input0_mem); + network->set_input_data("input1", input1_mem); + + auto inst = network->get_primitive("gemm"); + auto impl = inst->get_impl(); + ASSERT_TRUE(impl != nullptr); + + auto outputs = network->execute(); + + auto output_mem = outputs.at("gemm").get_memory(); + cldnn::mem_lock output_ptr(output_mem, get_test_stream()); + + ov::Shape ref_input0_shape; + ov::Shape ref_input1_broadcasted_shape; + ov::Shape ref_input1_shape; + ov::Shape ref_output_shape; + + ref_input0_shape = { BATCH_SIZE, 16, M_SIZE, K_SIZE }; + ref_input1_broadcasted_shape = { N_SIZE, BATCH_SIZE, 16, K_SIZE }; + ref_input1_shape = { BATCH_SIZE, 16, K_SIZE, N_SIZE }; + ref_output_shape = { BATCH_SIZE, 16, M_SIZE, N_SIZE }; + + std::vector ref_out_data; + ref_out_data.resize(ov::shape_size(ref_output_shape)); + + std::vector ref_input_0_data(input_0_data.size()); + std::vector ref_input_1_broadcasted_data(ov::shape_size(ref_input1_broadcasted_shape)); + std::vector ref_input_1_data(ref_input_1_broadcasted_data.size()); + + ov::reference::transpose((const char *)(input_0_data.data()), + (char *)(ref_input_0_data.data()), + input0_shape, + sizeof(float), + input0_order, + ref_input0_shape); + + ov::reference::broadcast(reinterpret_cast(input_1_data.data()), + reinterpret_cast(ref_input_1_broadcasted_data.data()), + input1_shape, + ref_input1_broadcasted_shape, + ov::AxisSet({}), + sizeof(float)); + + ov::reference::transpose((const char *)(ref_input_1_broadcasted_data.data()), + (char *)(ref_input_1_data.data()), + ref_input1_broadcasted_shape, + sizeof(float), + input1_order, + ref_input1_shape); + + ov::reference::matmul(ref_input_0_data.data(), + ref_input_1_data.data(), + ref_out_data.data(), + ref_input0_shape, + ref_input1_shape, + ref_output_shape, + false, + false); + + ASSERT_EQ(output_ptr.size(), ref_out_data.size()); + + const auto abs_error = 0.0001; + for (uint32_t i = 0; i < ref_out_data.size(); ++i) { + ASSERT_NEAR(output_ptr[i], ref_out_data[i], abs_error) << "at " << i; + } + } + + void test_broadcast_reshape_transpose_matmul(bool is_caching_test) { + tests::random_generator rg; + rg.set_seed(GET_SUITE_NAME); + + const unsigned long BATCH_SIZE = 1; + const unsigned long M_SIZE = 1; + const unsigned long K_SIZE = 32; + const unsigned long N_SIZE = 21; + + auto fill_mem = [&](cldnn::memory_ptr mem, std::vector& data) { + cldnn::mem_lock mem_ptr(mem, get_test_stream()); + auto&& l = mem->get_layout(); + auto data_idx = 0; + for (cldnn::tensor::value_type b = 0; b < l.batch(); ++b) { + for (cldnn::tensor::value_type f = 0; f < l.feature(); ++f) { + for (cldnn::tensor::value_type z = 0; z < l.spatial(2); ++z) { + for (cldnn::tensor::value_type y = 0; y < l.spatial(1); ++y) { + for (cldnn::tensor::value_type x = 0; x < l.spatial(0); ++x) { + auto tensor_coord = cldnn::tensor{{b, f, x, y, z}, 0}; + auto buffer_idx = l.get_linear_offset(tensor_coord); + mem_ptr[buffer_idx] = data[data_idx++]; + } + } + } + } + } + }; + + auto& engine = get_test_engine(); + ov::Shape input0_shape; + ov::Shape input1_shape; + std::vector input1_target_shape; + std::vector input1_output_pattern; + std::vector input0_order; + std::vector input1_order; + ov::Shape beam_table_shape; + cldnn::layout input0_layout; + cldnn::layout input1_layout; + + input0_shape = { BATCH_SIZE, 32, M_SIZE, K_SIZE }; + input1_shape = { N_SIZE, BATCH_SIZE, 2, 1, K_SIZE }; + input1_target_shape = { 1, 1, 1, 16, 1 }; + input1_output_pattern = { 0, 0, 32, K_SIZE }; + input0_order = { 0, 1, 2, 3 }; + input1_order = { 1, 2, 3, 0 }; + + input0_layout = layout{ov::PartialShape::dynamic(input0_shape.size()), data_types::f32, format::bfyx}; + input1_layout = layout{ov::PartialShape::dynamic(input1_shape.size()), data_types::f32, format::bfzyx}; + + auto input0_mem = engine.allocate_memory(layout{ov::PartialShape(input0_shape), data_types::f32, format::bfyx}); + auto input1_mem = engine.allocate_memory(layout{ov::PartialShape(input1_shape), data_types::f32, format::bfzyx}); + + auto input_0_data = rg.generate_random_1d(ov::shape_size(input0_shape), -2, 2); + auto input_1_data = rg.generate_random_1d(ov::shape_size(input1_shape), -2, 2); + + fill_mem(input0_mem, input_0_data); + fill_mem(input1_mem, input_1_data); + + topology topology; + topology.add(input_layout("input0", input0_layout), + input_layout("input1", input1_layout), + gemm("gemm", { input_info("input0"), input_info("input1") }, data_types::f32, {}, input1_target_shape, {}, input1_output_pattern, input0_order, input1_order) + ); + + ExecutionConfig config = get_test_default_config(engine); + config.set_property(ov::intel_gpu::optimize_data(true)); + config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); + network::ptr network = get_network(engine, topology, config, get_test_stream_ptr(), is_caching_test); + network->set_input_data("input0", input0_mem); + network->set_input_data("input1", input1_mem); + + auto inst = network->get_primitive("gemm"); + auto impl = inst->get_impl(); + ASSERT_TRUE(impl != nullptr); + + auto outputs = network->execute(); + + auto output_mem = outputs.at("gemm").get_memory(); + cldnn::mem_lock output_ptr(output_mem, get_test_stream()); + + ov::Shape ref_input0_shape; + ov::Shape ref_input1_broadcasted_shape; + ov::Shape ref_input1_reshaped_shape; + ov::Shape ref_input1_shape; + ov::Shape ref_output_shape; + + ref_input0_shape = { BATCH_SIZE, 32, M_SIZE, K_SIZE }; + ref_input1_broadcasted_shape = { N_SIZE, BATCH_SIZE, 2, 16, K_SIZE }; + ref_input1_reshaped_shape = { N_SIZE, BATCH_SIZE, 32, K_SIZE }; + ref_input1_shape = { BATCH_SIZE, 32, K_SIZE, N_SIZE }; + ref_output_shape = { BATCH_SIZE, 32, M_SIZE, N_SIZE }; + + std::vector ref_out_data; + ref_out_data.resize(ov::shape_size(ref_output_shape)); + + std::vector ref_input_0_data(input_0_data.size()); + std::vector ref_input_1_broadcasted_data(ov::shape_size(ref_input1_broadcasted_shape)); + std::vector ref_input_1_reshaped_data(ov::shape_size(ref_input1_reshaped_shape)); + std::vector ref_input_1_data(ref_input_1_broadcasted_data.size()); + + ov::reference::transpose((const char *)(input_0_data.data()), + (char *)(ref_input_0_data.data()), + input0_shape, + sizeof(float), + input0_order, + ref_input0_shape); + + ov::reference::broadcast(reinterpret_cast(input_1_data.data()), + reinterpret_cast(ref_input_1_broadcasted_data.data()), + input1_shape, + ref_input1_broadcasted_shape, + ov::AxisSet({}), + sizeof(float)); + + std::vector axes_order(ov::shape_size(ref_input1_broadcasted_shape)); + std::iota(axes_order.begin(), axes_order.end(), 0); + + ov::reference::reshape(reinterpret_cast(ref_input_1_broadcasted_data.data()), + reinterpret_cast(ref_input_1_reshaped_data.data()), + ref_input1_broadcasted_shape, + axes_order, + ref_input1_reshaped_shape, + sizeof(float)); + + ov::reference::transpose((const char *)(ref_input_1_reshaped_data.data()), + (char *)(ref_input_1_data.data()), + ref_input1_reshaped_shape, + sizeof(float), + input1_order, + ref_input1_shape); + + ov::reference::matmul(ref_input_0_data.data(), + ref_input_1_data.data(), + ref_out_data.data(), + ref_input0_shape, + ref_input1_shape, + ref_output_shape, + false, + false); + + ASSERT_EQ(output_ptr.size(), ref_out_data.size()); + + const auto abs_error = 0.0001; + for (uint32_t i = 0; i < ref_out_data.size(); ++i) { + ASSERT_NEAR(output_ptr[i], ref_out_data[i], abs_error) << "at " << i; + } + } + void test_transpose_matmul(size_t num_dims, bool is_input_dynamic, bool is_caching_test) { tests::random_generator rg; rg.set_seed(GET_SUITE_NAME); @@ -914,7 +1194,7 @@ class gemm_gpu_tests: public ::testing::Test { topology topology; topology.add(input_layout("input0", input0_layout), input_layout("input1", input1_layout), - gemm("gemm", { input_info("input0"), input_info("input1") }, data_types::f32, input0_order, input1_order) + gemm("gemm", { input_info("input0"), input_info("input1") }, data_types::f32, {}, {}, {}, {}, input0_order, input1_order) ); ExecutionConfig config = get_test_default_config(engine); @@ -1072,7 +1352,7 @@ class gemm_gpu_tests: public ::testing::Test { topology topology; topology.add(input_layout("input0", input0_layout), input_layout("input1", input1_layout), - gemm("gemm", { input_info("input0"), input_info("input1") }, data_types::f16, input0_order, input1_order, output_order) + gemm("gemm", { input_info("input0"), input_info("input1") }, data_types::f16, {}, {}, {}, {}, input0_order, input1_order, output_order) ); ExecutionConfig config = get_test_default_config(engine); @@ -1225,6 +1505,14 @@ TEST_F(gemm_gpu_tests, transpose_matmul_in1_indirect) { this->test_transpose_indirect(false, false, true); } +TEST_F(gemm_gpu_tests, broadcast_transpose_matmul) { + this->test_broadcast_transpose_matmul(false); +} + +TEST_F(gemm_gpu_tests, broadcast_reshape_transpose_matmul) { + this->test_broadcast_reshape_transpose_matmul(false); +} + TEST_F(gemm_gpu_tests, transpose_matmul_transpose_dynamic_1d) { this->test_transpose_matmul_transpose(1, true, false); } diff --git a/src/plugins/intel_gpu/tests/unit/transformations/broadcast_reshape_matmul_fusion_test.cpp b/src/plugins/intel_gpu/tests/unit/transformations/broadcast_reshape_matmul_fusion_test.cpp new file mode 100644 index 00000000000000..8c25f415967dd5 --- /dev/null +++ b/src/plugins/intel_gpu/tests/unit/transformations/broadcast_reshape_matmul_fusion_test.cpp @@ -0,0 +1,99 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_test_utils/ov_test_utils.hpp" + +#include "openvino/core/model.hpp" +#include "openvino/pass/manager.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/reshape.hpp" +#include "intel_gpu/op/gemm.hpp" + +#include "plugin/transformations/broadcast_reshape_matmul_fusion.hpp" + +#include + +using namespace testing; +using namespace ov::intel_gpu; + +namespace ov { +namespace test { +namespace intel_gpu { + +TEST_F(TransformationTestsF, BroadReshapeMatmulFusion1) { + std::vector order_a = {0, 1, 2, 3}; + std::vector order_b = {1, 2, 3, 0}; + std::vector order_c = {0, 1, 2, 3}; + std::vector target_shape_b = {1, 1, 1, 16, 1}; + std::vector pattern_b = {0, 0, 32, 32}; + { + auto input_a = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto input_b = std::make_shared(ov::element::f32, ov::PartialShape{-1, -1, 2, 1, 32}); + auto broadcast_b_const = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{5}, target_shape_b); + auto broadcast_b = std::make_shared(input_b, broadcast_b_const, ov::op::BroadcastType::BIDIRECTIONAL); + auto reshape_b_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, pattern_b); + auto reshape_b = std::make_shared(broadcast_b, reshape_b_const, true); + auto gemm = std::make_shared(input_a, reshape_b, order_a, order_b, order_c, ov::element::undefined); + + model = std::make_shared(ov::NodeVector{ gemm }, ov::ParameterVector{ input_a, input_b }); + manager.register_pass(); + } + { + auto input_a = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto input_b = std::make_shared(ov::element::f32, ov::PartialShape{-1, -1, 2, 1, 32}); + auto gemm = std::make_shared(input_a, input_b, std::vector{}, target_shape_b, std::vector{}, pattern_b, order_a, order_b, order_c, ov::element::undefined); + + model_ref = std::make_shared(ov::NodeVector{ gemm }, ov::ParameterVector{ input_a, input_b }); + comparator.enable(FunctionsComparator::ATTRIBUTES); + } +} + +TEST_F(TransformationTestsF, BroadReshapeMatmulFusion2) { + std::vector order_a = {0, 1, 2, 3}; + std::vector order_b = {1, 2, 3, 0}; + std::vector order_c = {0, 1, 2, 3}; + std::vector target_shape_b = {1, 1, 1, 16, 1}; + std::vector pattern_b = {0, 0, -1, 32}; + { + auto input_a = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto input_b = std::make_shared(ov::element::f32, ov::PartialShape{-1, -1, -1, 1, 32}); + auto broadcast_b_const = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{5}, target_shape_b); + auto broadcast_b = std::make_shared(input_b, broadcast_b_const, ov::op::BroadcastType::BIDIRECTIONAL); + auto reshape_b_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, pattern_b); + auto reshape_b = std::make_shared(broadcast_b, reshape_b_const, true); + auto gemm = std::make_shared(input_a, reshape_b, order_a, order_b, order_c, ov::element::undefined); + + model = std::make_shared(ov::NodeVector{ gemm }, ov::ParameterVector{ input_a, input_b }); + manager.register_pass(); + } + { + model_ref = model->clone(); + comparator.enable(FunctionsComparator::ATTRIBUTES); + } +} + +TEST_F(TransformationTestsF, BroadReshapeMatmulFusion3) { + std::vector order_a = {0, 1, 2, 3}; + std::vector order_b = {0, 1, 2, 3}; + std::vector order_c = {0, 1, 2, 3}; + std::vector target_shape_b = {1, 1, 16, 1, 1}; + std::vector pattern_b = {0, 32, 32, 0}; + { + auto input_a = std::make_shared(ov::element::f32, ov::PartialShape::dynamic(4)); + auto input_b = std::make_shared(ov::element::f32, ov::PartialShape{-1, 2, 1, 32, -1}); + auto broadcast_b_const = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{5}, target_shape_b); + auto broadcast_b = std::make_shared(input_b, broadcast_b_const, ov::op::BroadcastType::BIDIRECTIONAL); + auto reshape_b_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, pattern_b); + auto reshape_b = std::make_shared(broadcast_b, reshape_b_const, true); + auto gemm = std::make_shared(input_a, reshape_b, order_a, order_b, order_c, ov::element::undefined); + + model = std::make_shared(ov::NodeVector{ gemm }, ov::ParameterVector{ input_a, input_b }); + manager.register_pass(); + } +} + +} // namespace intel_gpu +} // namespace test +} // namespace ov