Skip to content

Commit

Permalink
[GPU] Enable 8bit compression support on dGPU via oneDNN
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Feb 8, 2024
1 parent 16e6aba commit 06d6aac
Show file tree
Hide file tree
Showing 7 changed files with 281 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,17 @@ void compile_graph::run(program& p) {
auto& node = *(std::next(proc_order.begin(), idx));
const bool use_shape_agnostic_impl = !p.get_config().get_property(ov::intel_gpu::use_only_static_kernels_for_dynamic_shape);
const impl_types original_impl_type = node->get_preferred_impl_type();
const bool change_initial_impl = node->is_dynamic() && original_impl_type == impl_types::onednn;
bool change_initial_impl = node->is_dynamic() && original_impl_type == impl_types::onednn;

if (node->is_type<fully_connected>() && change_initial_impl) {
const auto fc_prim = node->as<fully_connected>().get_primitive();
const auto weights_dt = node->get_input_layout(1).data_type;

// Do not change impl (i.e. do not use ocl shape-agnostic kernels) in case of FC and 8bit compressed weights,
// since oneDNN primitives/kernels caching mechanism will be used instead.
if (fc_prim->compressed_weights && ov::element::Type(weights_dt).bitwidth() == 8)
change_initial_impl = false;
}

if (change_initial_impl)
node->set_preferred_impl_type(impl_types::ocl);
Expand Down
122 changes: 109 additions & 13 deletions src/plugins/intel_gpu/src/graph/impls/onednn/fully_connected_onednn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,26 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
args.insert({DNNL_ARG_BIAS, bias->get_onednn_memory(_pd.weights_desc(1), offset)});
}

const auto& prim = instance.get_impl_params()->typed_desc<fully_connected>();
if (prim->compressed_weights) {
const auto weights_dt = instance.get_input_layout(1).data_type;
OPENVINO_ASSERT(ov::element::Type(weights_dt).bitwidth() == 8, "[GPU] oneDNN supports only 8bit compressed weights");

if (!prim->decompression_scale.empty()) {
auto decompression_scale_idx = prim->bias.empty() ? 2 : 3;
auto scale_mem = instance.dep_memory_ptr(decompression_scale_idx);
dnnl::memory::desc desc = onednn::layout_to_memory_desc(scale_mem->get_layout(), dnnl::memory::format_tag::a, true);
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, scale_mem->get_onednn_memory(desc)});
}

if (!prim->decompression_zero_point.empty()) {
auto decompression_zp_idx = prim->bias.empty() ? 3 : 4;
auto zp_mem = instance.dep_memory_ptr(decompression_zp_idx);
dnnl::memory::desc desc = onednn::layout_to_memory_desc(zp_mem->get_layout(), dnnl::memory::format_tag::a, true);
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, zp_mem->get_onednn_memory(desc)});
}
}

return args;
}

Expand All @@ -76,13 +96,7 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
return std::make_shared<WeightsReorderParams>(weights_layout, output_weights_layout, false);
}

static std::shared_ptr<dnnl::inner_product_forward::primitive_desc> get_fully_connected_primitive_descriptor(const kernel_impl_params& impl_params,
cldnn::engine& engine, size_t prim_input_size, bool has_bias,
const dnnl::primitive_attr& attr = dnnl::primitive_attr()) {
auto input_layout = impl_params.get_input_layout(0);
auto weights_layout = impl_params.get_input_layout(1);
auto output_layout = impl_params.get_output_layout();

static void transform_layouts(layout& input_layout, layout& weights_layout, layout& output_layout, size_t prim_input_size) {
auto input_pshape = input_layout.get_partial_shape();
auto weights_pshape = weights_layout.get_partial_shape();

Expand All @@ -93,7 +107,7 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
}

if (input_size > 3) {
input_layout.set_partial_shape(reshape_to_2d(input_pshape, feature));
input_layout.set_partial_shape(reshape_to_2d(input_pshape, feature));
}
if (weights_pshape.size() != 2) {
weights_layout.set_partial_shape(reshape_to_2d(weights_pshape, feature));
Expand All @@ -108,6 +122,19 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
combine_bf_with_first_spatial_dim(input_layout);
combine_bf_with_first_spatial_dim(output_layout);
}
}

static std::shared_ptr<dnnl::inner_product_forward::primitive_desc>
get_inner_product_primitive_descriptor(const kernel_impl_params& impl_params,
cldnn::engine& engine,
size_t prim_input_size,
bool has_bias,
const dnnl::primitive_attr& attr = dnnl::primitive_attr()) {
auto input_layout = impl_params.get_input_layout(0);
auto weights_layout = impl_params.get_input_layout(1);
auto output_layout = impl_params.get_output_layout();

transform_layouts(input_layout, weights_layout, output_layout, prim_input_size);

auto input_md = onednn::layout_to_memory_desc(input_layout, dnnl::memory::format_tag::undef, false);
auto weights_md = onednn::layout_to_memory_desc(weights_layout, dnnl::memory::format_tag::any);
Expand All @@ -134,6 +161,41 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
}
}

