Skip to content

Commit

Permalink
[GPU] Integrate dynamic quantization for onednn (#26940)
Browse files Browse the repository at this point in the history
### Details:
 - Integrated grouped dynamic quantization from onednn
 - Integrated asymmetric per-token dynamic quantization from onednn
 - Those are not enabled by default, yet

### Tickets:
 - 148732, 157869, 157589
  • Loading branch information
isanghao authored Dec 9, 2024
1 parent e8fa9f7 commit 179e1e0
Show file tree
Hide file tree
Showing 25 changed files with 420 additions and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class FullyConnectedCompressed : public FullyConnected {
const ov::Output<Node> &w_decompression_scale,
const ov::Output<Node> &w_decompression_zero_point,
const ov::Output<Node> &a_decompression_scale,
const ov::Output<Node> &a_decompression_zero_point,
const ov::element::Type output_type = ov::element::undefined);


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,19 @@ struct dynamic_quantize : public primitive_base<dynamic_quantize> {
/// @param output_size Output data size of the primitive
dynamic_quantize(const primitive_id& id,
const input_info& input,
const Attributes& attrs)
const Attributes& attrs,
const size_t input_size = 3)
: primitive_base(id, {input})
, attrs(attrs) {
, attrs(attrs)
, input_size(input_size) {
num_outputs = 2;
if (attrs.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric &&
attrs.output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::Planar)
num_outputs++;
}

Attributes attrs;
size_t input_size;

size_t hash() const override {
size_t seed = primitive::hash();
Expand All @@ -46,6 +49,7 @@ struct dynamic_quantize : public primitive_base<dynamic_quantize> {
seed = hash_combine(seed, attrs.scale_dt.hash());
seed = hash_combine(seed, attrs.zp_dt.hash());
seed = hash_combine(seed, attrs.output_storage_type);
seed = hash_combine(seed, input_size);

return seed;
}
Expand All @@ -62,7 +66,8 @@ struct dynamic_quantize : public primitive_base<dynamic_quantize> {
attrs.quantization_dt == rhs_casted.attrs.quantization_dt &&
attrs.scale_dt == rhs_casted.attrs.scale_dt &&
attrs.zp_dt == rhs_casted.attrs.zp_dt &&
attrs.quantization_type == rhs_casted.attrs.quantization_type;;
attrs.quantization_type == rhs_casted.attrs.quantization_type &&
input_size == rhs_casted.input_size;
}

void save(BinaryOutputBuffer& ob) const override {
Expand All @@ -75,6 +80,7 @@ struct dynamic_quantize : public primitive_base<dynamic_quantize> {
ob << make_data(&attrs.output_storage_type, sizeof(attrs.output_storage_type));
ob << attrs.scales_zp_output_order;
ob << attrs.group_sizes;
ob << input_size;
}

void load(BinaryInputBuffer& ib) override {
Expand All @@ -87,6 +93,7 @@ struct dynamic_quantize : public primitive_base<dynamic_quantize> {
ib >> make_data(&attrs.output_storage_type, sizeof(attrs.output_storage_type));
ib >> attrs.scales_zp_output_order;
ib >> attrs.group_sizes;
ib >> input_size;
}
};
} // namespace cldnn
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ struct fully_connected : public primitive_base<fully_connected> {
decompression_scale(decompression_scale),
decompression_zero_point(decompression_zero_point),
dynamic_quantized_activation(false),
dynamic_quantized_activation_zp(false),
input_size(input_size),
weights_rank(weights_rank) {
OPENVINO_ASSERT(!decompression_scale.empty(), "[GPU] Compressed fully connected requires at least decompression scale input");
Expand All @@ -109,13 +110,15 @@ struct fully_connected : public primitive_base<fully_connected> {
/// @param compression_scale Primitive id containing scale factors for weights decompression.
/// @param compression_zero_point Primitive id containing zero points for weights decompression.
/// @param activation_scale Primitive id containing scale factor for activation.
/// @param activation_zero_point Primitive id containing zero point for activation.
fully_connected(const primitive_id& id,
const input_info& input,
const primitive_id& weights,
const primitive_id& bias,
const primitive_id& decompression_scale,
const primitive_id& decompression_zero_point,
const input_info& activation_scale,
const input_info& activation_zero_point,
const data_types data_type,
const size_t input_size = 2,
const size_t weights_rank = 2)
Expand All @@ -126,11 +129,15 @@ struct fully_connected : public primitive_base<fully_connected> {
decompression_scale(decompression_scale),
decompression_zero_point(decompression_zero_point),
dynamic_quantized_activation(false),
dynamic_quantized_activation_zp(false),
activation_scale(activation_scale),
activation_zero_point(activation_zero_point),
input_size(input_size),
weights_rank(weights_rank) {
if (activation_scale.is_valid())
dynamic_quantized_activation = true;
if (activation_zero_point.is_valid())
dynamic_quantized_activation_zp = true;

OPENVINO_ASSERT(!decompression_scale.empty(), "[GPU] Compressed fully connected requires at least decompression scale input");
}
Expand All @@ -144,7 +151,9 @@ struct fully_connected : public primitive_base<fully_connected> {
primitive_id decompression_scale = "";
primitive_id decompression_zero_point = "";
bool dynamic_quantized_activation = false;
bool dynamic_quantized_activation_zp = false;
input_info activation_scale = {"", 0};
input_info activation_zero_point = {"", 0};
optional_value<float> decompression_zero_point_scalar = optional_value<float>();

/// @brief Primitive dimension size.
Expand All @@ -161,6 +170,7 @@ struct fully_connected : public primitive_base<fully_connected> {
seed = hash_combine(seed, !decompression_scale.empty());
seed = hash_combine(seed, !decompression_zero_point.empty());
seed = hash_combine(seed, activation_scale.is_valid());
seed = hash_combine(seed, activation_zero_point.is_valid());
seed = hash_combine(seed, decompression_zero_point_scalar.has_value());
seed = hash_combine(seed, decompression_zero_point_scalar.value_or(0.0f));
return seed;
Expand All @@ -179,6 +189,7 @@ struct fully_connected : public primitive_base<fully_connected> {
decompression_scale.empty() == rhs_casted.decompression_scale.empty() &&
decompression_zero_point.empty() == rhs_casted.decompression_zero_point.empty() &&
activation_scale.is_valid() == rhs_casted.activation_scale.is_valid() &&
activation_zero_point.is_valid() == rhs_casted.activation_zero_point.is_valid() &&
decompression_zero_point_scalar.value_or(0.0f) == rhs_casted.decompression_zero_point_scalar.value_or(0.0f);
}

Expand All @@ -190,9 +201,11 @@ struct fully_connected : public primitive_base<fully_connected> {
ob << decompression_scale;
ob << decompression_zero_point;
ob << activation_scale;
ob << activation_zero_point;
ob << input_size;
ob << weights_rank;
ob << dynamic_quantized_activation;
ob << dynamic_quantized_activation_zp;

if (decompression_zero_point_scalar.has_value()) {
ob << true;
Expand All @@ -211,9 +224,11 @@ struct fully_connected : public primitive_base<fully_connected> {
ib >> decompression_scale;
ib >> decompression_zero_point;
ib >> activation_scale;
ib >> activation_zero_point;
ib >> input_size;
ib >> weights_rank;
ib >> dynamic_quantized_activation;
ib >> dynamic_quantized_activation_zp;

bool has_value;
ib >> has_value;
Expand Down Expand Up @@ -243,6 +258,9 @@ struct fully_connected : public primitive_base<fully_connected> {
if (activation_scale.is_valid())
ret.push_back(activation_scale);

if (activation_zero_point.is_valid())
ret.push_back(activation_zero_point);

return ret;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class debug_configuration {
std::vector<std::string> dynamic_quantize_layers_without_onednn; // Specify Fully-connected layers which enable Dynamic quantization
int use_kv_cache_compression; // Enable KV-cache compression
int dynamic_quantize_group_size; // Enable Dynamic quantization for fully connected primitive by specified group size
int dynamic_quantize_asym; // Use asymmetric dynamic quantization
int disable_horizontal_fc_fusion; // Disable fc horizontal fusion
int disable_fc_swiglu_fusion; // Disable swiglu fusion to fc
std::set<int64_t> dump_iteration; // Dump n-th execution of network.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,9 @@ void prepare_primitive_fusing::fuse_bias(program &p) {
if (desc->decompression_zero_point_scalar.has_value())
fc_with_bias_prim->decompression_zero_point_scalar = desc->decompression_zero_point_scalar.value();
fc_with_bias_prim->activation_scale = desc->activation_scale;
fc_with_bias_prim->activation_zero_point = desc->activation_zero_point;
fc_with_bias_prim->dynamic_quantized_activation = desc->dynamic_quantized_activation;
fc_with_bias_prim->dynamic_quantized_activation_zp = desc->dynamic_quantized_activation_zp;
}
auto& new_fc_node = p.get_or_create(fc_with_bias_prim);
fuse_bias_f(fc, new_fc_node, bias_node, eltw_node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct dynamic_quantize_impl : typed_primitive_impl_ocl<dynamic_quantize> {

static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) {
auto params = get_default_params<kernel_selector::dynamic_quantize_params>(impl_param, is_shape_agnostic);
const auto& primitive = impl_param.typed_desc<dynamic_quantize>();
params.outputs.push_back(convert_data_tensor(impl_param.get_output_layout(1)));

// In Some model, the feature size could be dynamic in input0.
Expand All @@ -48,6 +49,10 @@ struct dynamic_quantize_impl : typed_primitive_impl_ocl<dynamic_quantize> {
if (impl_param.output_layouts.size() > 2)
params.outputs.push_back(convert_data_tensor(impl_param.get_output_layout(2)));

// Keep 2d data as bf layout
if (primitive->input_size == 2)
params.outputs[0] = params.outputs[0].FlattenFeatureAndSpatials();

const auto& desc = impl_param.typed_desc<dynamic_quantize>();
params.group_sizes = desc->attrs.group_sizes;
params.scales_output_order = desc->attrs.scales_zp_output_order;
Expand All @@ -68,7 +73,8 @@ namespace detail {
attach_dynamic_quantize_impl::attach_dynamic_quantize_impl() {
auto types = {
data_types::f16,
data_types::i8
data_types::i8,
data_types::u8
};

auto formats = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,16 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
if (prim->activation_scale.is_valid()) {
auto activation_scale_idx = idx++;
auto act_scale_mem = instance.dep_memory_ptr(activation_scale_idx);
// TODO: handle group_size here
dnnl::memory::desc desc = onednn::layout_to_memory_desc(act_scale_mem->get_layout(), dnnl::memory::format_tag::a, true);
dnnl::memory::desc desc = onednn::layout_to_memory_desc(act_scale_mem->get_layout(), dnnl::memory::format_tag::ab, true);
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, act_scale_mem->get_onednn_memory(desc)});
}

if (prim->activation_zero_point.is_valid()) {
auto activation_zp_idx = idx++;
auto act_zp_mem = instance.dep_memory_ptr(activation_zp_idx);
dnnl::memory::desc desc = onednn::layout_to_memory_desc(act_zp_mem->get_layout(), dnnl::memory::format_tag::ab, true);
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC_0, act_zp_mem->get_onednn_memory(desc)});
}
}

return args;
Expand Down Expand Up @@ -245,6 +251,7 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
ob << has_bias;
ob << is_compressed;
ob << prim->dynamic_quantized_activation;
ob << prim->dynamic_quantized_activation_zp;

bool has_decompression_scale = !prim->decompression_scale.empty();
if (has_decompression_scale) {
Expand All @@ -271,10 +278,12 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
bool has_bias = false;
bool is_compressed = false;
bool dynamic_quantized_activation;
bool dynamic_quantized_activation_zp;
ib >> input_size;
ib >> has_bias;
ib >> is_compressed;
ib >> dynamic_quantized_activation;
ib >> dynamic_quantized_activation_zp;

const kernel_impl_params* impl_params = reinterpret_cast<kernel_impl_params*>(ib.getKernelImplParams());
auto prim = impl_params->typed_desc<fully_connected>();
Expand All @@ -293,11 +302,12 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {

bool has_decompression_zp = !prim->decompression_zero_point.empty() || prim->decompression_zero_point_scalar.has_value();
auto& arg = impl_params->get_program().get_node(impl_params->desc->id).as<fully_connected>();
int idx = !arg.bias_term() ? 3 : 4;
int idx = !arg.bias_term() ? 2 : 3;

if (has_decompression_zp) {
ib >> make_data(&_dzp_data_type, sizeof(dnnl::memory::data_type));
auto dzp_layout = arg.get_dependency(idx++).get_output_layout();
auto decompression_zp_idx = ++idx;
auto dzp_layout = arg.get_dependency(decompression_zp_idx).get_output_layout();

if (dzp_layout.count() == 1) {
_attrs->set_zero_points(DNNL_ARG_WEIGHTS, COMMON, dnnl::memory::dims{}, _dzp_data_type);
Expand All @@ -312,12 +322,17 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
}

if (dynamic_quantized_activation) {
// TODO: it supports per-token activation scale only
auto src_scale_idx = ++idx;
auto partial_shape = impl_params->get_input_layout(0).get_partial_shape();
auto innermost_len = partial_shape[partial_shape.size() - 1].get_length();

auto act_scale_data_type = convert_data_type(impl_params->get_input_layout(idx).data_type);
_attrs->set_scales(DNNL_ARG_SRC, GROUPED, dnnl::memory::dims{1, innermost_len}, act_scale_data_type);
auto& src_scale_shape = impl_params->input_layouts[src_scale_idx].get_partial_shape();
int src_scale_ngroups = src_scale_shape[src_scale_shape.size() - 1].get_length();
int src_group_size = innermost_len / src_scale_ngroups;

auto act_scale_data_type = convert_data_type(impl_params->get_input_layout(src_scale_idx).data_type);
_attrs->set_scales(DNNL_ARG_SRC, GROUPED, dnnl::memory::dims{1, src_group_size}, act_scale_data_type);
if (dynamic_quantized_activation_zp)
_attrs->set_zero_points(DNNL_ARG_SRC, GROUPED, dnnl::memory::dims{1, src_group_size}, dnnl::memory::data_type::u8);
}

if (is_compressed) {
Expand Down Expand Up @@ -387,15 +402,21 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
}

if (prim->dynamic_quantized_activation) {
// Note: it supports per-token activation scale only
++idx;
auto partial_shape = impl_params.input_layouts[0].get_partial_shape();
auto src_scale_idx = ++idx;
auto& partial_shape = impl_params.input_layouts[0].get_partial_shape();
auto innermost_len = partial_shape[partial_shape.size() - 1].get_length();
auto& src_scale_shape = impl_params.input_layouts[src_scale_idx].get_partial_shape();
int src_scale_ngroups = src_scale_shape[src_scale_shape.size() - 1].get_length();
int src_group_size = innermost_len / src_scale_ngroups;

auto act_scale_data_type = convert_data_type(impl_params.input_layouts[idx].data_type);
attr->set_scales(DNNL_ARG_SRC, GROUPED, dnnl::memory::dims{1, innermost_len}, act_scale_data_type);
auto act_scale_data_type = convert_data_type(impl_params.input_layouts[src_scale_idx].data_type);
attr->set_scales(DNNL_ARG_SRC, GROUPED, dnnl::memory::dims{1, src_group_size}, act_scale_data_type);

if (prim->activation_zero_point.is_valid())
attr->set_zero_points(DNNL_ARG_SRC, GROUPED, dnnl::memory::dims{1, src_group_size}, dnnl::memory::data_type::u8);
}


auto prim_desc = get_matmul_primitive_descriptor(impl_params, impl_params.prog->get_engine(),
prim->input_size, !prim->bias.empty(), *attr);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct FullyConnectedImplementationManager : public ImplementationManager {
one_of(wei_dt, {data_types::i8, data_types::u8}) &&
one_of(out_dt, {data_types::f16, data_types::f32, data_types::i32, data_types::i8, data_types::u8});
bool compressed_case = fc_prim->compressed_weights &&
one_of(in0_dt, {data_types::f16, data_types::f32, data_types::i8}) &&
one_of(in0_dt, {data_types::f16, data_types::f32, data_types::i8, data_types::u8}) &&
one_of(wei_dt, {data_types::u8, data_types::i8, data_types::u4, data_types::i4}) &&
one_of(out_dt, {data_types::f16, data_types::f32, data_types::u8, data_types::i8});
if (!f16f16_case && !f32f32_case && !u8s8_case && !compressed_case)
Expand Down
Loading

0 comments on commit 179e1e0

Please sign in to comment.