Skip to content

Commit

Permalink
[GPU] MLP : 2fcs + swiglu fusion (openvinotoolkit#27831)
Browse files Browse the repository at this point in the history
### Details:
 - 2 FCs + swiglu in MLP pattern are fused
 - Only applied to cldnn && #EUs > 128 && glu type with swiglu 

### Tickets:
 - 152163
  • Loading branch information
yeonbok authored Dec 6, 2024
1 parent c3b014c commit bf62609
Show file tree
Hide file tree
Showing 22 changed files with 469 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class debug_configuration {
int use_kv_cache_compression; // Enable KV-cache compression
int dynamic_quantize_group_size; // Enable Dynamic quantization for fully connected primitive by specified group size
int disable_horizontal_fc_fusion; // Disable fc horizontal fusion
int disable_fc_swiglu_fusion; // Disable swiglu fusion to fc
std::set<int64_t> dump_iteration; // Dump n-th execution of network.
std::vector<std::string> load_layers_raw_dump; // List of layers to load dumped raw binary and filenames
static const debug_configuration *get_instance();
Expand Down
5 changes: 5 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/runtime/layout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ struct data_type_traits {
return et.is_quantized() && et.bitwidth() == 8;
}

static bool is_i4_u4(data_types data_type) {
auto et = ov::element::Type(data_type);
return et.bitwidth() == 4;
}

static ov::element::Type max_type(ov::element::Type t1, ov::element::Type t2) {
if (t1.bitwidth() < t2.bitwidth())
return t2;
Expand Down
26 changes: 23 additions & 3 deletions src/plugins/intel_gpu/src/graph/fully_connected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
#include <string>
#include <algorithm>
#include "utils.hpp"
#include "swiglu_inst.h"

#include "matmul_shape_inference.hpp"
#include "glu_shape_inference.hpp"

namespace cldnn {
GPU_DEFINE_PRIMITIVE_TYPE_ID(fully_connected)
Expand Down Expand Up @@ -171,14 +173,32 @@ std::vector<layout> fully_connected_inst::calc_output_layouts(fully_connected_no
output_type = impl_param.get_output_element_type();
}

ov::op::v0::MatMul op;
op.set_transpose_b(true);
ov::op::v0::MatMul matmul_op;
matmul_op.set_transpose_b(true);
std::vector<ShapeType> input_shapes = {
input_layout.get<ShapeType>(),
weights_layout.get<ShapeType>()
};

std::vector<ShapeType> output_shapes = ov::op::v0::shape_infer(&op, input_shapes);
std::vector<ShapeType> output_shapes = ov::op::v0::shape_infer(&matmul_op, input_shapes);
bool has_swiglu = false;
auto& fused_prims = node.get_fused_primitives();
for (auto f : fused_prims) {
if (f.is_type<swiglu>()) {
has_swiglu = true;
OPENVINO_ASSERT(fused_prims.size() == 1, "Other operation is fused in addition to swiglu!");
}
}
if (has_swiglu) {
ov::op::internal::GLU swiglu_op;
OPENVINO_ASSERT(fused_prims.size() == 1);
OPENVINO_ASSERT(fused_prims[0].typed_desc<swiglu>()->glu_type == ov::op::internal::GLU::GluType::Swish);
swiglu_op.set_axis(fused_prims[0].typed_desc<swiglu>()->axis);
swiglu_op.set_split_lengths(fused_prims[0].typed_desc<swiglu>()->split_lengths);
swiglu_op.set_glu_type(fused_prims[0].typed_desc<swiglu>()->glu_type);
std::vector<ShapeType> input_shapes = { output_shapes[0] };
output_shapes = shape_infer(&swiglu_op, input_shapes);
}

bool is_static = input_layout.is_static() && weights_layout.is_static();
bool allow_new_shape_infer = impl_param.get_program().is_new_shape_infer();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "intel_gpu/runtime/debug_configuration.hpp"
#include "program_helpers.h"
#include "pass_manager.h"

Expand Down Expand Up @@ -37,6 +37,7 @@
#include "strided_slice_inst.h"
#include "cum_sum_inst.h"
#include "embedding_bag_inst.h"
#include "swiglu_inst.h"
#include "extract_image_patches_inst.h"
#include "reduce_inst.h"
#include "group_normalization_inst.h"
Expand All @@ -56,6 +57,7 @@ using namespace cldnn;
void prepare_primitive_fusing::run(program& p) {
fuse_reorders(p);
remove_redundant_reshape(p);
fuse_swiglu(p);
fuse_bias(p);
fuse_simple_primitives(p);
fuse_constant_transposes(p);
Expand Down Expand Up @@ -161,6 +163,46 @@ void prepare_primitive_fusing::fuse_reorders(program &p) {
}
}

void prepare_primitive_fusing::fuse_swiglu(program &p) {
GPU_DEBUG_GET_INSTANCE(debug_config);
bool disable_fc_swiglu_fusion = false;
GPU_DEBUG_IF(debug_config->disable_fc_swiglu_fusion == 1)
disable_fc_swiglu_fusion = true;
// Apply only for high performant GPU
if (disable_fc_swiglu_fusion || p.get_engine().get_device_info().execution_units_count < 128)
return;
// TODO: to support other glu types && other weight data types
auto itr = p.get_processing_order().begin();
std::map<primitive_id, std::vector<std::pair<primitive_id, size_t>>> fusing_history;
while (itr != p.get_processing_order().end()) {
auto node_itr = itr++;
auto& node = (*node_itr);
if (node->is_type<swiglu>()) {
if (!node->get_dependency(0).is_type<fully_connected>())
continue;
auto swiglu_prim = node->get_kernel_impl_params()->typed_desc<swiglu>();
auto& fc_node = node->get_dependency(0);
if (node->get_dependencies().size() > 1)
continue;
if (!node->get_dependency(0).get_fused_primitives().empty())
continue;
auto in_dt = fc_node.get_input_layout(0).data_type;
if (in_dt != data_types::f16)
continue;
auto wt_dt = fc_node.get_input_layout(1).data_type;
if (!data_type_traits::is_i4_u4(wt_dt))
continue;
if (swiglu_prim->glu_type != ov::op::internal::GLU::GluType::Swish ||
!(swiglu_prim->axis == -1 || swiglu_prim->axis == static_cast<int64_t>(node->get_output_layout(0).get_partial_shape().size()) - 1))
continue;
GPU_DEBUG_TRACE_DETAIL << node->id() << " : fuse swiglu to " << fc_node.id() << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - split axis : " << swiglu_prim->axis << std::endl;
GPU_DEBUG_TRACE_DETAIL << " - split length : " << swiglu_prim->split_lengths << std::endl;
p.fuse_nodes(fc_node, *node, &fusing_history);
}
}
}

void prepare_primitive_fusing::fuse_bias(program &p) {
auto itr = p.get_processing_order().begin();
while (itr != p.get_processing_order().end()) {
Expand Down Expand Up @@ -188,6 +230,17 @@ void prepare_primitive_fusing::fuse_bias(program &p) {
if (!is_bias_add)
continue;

for (auto& dep : eltw_node.get_dependencies()) {
auto& fused_prims = dep.first->get_fused_primitives();
if (std::any_of(fused_prims.begin(), fused_prims.end(), [](const fused_primitive_desc& f_desc) {
return f_desc.is_type<swiglu>();
})) {
GPU_DEBUG_TRACE_DETAIL << "Skip fusing " << eltw_node.id() << " to " << dep.first->id() << " because "
<< dep.first->id() << " has fused swiglu." << std::endl;
continue;
}
}

auto is_3d_fully_connected = [](program_node& node) {
if (!node.is_type<fully_connected>())
return false;
Expand Down Expand Up @@ -491,6 +544,13 @@ void prepare_primitive_fusing::fuse_simple_primitives(program &p) {
};

auto fc_supports_fusings = [&](fully_connected_node& node) -> bool {
auto& fused_prims = node.get_fused_primitives();
if (std::any_of(fused_prims.begin(), fused_prims.end(), [](const fused_primitive_desc& f_desc) {
return f_desc.is_type<swiglu>();
})) {
GPU_DEBUG_TRACE_DETAIL << node.id() << " has fused swiglu. Skip fusing more primitives" << std::endl;
return false;
}
if (lo.has_all_enabled_onednn_impls_optimization_attribute() &&
lo.get_preferred_impl_type(node, format::any /*dummy*/) == impl_types::onednn) {
return true;
Expand Down
14 changes: 11 additions & 3 deletions src/plugins/intel_gpu/src/graph/impls/ocl/fully_connected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,16 @@ struct fully_connected_impl : typed_primitive_impl_ocl<fully_connected> {
return layouts;
};

auto get_fc_output_layout = [primitive](const std::vector<layout>& input_layouts, const layout& output_layout) {
auto get_fc_output_layout = [primitive](const std::vector<layout>& input_layouts, const layout& output_layout, bool swiglu_fused) {
auto updated_out_layout = output_layout;

auto input0_pshape = input_layouts[0].get_partial_shape();
auto input1_pshape = input_layouts[1].get_partial_shape();
ov::PartialShape updated_out_pshape {input0_pshape[0], input1_pshape[0]};
const auto output_feature_size = swiglu_fused ? input1_pshape[0] / 2 : input1_pshape[0];

if (primitive->input_size == 3) {
updated_out_pshape = { input0_pshape[0], input0_pshape[1], input1_pshape[0] };
updated_out_pshape = { input0_pshape[0], input0_pshape[1], output_feature_size};
}
updated_out_layout.set_partial_shape(updated_out_pshape);

Expand All @@ -149,14 +150,21 @@ struct fully_connected_impl : typed_primitive_impl_ocl<fully_connected> {

bool allow_new_shape_infer = impl_param.get_program().is_new_shape_infer();
auto updated_impl_param = impl_param;
bool swiglu_fused = false;
if (updated_impl_param.fused_desc.size() > 0) {
for (const auto& f : updated_impl_param.fused_desc) {
if (f.is_type<swiglu>())
swiglu_fused = true;
}
}

const auto input_layouts = get_fc_input_layouts(impl_param.input_layouts, allow_new_shape_infer);
for (size_t i = 0; i < input_layouts.size(); ++i) {
updated_impl_param.input_layouts[i] = input_layouts[i];
}
updated_impl_param.weights_layout = input_layouts[1];

updated_impl_param.output_layouts[0] = get_fc_output_layout(input_layouts, impl_param.get_output_layout());
updated_impl_param.output_layouts[0] = get_fc_output_layout(input_layouts, impl_param.get_output_layout(), swiglu_fused);

return updated_impl_param;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
#include "intel_gpu/primitives/embedding_bag.hpp"
#include "intel_gpu/primitives/extract_image_patches.hpp"

#include "swiglu_inst.h"
#include "activation_inst.h"
#include "eltwise_inst.h"
#include "quantize_inst.h"
#include "reorder_inst.h"

#include "kernel_selector/kernels/swiglu/swiglu_kernel_base.h"
#include "kernel_selector/kernels/activation/activation_kernel_base.h"
#include "kernel_selector/kernels/depth_to_space/depth_to_space_kernel_base.h"
#include "kernel_selector/kernels/eltwise/eltwise_kernel_base.h"
Expand Down Expand Up @@ -1009,7 +1011,13 @@ kernel_selector::activation_function get_kernel_selector_activation_param(activa
}

std::shared_ptr<kernel_selector::fuse_params> convert_fuse_params(std::shared_ptr<NodeFuseParams> p) {
if (p->type() == activation::type_id()) {
if (p->type() == swiglu::type_id()) {
auto casted = std::dynamic_pointer_cast<SwigluFuseParams>(p);
auto axis = casted->_desc->axis;
auto split_length = casted->_desc->split_lengths;
auto split_to_glu_idx = casted->_desc->split_to_glu_idx;
return std::make_shared<kernel_selector::swiglu_fuse_params>(axis, split_length, split_to_glu_idx);
} else if (p->type() == activation::type_id()) {
auto casted = std::dynamic_pointer_cast<ActivationFuseParams>(p);
auto desc = casted->_desc;
kernel_selector::base_activation_params p;
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/include/pass_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class prepare_primitive_fusing : public base_pass {
private:
void run(program& p) override;
void fuse_bias(program &p);
void fuse_swiglu(program &p);
void fuse_reorders(program& p);
void fuse_simple_primitives(program &p);
void fuse_constant_transposes(program &p);
Expand Down
9 changes: 9 additions & 0 deletions src/plugins/intel_gpu/src/graph/include/swiglu_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@

namespace cldnn {

class SwigluFuseParams : public NodeFuseParams {
public:
SwigluFuseParams(std::shared_ptr<swiglu> desc) : NodeFuseParams(swiglu::type_id()), _desc(std::move(desc)) {}
std::shared_ptr<swiglu> _desc;
};
template <>
struct typed_program_node<swiglu> : public typed_program_node_base<swiglu> {
using parent = typed_program_node_base<swiglu>;
Expand All @@ -19,6 +24,10 @@ struct typed_program_node<swiglu> : public typed_program_node_base<swiglu> {

program_node& input(size_t index = 0) const { return get_dependency(index); }
std::vector<size_t> get_shape_infer_dependencies() const override { return {}; }

std::shared_ptr<NodeFuseParams> get_fuse_params() const override {
return std::make_shared<SwigluFuseParams>(typed_desc());
}
};

using swiglu_node = typed_program_node<swiglu>;
Expand Down
11 changes: 11 additions & 0 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "gather_inst.h"
#include "broadcast_inst.h"
#include "dynamic_quantize_inst.h"
#include "swiglu_inst.h"
#include "experimental_detectron_roi_feature_extractor_inst.hpp"
#include "impls/registry/implementation_manager.hpp"
#include "impls/registry/registry.hpp"
Expand Down Expand Up @@ -2606,6 +2607,16 @@ bool primitive_inst::is_valid_fusion() const {
} else {
if (fd.is_type<reorder>() || fd.is_type<quantize>())
continue;
if (fd.is_type<swiglu>()) {
OPENVINO_ASSERT(_node->is_type<fully_connected>() && _node->get_preferred_impl_type() == impl_types::ocl);
if (!_node->get_selected_impl())
return false;
// TODO : support ref kernel too
if (_node->get_selected_impl()->get_kernel_name().find("fully_connected_gpu_bf_tiled") != std::string::npos)
return true;
else
return false;
}

OPENVINO_THROW("[GPU] Unsupported fused operation in dynamic shape: type=", fd.desc->type_string(), ", id=", fd.desc->id);
}
Expand Down
22 changes: 22 additions & 0 deletions src/plugins/intel_gpu/src/graph/program_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "activation_inst.h"
#include "reorder_inst.h"
#include "quantize_inst.h"
#include "swiglu_inst.h"
#include "intel_gpu/runtime/debug_configuration.hpp"
#ifdef ENABLE_ONEDNN_FOR_GPU
#include "convolution_inst.h"
Expand Down Expand Up @@ -770,6 +771,15 @@ void program_node::save(cldnn::BinaryOutputBuffer& ob) const {
ob << casted->_out_hi;
ob << casted->_out_scale;
ob << casted->_out_shift;
} else if (f_desc.f_param->type() == swiglu::type_id()) {
auto casted = std::dynamic_pointer_cast<SwigluFuseParams>(f_desc.f_param);
if (get_program().has_node(casted->_desc->id)) {
ob << true;
ob << casted->_desc->id;
} else {
ob << false;
ob << casted->_desc;
}
}

ob << f_desc.deps.size();
Expand Down Expand Up @@ -975,6 +985,18 @@ void program_node::load(cldnn::BinaryInputBuffer& ib) {
need_pre_shift, need_clamp, need_min_clamp, need_max_clamp, per_tensor_input_range,
per_tensor_input_scale, per_tensor_input_shift, per_tensor_output_range, per_tensor_output_scale,
per_tensor_output_shift, in_lo, in_hi, in_scale, in_shift, out_lo, out_hi, out_scale, out_shift);
} else if (f_param_type == swiglu::type_id()) {
ib >> exist_prim;
std::shared_ptr<swiglu> param_desc;
if (exist_prim) {
primitive_id desc_id;
ib >> desc_id;
param_desc = std::dynamic_pointer_cast<swiglu>(get_program().get_node_ptr(desc_id)->desc);
} else {
ib >> param_desc;
}
f_desc.f_param = std::make_shared<SwigluFuseParams>(param_desc);

} else {
f_desc.f_param = std::make_shared<NodeFuseParams>(f_param_type);
}
Expand Down
Loading

0 comments on commit bf62609

Please sign in to comment.