static std::shared_ptr<dnnl::matmul::primitive_desc>
get_matmul_primitive_descriptor(const kernel_impl_params& impl_params,
cldnn::engine& engine,
size_t prim_input_size,
bool has_bias,
const dnnl::primitive_attr& attr = dnnl::primitive_attr()) {
auto input_layout = impl_params.get_input_layout(0);
auto weights_layout = impl_params.get_input_layout(1);
auto output_layout = impl_params.get_output_layout();

transform_layouts(input_layout, weights_layout, output_layout, prim_input_size);

auto input_md = onednn::layout_to_memory_desc(input_layout, dnnl::memory::format_tag::ab, false);
auto weights_md = onednn::layout_to_memory_desc(weights_layout, dnnl::memory::format_tag::ba);
auto output_md = onednn::layout_to_memory_desc(output_layout, dnnl::memory::format_tag::ab, false);

if (has_bias) {
auto bias_md = onednn::layout_to_memory_desc(impl_params.get_input_layout(2), dnnl::memory::format_tag::ab, false);
return std::make_shared<dnnl::matmul::primitive_desc>(
engine.get_onednn_engine(),
input_md,
weights_md,
bias_md,
output_md,
attr);
} else {
return std::make_shared<dnnl::matmul::primitive_desc>(
engine.get_onednn_engine(),
input_md,
weights_md,
output_md,
attr);
}
}

public:
void save(BinaryOutputBuffer& ob) const override {
#ifdef ONEDNN_PRIMITIVE_SERIALIZATION
Expand All @@ -143,8 +205,10 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
auto prim = impl_params->typed_desc<fully_connected>();
size_t input_size = prim->input_size;
bool has_bias = !prim->bias.empty();
bool is_compressed = prim->compressed_weights;
ob << input_size;
ob << has_bias;
ob << is_compressed;

std::vector<uint8_t> prim_cache;
prim_cache = _prim.get_cache_blob();
Expand All @@ -158,12 +222,19 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {

size_t input_size = 2;
bool has_bias = false;
bool is_compressed = false;
ib >> input_size;
ib >> has_bias;
ib >> is_compressed;

const kernel_impl_params* impl_params = reinterpret_cast<kernel_impl_params*>(ib.getKernelImplParams());
auto prim_desc = get_fully_connected_primitive_descriptor(*impl_params, ib.get_engine(), input_size, has_bias, *_attrs);
_pd = *prim_desc;
if (is_compressed) {
auto prim_desc = get_matmul_primitive_descriptor(*impl_params, ib.get_engine(), input_size, has_bias, *_attrs);
_pd = *prim_desc;
} else {
auto prim_desc = get_inner_product_primitive_descriptor(*impl_params, ib.get_engine(), input_size, has_bias, *_attrs);
_pd = *prim_desc;
}

std::vector<uint8_t> prim_cache;
ib >> prim_cache;
Expand All @@ -179,10 +250,35 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
auto& config = impl_params.prog->get_config();
auto attr = arg.get_onednn_primitive_attributes();
auto prim = impl_params.typed_desc<fully_connected>();
auto prim_desc = get_fully_connected_primitive_descriptor(impl_params, impl_params.prog->get_engine(),
prim->input_size, !prim->bias.empty(), *attr);

return cldnn::make_unique<fully_connected_onednn>(engine, config, attr, *prim_desc, get_weights_reorder(impl_params, *prim_desc));
// There may be a performance difference between InnerProduct and MatMul primitives in oneDNN,
// so use MatMul only for weights compression and IP for all other cases.
if (prim->compressed_weights) {
attr->set_fpmath_mode(dnnl::fpmath_mode::f16, true);
if (!prim->decompression_scale.empty()) {
auto decompression_scale_idx = !arg.bias_term() ? 2 : 3;
auto data_type = convert_data_type(arg.get_dependency(decompression_scale_idx).get_output_layout().data_type);
attr->set_scales(DNNL_ARG_WEIGHTS, 1 << 1, dnnl::memory::dims{}, data_type);
}

if (prim->decompression_zero_point_scalar.has_value()) {
OPENVINO_ASSERT(!prim->decompression_zero_point_scalar.has_value(), "[GPU] OneDNN can't use scalar as a zero point value, need to broadcast it\n");
} else if (!prim->decompression_zero_point.empty()) {
auto decompression_zp_idx = !arg.bias_term() ? 3 : 4;
auto data_type = convert_data_type(arg.get_dependency(decompression_zp_idx).get_output_layout().data_type);
attr->set_zero_points(DNNL_ARG_WEIGHTS, 1 << 1, dnnl::memory::dims{}, data_type);
}

auto prim_desc = get_matmul_primitive_descriptor(impl_params, impl_params.prog->get_engine(),
prim->input_size, !prim->bias.empty(), *attr);

return cldnn::make_unique<fully_connected_onednn>(engine, config, attr, *prim_desc);
} else {
auto prim_desc = get_inner_product_primitive_descriptor(impl_params, impl_params.prog->get_engine(),
prim->input_size, !prim->bias.empty(), *attr);

return cldnn::make_unique<fully_connected_onednn>(engine, config, attr, *prim_desc, get_weights_reorder(impl_params, *prim_desc));
}
}
};

