Skip to content

Commit

Permalink
[GPU] Enable 8bit compression support on dGPU via oneDNN (openvinotoo…
Browse files Browse the repository at this point in the history
…lkit#22740)

### Details:
 - Enable 8bit compression support on dGPU via oneDNN
 - Update oneDNN version
 - Enable oneDNN primitives cache

Ticket: 124115
  • Loading branch information
sshlyapn authored and alvoron committed Apr 29, 2024
1 parent 1ff5508 commit c332243
Show file tree
Hide file tree
Showing 11 changed files with 325 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,17 @@ void compile_graph::run(program& p) {
auto& node = *(std::next(proc_order.begin(), idx));
const bool use_shape_agnostic_impl = !p.get_config().get_property(ov::intel_gpu::use_only_static_kernels_for_dynamic_shape);
const impl_types original_impl_type = node->get_preferred_impl_type();
const bool change_initial_impl = node->is_dynamic() && original_impl_type == impl_types::onednn;
bool change_initial_impl = node->is_dynamic() && original_impl_type == impl_types::onednn;

if (node->is_type<fully_connected>() && change_initial_impl) {
const auto fc_prim = node->as<fully_connected>().get_primitive();
const auto weights_dt = node->get_input_layout(1).data_type;

// Do not change impl (i.e. do not use ocl shape-agnostic kernels) in case of FC and 8bit compressed weights,
// since oneDNN primitives/kernels caching mechanism will be used instead.
if (fc_prim->compressed_weights && ov::element::Type(weights_dt).bitwidth() == 8)
change_initial_impl = false;
}

if (change_initial_impl)
node->set_preferred_impl_type(impl_types::ocl);
Expand Down
122 changes: 109 additions & 13 deletions src/plugins/intel_gpu/src/graph/impls/onednn/fully_connected_onednn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,26 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
args.insert({DNNL_ARG_BIAS, bias->get_onednn_memory(_pd.weights_desc(1), offset)});
}

const auto& prim = instance.get_impl_params()->typed_desc<fully_connected>();
if (prim->compressed_weights) {
const auto weights_dt = instance.get_input_layout(1).data_type;
OPENVINO_ASSERT(ov::element::Type(weights_dt).bitwidth() == 8, "[GPU] oneDNN supports only 8bit compressed weights");

if (!prim->decompression_scale.empty()) {
auto decompression_scale_idx = prim->bias.empty() ? 2 : 3;
auto scale_mem = instance.dep_memory_ptr(decompression_scale_idx);
dnnl::memory::desc desc = onednn::layout_to_memory_desc(scale_mem->get_layout(), dnnl::memory::format_tag::a, true);
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, scale_mem->get_onednn_memory(desc)});
}

if (!prim->decompression_zero_point.empty()) {
auto decompression_zp_idx = prim->bias.empty() ? 3 : 4;
auto zp_mem = instance.dep_memory_ptr(decompression_zp_idx);
dnnl::memory::desc desc = onednn::layout_to_memory_desc(zp_mem->get_layout(), dnnl::memory::format_tag::a, true);
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, zp_mem->get_onednn_memory(desc)});
}
}

return args;
}

Expand Down Expand Up @@ -91,13 +111,7 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
false);
}

static std::shared_ptr<dnnl::inner_product_forward::primitive_desc> get_fully_connected_primitive_descriptor(const kernel_impl_params& impl_params,
cldnn::engine& engine, size_t prim_input_size, bool has_bias,
const dnnl::primitive_attr& attr = dnnl::primitive_attr()) {
auto input_layout = impl_params.get_input_layout(0);
auto weights_layout = impl_params.get_input_layout(1);
auto output_layout = impl_params.get_output_layout();

static void transform_layouts(layout& input_layout, layout& weights_layout, layout& output_layout, size_t prim_input_size) {
auto input_pshape = input_layout.get_partial_shape();
auto weights_pshape = weights_layout.get_partial_shape();

Expand All @@ -108,7 +122,7 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
}

if (input_size > 3) {
input_layout.set_partial_shape(reshape_to_2d(input_pshape, feature));
input_layout.set_partial_shape(reshape_to_2d(input_pshape, feature));
}
if (weights_pshape.size() != 2) {
weights_layout.set_partial_shape(reshape_to_2d(weights_pshape, feature));
Expand All @@ -123,6 +137,19 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
combine_bf_with_first_spatial_dim(input_layout);
combine_bf_with_first_spatial_dim(output_layout);
}
}

