Skip to content

Commit

Permalink
[GPU] Do not use OneDNN for Compressed FC with zero_point_scalar and …
Browse files Browse the repository at this point in the history
…group size that is not a power of two
  • Loading branch information
Lyamin-Roman committed Apr 29, 2024
1 parent abc40f3 commit a591997
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -385,14 +385,14 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
if (!prim->decompression_scale.empty()) {
auto decompression_scale_idx = !arg.bias_term() ? 2 : 3;
ds_data_type = convert_data_type(arg.get_dependency(decompression_scale_idx).get_output_layout().data_type);
auto ifm = arg.get_dependency(1).get_output_layout().get_dim(1);
auto ngroups = arg.get_dependency(decompression_scale_idx).get_output_layout().get_dim(1);
group_size = ifm / ngroups;
if (!is_four_bit_weight) {
// 8-bit quantized weight
attr->set_scales(DNNL_ARG_WEIGHTS, 1 << 1, dnnl::memory::dims{}, ds_data_type);
} else {
// OneDNN does not support scalar zero-point for s4 and u8 type. Need to broadcast it.
auto ifm = arg.get_dependency(1).get_output_layout().get_dim(1);
auto ngroups = arg.get_dependency(decompression_scale_idx).get_output_layout().get_dim(1);
group_size = ifm / ngroups;
attr->set_scales(DNNL_ARG_WEIGHTS, (1 << 1) + (1 << 0), {group_size, 1}, ds_data_type);
}
}
Expand Down
21 changes: 21 additions & 0 deletions src/plugins/intel_gpu/src/graph/layout_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,27 @@ static bool is_node_for_onednn(fully_connected_node const& node) {
}
}

if (fc_prim->decompression_zero_point_scalar.has_value() && !fc_prim->decompression_scale.empty()) {
size_t decompression_scale_idx = !node.bias_term() ? 2ul : 3ul;
const auto& weights_pshape = node.get_dependency(1).get_output_pshape();
const auto& decompression_scale_pshape = node.get_dependency(decompression_scale_idx).get_output_pshape();

const auto& ifm = weights_pshape[1];
const auto& ngroups = decompression_scale_pshape[1];

if (ifm.is_static() && ngroups.is_static()) {
auto is_power_of_two = [](int64_t x) {
return (x & (x - 1)) == 0;
};

int64_t group_size = ifm.get_length() / ngroups.get_length();

if (!is_power_of_two(group_size)) {
return false;
}
}
}

return true;
}

Expand Down

0 comments on commit a591997

Please sign in to comment.