Expand Down
3 changes: 3 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/onednn/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ dnnl::memory::desc layout_to_memory_desc(cldnn::layout l, dnnl::memory::format_t
} else if (target_fmt == dnnl::memory::format_tag::ab) {
dims.push_back(l.batch());
dims.push_back(l.get_tensor().count() / l.batch());
} else if (target_fmt == dnnl::memory::format_tag::ba) {
dims.push_back(l.feature());
dims.push_back(l.get_tensor().count() / l.feature());
} else if (flatten) {
dims = flatten_tensor(l.get_tensor());
} else {
Expand Down
35 changes: 28 additions & 7 deletions src/plugins/intel_gpu/src/graph/layout_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -894,12 +894,25 @@ static bool is_node_for_onednn(deconvolution_node const& node) {


static bool is_node_for_onednn(fully_connected_node const& node) {
if (!layout_optimizer::are_data_types_suitable_for_onednn((program_node&)node))
return false;

auto fc_prim = node.get_primitive();
// onednn impl doesn't support compressed weights for now
if (fc_prim->compressed_weights)

if (fc_prim->compressed_weights) {
auto weights_dt = node.weights().get_output_layout().data_type;
if (ov::element::Type(weights_dt).bitwidth() != 8)
return false;

if (fc_prim->decompression_zero_point_scalar.has_value())
return false;

if (!fc_prim->decompression_zero_point.empty()) {
auto decompression_zp_idx = fc_prim->bias.empty() ? 3 : 4;
auto decompression_zp_dt = node.get_input_layout(decompression_zp_idx).data_type;
if (weights_dt != decompression_zp_dt)
return false;
}
}

if (!layout_optimizer::are_data_types_suitable_for_onednn((program_node&)node))
return false;

auto output_layout = node.get_output_layout();
Expand Down Expand Up @@ -1331,8 +1344,16 @@ bool layout_optimizer::are_data_types_suitable_for_onednn(program_node& node) {
return onednn_check_data_types_for_deconvolution(in_dt, wei_dt, out_dt);
} else if (node.is_type<fully_connected>() || node.is_type<gemm>()) {
bool is_fc = node.is_type<fully_connected>();
auto wei_dt = is_fc ? node.as<fully_connected>().weights().get_output_layout(false).data_type :
node.as<gemm>().get_input_layout(1).data_type;
data_types wei_dt;
if (is_fc) {
const auto& fc_node = node.as<fully_connected>();
const auto fc_prim = fc_node.get_primitive();
wei_dt = fc_node.weights().get_output_layout(false).data_type;
if (fc_prim->compressed_weights && ov::element::Type(wei_dt).bitwidth() == 8)
return true;
} else {
wei_dt = node.as<gemm>().get_input_layout(1).data_type;
}
return onednn_check_data_types_for_fc_gemm(in_dt, wei_dt, out_dt);
} else if (node.is_type<reorder>()) {
auto input_fmt = node.get_input_layout(0).format;
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1632,6 +1632,8 @@ event::ptr primitive_inst::update_weights() {
// incorrect memory buffer may be assigned, so reset cached weights for such case
_reordered_weights_cache.add(original_layout, original_weights_memory);
_impl_params->weights_layout = optional_layout(original_layout);
GPU_DEBUG_TRACE_DETAIL << id() << ": add original weights memory " << original_layout.to_short_string() << " to weights cache; "
<< "cache_size=" << _reordered_weights_cache.size() << "/" << _reordered_weights_cache.capacity() << std::endl;
} else {
auto expected_layout = reorder_kernel_params->get_output_layout();
// Set original partial shape, because it may be lost during kernel_selector::weights_tensor -> layout conversion
Expand Down
Loading

0 comments on commit 06d6aac

Please sign in to comment.