static std::shared_ptr<dnnl::inner_product_forward::primitive_desc>
get_inner_product_primitive_descriptor(const kernel_impl_params& impl_params,
cldnn::engine& engine,
size_t prim_input_size,
bool has_bias,
const dnnl::primitive_attr& attr = dnnl::primitive_attr()) {
auto input_layout = impl_params.get_input_layout(0);
auto weights_layout = impl_params.get_input_layout(1);
auto output_layout = impl_params.get_output_layout();

transform_layouts(input_layout, weights_layout, output_layout, prim_input_size);

auto input_md = onednn::layout_to_memory_desc(input_layout, dnnl::memory::format_tag::undef, false);
auto weights_md = onednn::layout_to_memory_desc(weights_layout, dnnl::memory::format_tag::any);
Expand All @@ -149,6 +176,41 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
}
}

static std::shared_ptr<dnnl::matmul::primitive_desc>
get_matmul_primitive_descriptor(const kernel_impl_params& impl_params,
cldnn::engine& engine,
size_t prim_input_size,
bool has_bias,
const dnnl::primitive_attr& attr = dnnl::primitive_attr()) {
auto input_layout = impl_params.get_input_layout(0);
auto weights_layout = impl_params.get_input_layout(1);
auto output_layout = impl_params.get_output_layout();

transform_layouts(input_layout, weights_layout, output_layout, prim_input_size);

auto input_md = onednn::layout_to_memory_desc(input_layout, dnnl::memory::format_tag::ab, false);
auto weights_md = onednn::layout_to_memory_desc(weights_layout, dnnl::memory::format_tag::ba);
auto output_md = onednn::layout_to_memory_desc(output_layout, dnnl::memory::format_tag::ab, false);

if (has_bias) {
auto bias_md = onednn::layout_to_memory_desc(impl_params.get_input_layout(2), dnnl::memory::format_tag::ab, false);
return std::make_shared<dnnl::matmul::primitive_desc>(
engine.get_onednn_engine(),
input_md,
weights_md,
bias_md,
output_md,
attr);
} else {
return std::make_shared<dnnl::matmul::primitive_desc>(
engine.get_onednn_engine(),
input_md,
weights_md,
output_md,
attr);
}
}

public:
void save(BinaryOutputBuffer& ob) const override {
#ifdef ONEDNN_PRIMITIVE_SERIALIZATION
Expand All @@ -158,8 +220,10 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
auto prim = impl_params->typed_desc<fully_connected>();
size_t input_size = prim->input_size;
bool has_bias = !prim->bias.empty();
bool is_compressed = prim->compressed_weights;
ob << input_size;
ob << has_bias;
ob << is_compressed;

std::vector<uint8_t> prim_cache;
prim_cache = _prim.get_cache_blob();
Expand All @@ -173,12 +237,19 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {

size_t input_size = 2;
bool has_bias = false;
bool is_compressed = false;
ib >> input_size;
ib >> has_bias;
ib >> is_compressed;

const kernel_impl_params* impl_params = reinterpret_cast<kernel_impl_params*>(ib.getKernelImplParams());
auto prim_desc = get_fully_connected_primitive_descriptor(*impl_params, ib.get_engine(), input_size, has_bias, *_attrs);
_pd = *prim_desc;
if (is_compressed) {
auto prim_desc = get_matmul_primitive_descriptor(*impl_params, ib.get_engine(), input_size, has_bias, *_attrs);
_pd = *prim_desc;
} else {
auto prim_desc = get_inner_product_primitive_descriptor(*impl_params, ib.get_engine(), input_size, has_bias, *_attrs);
_pd = *prim_desc;
}

std::vector<uint8_t> prim_cache;
ib >> prim_cache;
Expand All @@ -194,10 +265,35 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
auto& config = impl_params.prog->get_config();
auto attr = arg.get_onednn_primitive_attributes();
auto prim = impl_params.typed_desc<fully_connected>();
auto prim_desc = get_fully_connected_primitive_descriptor(impl_params, impl_params.prog->get_engine(),
prim->input_size, !prim->bias.empty(), *attr);

return cldnn::make_unique<fully_connected_onednn>(engine, config, attr, *prim_desc, get_weights_reorder(impl_params, *prim_desc));
// There may be a performance difference between InnerProduct and MatMul primitives in oneDNN,
// so use MatMul only for weights compression and IP for all other cases.
if (prim->compressed_weights) {
attr->set_fpmath_mode(dnnl::fpmath_mode::f16, true);
if (!prim->decompression_scale.empty()) {
auto decompression_scale_idx = !arg.bias_term() ? 2 : 3;
auto data_type = convert_data_type(arg.get_dependency(decompression_scale_idx).get_output_layout().data_type);
attr->set_scales(DNNL_ARG_WEIGHTS, 1 << 1, dnnl::memory::dims{}, data_type);
}

if (prim->decompression_zero_point_scalar.has_value()) {
OPENVINO_ASSERT(!prim->decompression_zero_point_scalar.has_value(), "[GPU] OneDNN can't use scalar as a zero point value\n");
} else if (!prim->decompression_zero_point.empty()) {
auto decompression_zp_idx = !arg.bias_term() ? 3 : 4;
auto data_type = convert_data_type(arg.get_dependency(decompression_zp_idx).get_output_layout().data_type);
attr->set_zero_points(DNNL_ARG_WEIGHTS, 1 << 1, dnnl::memory::dims{}, data_type);
}

auto prim_desc = get_matmul_primitive_descriptor(impl_params, impl_params.prog->get_engine(),
prim->input_size, !prim->bias.empty(), *attr);

return cldnn::make_unique<fully_connected_onednn>(engine, config, attr, *prim_desc);
} else {
auto prim_desc = get_inner_product_primitive_descriptor(impl_params, impl_params.prog->get_engine(),
prim->input_size, !prim->bias.empty(), *attr);

return cldnn::make_unique<fully_connected_onednn>(engine, config, attr, *prim_desc, get_weights_reorder(impl_params, *prim_desc));
}
}
};

Expand Down
3 changes: 3 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/onednn/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ dnnl::memory::desc layout_to_memory_desc(cldnn::layout l, dnnl::memory::format_t
} else if (target_fmt == dnnl::memory::format_tag::ab) {
dims.push_back(l.batch());
dims.push_back(l.get_tensor().count() / l.batch());
} else if (target_fmt == dnnl::memory::format_tag::ba) {
dims.push_back(l.feature());
dims.push_back(l.get_tensor().count() / l.feature());
} else if (flatten) {
dims = flatten_tensor(l.get_tensor());
} else {
Expand Down
35 changes: 28 additions & 7 deletions src/plugins/intel_gpu/src/graph/layout_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -895,12 +895,25 @@ static bool is_node_for_onednn(deconvolution_node const& node) {


static bool is_node_for_onednn(fully_connected_node const& node) {
if (!layout_optimizer::are_data_types_suitable_for_onednn((program_node&)node))
return false;

auto fc_prim = node.get_primitive();
// onednn impl doesn't support compressed weights for now
if (fc_prim->compressed_weights)

if (fc_prim->compressed_weights) {
auto weights_dt = node.weights().get_output_layout().data_type;
if (ov::element::Type(weights_dt).bitwidth() != 8)
return false;

if (fc_prim->decompression_zero_point_scalar.has_value())
return false;

if (!fc_prim->decompression_zero_point.empty()) {
auto decompression_zp_idx = fc_prim->bias.empty() ? 3 : 4;
auto decompression_zp_dt = node.get_input_layout(decompression_zp_idx).data_type;
if (weights_dt != decompression_zp_dt)
return false;
}
}

if (!layout_optimizer::are_data_types_suitable_for_onednn((program_node&)node))
return false;

auto output_layout = node.get_output_layout();
Expand Down Expand Up @@ -1332,8 +1345,16 @@ bool layout_optimizer::are_data_types_suitable_for_onednn(program_node& node) {
return onednn_check_data_types_for_deconvolution(in_dt, wei_dt, out_dt);
} else if (node.is_type<fully_connected>() || node.is_type<gemm>()) {
bool is_fc = node.is_type<fully_connected>();
auto wei_dt = is_fc ? node.as<fully_connected>().weights().get_output_layout(false).data_type :
node.as<gemm>().get_input_layout(1).data_type;
data_types wei_dt;
if (is_fc) {
const auto& fc_node = node.as<fully_connected>();
const auto fc_prim = fc_node.get_primitive();
wei_dt = fc_node.weights().get_output_layout(false).data_type;
if (fc_prim->compressed_weights && ov::element::Type(wei_dt).bitwidth() == 8)
return true;
} else {
wei_dt = node.as<gemm>().get_input_layout(1).data_type;
}
return onednn_check_data_types_for_fc_gemm(in_dt, wei_dt, out_dt);
} else if (node.is_type<reorder>()) {
auto input_fmt = node.get_input_layout(0).format;
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1676,6 +1676,8 @@ event::ptr primitive_inst::update_weights() {
// incorrect memory buffer may be assigned, so reset cached weights for such case
_reordered_weights_cache.add(original_layout, original_weights_memory);
_impl_params->weights_layout = optional_layout(original_layout);
GPU_DEBUG_TRACE_DETAIL << id() << ": add original weights memory " << original_layout.to_short_string() << " to weights cache; "
<< "cache_size=" << _reordered_weights_cache.size() << "/" << _reordered_weights_cache.capacity() << std::endl;
} else {
auto expected_layout = reorder_kernel_params->get_output_layout();
// Set original partial shape, because it may be lost during kernel_selector::weights_tensor -> layout conversion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon
auto convert_m = wrap_type<ov::op::v0::Convert>({weights_m});

auto sub_const_m = wrap_type<ov::op::v0::Constant>(consumers_count(1));
auto subtract_m = wrap_type<ov::op::v1::Subtract>({convert_m, sub_const_m});
auto sub_convert_const_m = wrap_type<ov::op::v0::Convert>({sub_const_m});
auto sub_with_convert_m = wrap_type<ov::op::v1::Subtract>({convert_m, sub_convert_const_m});
auto sub_no_convert_m = wrap_type<ov::op::v1::Subtract>({convert_m, sub_const_m});
auto subtract_m = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{sub_with_convert_m, sub_no_convert_m});

auto mul_const_m = wrap_type<ov::op::v0::Constant>(consumers_count(1));
auto mul_with_sub_m = wrap_type<ov::op::v1::Multiply>({subtract_m, mul_const_m});
Expand Down Expand Up @@ -97,7 +100,7 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon
const auto& scale = reshape_const_to_2d(pattern_map.at(mul_const_m).get_node_shared_ptr());
std::shared_ptr<ov::Node> optional_zero_point = nullptr;

const bool with_zero_point = pattern_map.count(subtract_m) > 0;
const bool with_zero_point = pattern_map.count(sub_no_convert_m) > 0 || pattern_map.count(sub_with_convert_m) > 0;
if (with_zero_point) {
optional_zero_point = reshape_const_to_2d(pattern_map.at(sub_const_m).get_node_shared_ptr());
}
Expand Down
11 changes: 10 additions & 1 deletion src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,16 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
return !is_type<ov::op::v0::MatMul>(next_node);
});

manager.register_pass<ov::pass::MarkDequantizationSubgraph>(ov::element::TypeVector{ov::element::u8, ov::element::u4, ov::element::i4}, true);
// Disable subtract folding only for the dGPUs to meet the requirements of oneDNN:
// it expects to have the same data type for weights and zero points (apply it only for u8 data type, since other compression
// types are not supported by oneDNN)
if (device_info.supports_immad) {
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(ov::element::TypeVector{ov::element::u8}, false);
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(ov::element::TypeVector{ov::element::u4, ov::element::i4}, true);
} else {
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(ov::element::TypeVector{ov::element::u8, ov::element::u4, ov::element::i4}, true);
}

// Need to check if transfomrations work correctly for mixed models with both compression and quantization at the same time.
if (!is_model_quantized)
pass_config->set_callback<ov::pass::MarkDequantizationSubgraph>(is_non_supported_decompression_op);
Expand Down
Loading

0 comments on commit c332243

Please sign in to comment.