Skip to content

Commit

Permalink
[GPU] Grouped decompression scale/zp support
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-paramuzov committed Oct 16, 2023
1 parent 755651c commit cd4cec3
Show file tree
Hide file tree
Showing 9 changed files with 353 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "fully_connected_inst.h"
#include "pooling_inst.h"
#include "quantize_inst.h"
#include "reorder_inst.h"
Expand Down Expand Up @@ -847,6 +848,42 @@ bool prepare_quantization::optimize_quantize(program &p, quantize_node& quantize
return true;
}

static void optimize_weights_decompression_parameters(fully_connected_node& fc_node, program& p) {
auto fc_prim = fc_node.get_primitive();
if (!fc_prim->compressed_weights)
return;

auto reorder_bfyx_to_fbyx = [&](size_t dep_id) {
auto& dep = fc_node.get_dependency(dep_id);
auto target_layout = dep.get_output_layout();
target_layout.format = format::fbyx;
auto reorder_prim = std::make_shared<reorder>(dep.id() + "_reorder", dep.id(), target_layout);
p.add_intermediate(reorder_prim, fc_node, dep_id, true);
fc_node.get_dependency(dep_id).recalc_output_layout(false);
};

auto need_reorder = [&](size_t dep_id) {
auto dep_layout = fc_node.get_input_layout(dep_id);
auto dep_pshape = dep_layout.get_partial_shape();

auto groups_count = dep_pshape[dep_pshape.size() - 1].get_length();

return groups_count > 1;
};

auto decompression_scale_idx = !fc_node.bias_term() ? 2 : 3;
if (need_reorder(decompression_scale_idx)) {
reorder_bfyx_to_fbyx(decompression_scale_idx);
}

if (!fc_prim->decompression_zero_point.empty()) {
auto decompression_zp_idx = decompression_scale_idx + 1;
if (need_reorder(decompression_zp_idx)) {
reorder_bfyx_to_fbyx(decompression_zp_idx);
}
}
}

void prepare_quantization::run(program& p) {
auto itr = p.get_processing_order().begin();
while (itr != p.get_processing_order().end()) {
Expand All @@ -859,6 +896,8 @@ void prepare_quantization::run(program& p) {
remove_fake_reorders(p, node->as<reorder>());
} else if (node->is_type<convolution>()) {
prepare_asymmetric_quantization(p, node->as<convolution>());
} else if (node->is_type<fully_connected>()) {
optimize_weights_decompression_parameters(node->as<fully_connected>(), p);
}
}
}
7 changes: 0 additions & 7 deletions src/plugins/intel_gpu/src/graph/impls/ocl/fully_connected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,13 @@ struct fully_connected_impl : typed_primitive_impl_ocl<fully_connected> {
bool has_scale = !primitive->decompression_scale.empty();

size_t offset = primitive->bias.empty() ? 2 : 3;
const auto& weights_pshape = input1_layout.get_partial_shape();
if (has_scale) {
auto scale_layout = input_layouts[offset++];
if (input1_pshape.size() != 2) {
scale_layout.set_partial_shape(reshape_to_2d(scale_layout.get_partial_shape(), weights_pshape[0], primitive->weights_rank));
}
layouts.push_back(scale_layout);
}

if (has_zp) {
auto zp_layout = input_layouts[offset];
if (input1_pshape.size() != 2) {
zp_layout.set_partial_shape(reshape_to_2d(zp_layout.get_partial_shape(), weights_pshape[0], primitive->weights_rank));
}
layouts.push_back(zp_layout);
}

Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_gpu/src/graph/layout_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ bool layout_optimizer::can_fuse_reorder(program_node& prev, program_node& next,
(fmt_prev == format::b_fs_yx_fsv4 &&
prev_output_layout.feature() % 32 == 0 &&
prev_output_layout.spatial(0) == 1 &&
prev_output_layout.spatial(1) == 1)))
prev_output_layout.spatial(1) == 1)) && is_input_reorder(prev, next))
return true;

if (next.is_type<convolution>() && fmt_prev == format::b_fs_yx_fsv16 && fmt_next == format::b_fs_yx_fsv4 && is_input_idx(0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ KERNEL(fc)(
uint input_offset = out_b * TILE_IN_B_PITCH + INPUT0_OFFSET;
uint weights_offset = out_f * INPUT_ELEMENTS_COUNT;

#if COMPRESSED_WEIGHTS
#if COMPRESSED_WEIGHTS && DECOMPRESSION_SCALE_GROUPS_NUM == 1
#if DECOMPRESSION_SCALE_LENGTH > 1 && DECOMPRESSION_SCALE_LENGTH % SIMD == 0
ACCUMULATOR_VEC_TYPE d_scale = BLOCK_READN(ACCUMULATOR_TYPE, TILE_OFM, decompression_scale, out_f);
#elif DECOMPRESSION_SCALE_LENGTH > 1 && DECOMPRESSION_SCALE_LENGTH % SIMD != 0
Expand All @@ -134,9 +134,11 @@ KERNEL(fc)(
ACCUMULATOR_VEC_TYPE d_scale = decompression_scale[0];
#endif

#if !DECOMPRESSION_ZP_TERM
ACCUMULATOR_VEC_TYPE d_zp = 0;
#elif DECOMPRESSION_ZP_LENGTH > 1 && DECOMPRESSION_ZP_LENGTH % SIMD == 0
ACCUMULATOR_TYPE* d_scales = (ACCUMULATOR_TYPE*)(&d_scale);
#endif

#if COMPRESSED_WEIGHTS && DECOMPRESSION_ZP_TERM && DECOMPRESSION_ZP_GROUPS_NUM == 1
#if DECOMPRESSION_ZP_LENGTH > 1 && DECOMPRESSION_ZP_LENGTH % SIMD == 0
ACCUMULATOR_VEC_TYPE d_zp = BLOCK_READN(ACCUMULATOR_TYPE, TILE_OFM, decompression_zp, out_f);
#elif DECOMPRESSION_ZP_LENGTH > 1 && DECOMPRESSION_ZP_LENGTH % SIMD != 0
ACCUMULATOR_VEC_TYPE d_zp = 0;
Expand All @@ -148,9 +150,7 @@ KERNEL(fc)(
#else
ACCUMULATOR_VEC_TYPE d_zp = decompression_zp[0];
#endif

ACCUMULATOR_TYPE* ds = (ACCUMULATOR_TYPE*)(&d_scale);
ACCUMULATOR_TYPE* dzp = (ACCUMULATOR_TYPE*)(&d_zp);
ACCUMULATOR_TYPE* d_zps = (ACCUMULATOR_TYPE*)(&d_zp);
#endif

#if REALIGN_FP16_OFFSET
Expand Down Expand Up @@ -193,7 +193,28 @@ KERNEL(fc)(
ACCUMULATOR_TYPE* w = (ACCUMULATOR_TYPE*)(&wei);
unroll_for(uint kii = 0; kii < TILE_K; ++kii) {
unroll_for(uint fi = 0; fi < TILE_OFM; ++fi) {
w[kii * TILE_OFM + fi] = (w[kii * TILE_OFM + fi] - dzp[fi]) * ds[fi];
const uint w_idx = kii * TILE_OFM + fi;
uint offset_ofm = out_f + fi*SIMD + get_sub_group_local_id();
#if DECOMPRESSION_SCALE_GROUPS_NUM > 1
const uint scale_offset = (offset_ofm % DECOMPRESSION_SCALE_BATCH_NUM) * DECOMPRESSION_SCALE_BATCH_PITCH +
((kii + ki*TILE_K + ni*TILE_IFM*SIMD) / DECOMPRESSION_SCALE_GROUP_SIZE)*DECOMPRESSION_SCALE_FEATURE_PITCH;
ACCUMULATOR_TYPE ds = decompression_scale[scale_offset];
#else
ACCUMULATOR_TYPE ds = d_scales[fi];
#endif

#if DECOMPRESSION_ZP_TERM
#if DECOMPRESSION_ZP_GROUPS_NUM > 1
const uint zp_offset = (offset_ofm % DECOMPRESSION_ZP_BATCH_NUM) * DECOMPRESSION_ZP_BATCH_PITCH +
((kii + ki*TILE_K + ni*TILE_IFM*SIMD) / DECOMPRESSION_ZP_GROUP_SIZE) * DECOMPRESSION_ZP_FEATURE_PITCH;
ACCUMULATOR_TYPE dzp = decompression_zp[zp_offset];
#else
ACCUMULATOR_TYPE dzp = d_zps[fi];
#endif
#else
ACCUMULATOR_TYPE dzp = ACCUMULATOR_VAL_ZERO;
#endif
w[w_idx] = (w[w_idx] - dzp) * ds;
}
}
#endif
Expand Down Expand Up @@ -230,7 +251,28 @@ KERNEL(fc)(
ACCUMULATOR_TYPE* w = (ACCUMULATOR_TYPE*)(&wei);
unroll_for(uint kii = 0; kii < TILE_K; ++kii) {
unroll_for(uint fi = 0; fi < TILE_OFM; ++fi) {
w[kii * TILE_OFM + fi] = (w[kii * TILE_OFM + fi] - dzp[fi]) * ds[fi];
const uint w_idx = kii * TILE_OFM + fi;
uint offset_ofm = out_f + fi*SIMD + get_sub_group_local_id();
#if DECOMPRESSION_SCALE_GROUPS_NUM > 1
const uint scale_offset = (offset_ofm % DECOMPRESSION_SCALE_BATCH_NUM) * DECOMPRESSION_SCALE_BATCH_PITCH +
((kii + ki*TILE_K + ni*TILE_IFM*SIMD) / DECOMPRESSION_SCALE_GROUP_SIZE)*DECOMPRESSION_SCALE_FEATURE_PITCH;
ACCUMULATOR_TYPE ds = decompression_scale[scale_offset];
#else
ACCUMULATOR_TYPE ds = d_scales[fi];
#endif

#if DECOMPRESSION_ZP_TERM
#if DECOMPRESSION_ZP_GROUPS_NUM > 1
const uint zp_offset = (offset_ofm % DECOMPRESSION_ZP_BATCH_NUM) * DECOMPRESSION_ZP_BATCH_PITCH +
((kii + ki*TILE_K + ni*TILE_IFM*SIMD) / DECOMPRESSION_ZP_GROUP_SIZE) * DECOMPRESSION_ZP_FEATURE_PITCH;
ACCUMULATOR_TYPE dzp = decompression_zp[zp_offset];
#else
ACCUMULATOR_TYPE dzp = d_zps[fi];
#endif
#else
ACCUMULATOR_TYPE dzp = ACCUMULATOR_VAL_ZERO;
#endif
w[w_idx] = (w[w_idx] - dzp) * ds;
}
}
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,23 @@ KERNEL(fc)(
for (uint x = 0; x < INPUT0_SIZE_X; ++x)
{
const uint input0_idx = INPUT0_GET_INDEX(b, ofm, y, x);
const uint filter_idx = GET_FILTER_INDEX(FILTER, 0, oym, y, 0, 0);
#if COMPRESSED_WEIGHTS
ACCUMULATOR_TYPE filter_compressed = TO_ACCUMULATOR_TYPE(weights[filter_idx]);
#if DECOMPRESSION_ZP_TERM
ACCUMULATOR_TYPE zp = TO_ACCUMULATOR_TYPE(decompression_zp[DECOMPRESSION_ZP_GET_INDEX_SAFE(0, oym, 0, 0)]);
ACCUMULATOR_TYPE zp = TO_ACCUMULATOR_TYPE(decompression_zp[DECOMPRESSION_ZP_GET_INDEX_SAFE(oym, y / DECOMPRESSION_SCALE_GROUP_SIZE, 0, 0)]);
#else
ACCUMULATOR_TYPE zp = ACCUMULATOR_VAL_ZERO;
#endif
DECOMPRESSION_SCALE_TYPE scale = decompression_scale[DECOMPRESSION_SCALE_GET_INDEX_SAFE(0, oym, 0, 0)];
ACCUMULATOR_TYPE filter_val = (TO_ACCUMULATOR_TYPE(filter_compressed) - TO_ACCUMULATOR_TYPE(zp)) * scale;
const uint decomp_offset = DECOMPRESSION_SCALE_GET_INDEX_SAFE(oym, y / DECOMPRESSION_SCALE_GROUP_SIZE, 0, 0);
DECOMPRESSION_SCALE_TYPE scale = decompression_scale[decomp_offset];
#endif

#if COMPRESSED_WEIGHTS_INT8
const uint filter_idx = GET_FILTER_INDEX(FILTER, 0, oym, y, 0, 0);
ACCUMULATOR_TYPE filter_compressed = TO_ACCUMULATOR_TYPE(weights[filter_idx]);
ACCUMULATOR_TYPE filter_val = (filter_compressed - zp) * scale;
dotProd += (ACCUMULATOR_TYPE)(input[input0_idx]) * (ACCUMULATOR_TYPE)(filter_val);
#else
const uint filter_idx = GET_FILTER_INDEX(FILTER, 0, oym, y, 0, 0);
dotProd += (ACCUMULATOR_TYPE)(input[input0_idx]) * (ACCUMULATOR_TYPE)(weights[filter_idx]);
#endif
}
Expand All @@ -67,19 +72,24 @@ KERNEL(fc)(
for (uint x = 0; x < INPUT0_SIZE_X; ++x)
{
const uint input0_idx = INPUT0_GET_INDEX(b, ifm, y, x);
const uint filter_idx = GET_FILTER_INDEX(FILTER, 0, ofm, ifm, y, x);
#if COMPRESSED_WEIGHTS
FILTER_TYPE filter_compressed = weights[filter_idx];
#if DECOMPRESSION_ZP_TERM
ACCUMULATOR_TYPE zp = decompression_zp[DECOMPRESSION_ZP_GET_INDEX_SAFE(0, ofm, 0, 0)];
ACCUMULATOR_TYPE zp = TO_ACCUMULATOR_TYPE(decompression_zp[DECOMPRESSION_ZP_GET_INDEX_SAFE(ofm, ifm / DECOMPRESSION_SCALE_GROUP_SIZE, 0, 0)]);
#else
ACCUMULATOR_TYPE zp = ACCUMULATOR_VAL_ZERO;
#endif
const uint decomp_offset = DECOMPRESSION_SCALE_GET_INDEX_SAFE(ofm, ifm / DECOMPRESSION_SCALE_GROUP_SIZE, 0, 0);
DECOMPRESSION_SCALE_TYPE scale = decompression_scale[decomp_offset];
#endif

DECOMPRESSION_SCALE_TYPE scale = decompression_scale[DECOMPRESSION_SCALE_GET_INDEX_SAFE(0, ofm, 0, 0)];
ACCUMULATOR_TYPE filter_val = (TO_ACCUMULATOR_TYPE(filter_compressed) - TO_ACCUMULATOR_TYPE(zp)) * scale;

#if COMPRESSED_WEIGHTS_INT8
const uint filter_idx = GET_FILTER_INDEX(FILTER, 0, ofm, ifm, y, x);
FILTER_TYPE filter_compressed = weights[filter_idx];
ACCUMULATOR_TYPE filter_val = (TO_ACCUMULATOR_TYPE(filter_compressed) - zp) * scale;
dotProd += (ACCUMULATOR_TYPE)(input[input0_idx]) * (ACCUMULATOR_TYPE)(filter_val);
#else
const uint filter_idx = GET_FILTER_INDEX(FILTER, 0, ofm, ifm, y, x);
dotProd += (ACCUMULATOR_TYPE)(input[input0_idx]) * (ACCUMULATOR_TYPE)(weights[filter_idx]);
#endif
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,23 @@ JitConstants FullyConnectedKernelBase::GetJitConstants(const fully_connected_par

if (params.compressed) {
jit.AddConstants({MakeJitConstant("COMPRESSED_WEIGHTS", 1)});
if (params.weights.GetDType() == WeightsType::INT8 || params.weights.GetDType() == WeightsType::UINT8) {
jit.AddConstants({MakeJitConstant("COMPRESSED_WEIGHTS_INT8", 1)});
}

const size_t scale_groups_num = params.decompression_scale.Feature().v;
const size_t scale_group_size = params.weights.IFM().v / params.decompression_scale.Feature().v;
jit.AddConstants({MakeJitConstant("DECOMPRESSION_SCALE_TERM", 1)});
jit.AddConstants({MakeJitConstant("DECOMPRESSION_SCALE", params.decompression_scale)});
jit.AddConstants({MakeJitConstant("DECOMPRESSION_SCALE_GROUPS_NUM", scale_groups_num)});
jit.AddConstants({MakeJitConstant("DECOMPRESSION_SCALE_GROUP_SIZE", scale_group_size)});
if (params.has_decompression_zp) {
const size_t zp_groups_num = params.decompression_zero_point.Feature().v;
const size_t zp_group_size = params.weights.IFM().v / params.decompression_zero_point.Feature().v;
jit.AddConstants({MakeJitConstant("DECOMPRESSION_ZP_TERM", 1)});
jit.AddConstants({MakeJitConstant("DECOMPRESSION_ZP", params.decompression_zero_point)});
jit.AddConstants({MakeJitConstant("DECOMPRESSION_ZP_GROUPS_NUM", zp_groups_num)});
jit.AddConstants({MakeJitConstant("DECOMPRESSION_ZP_GROUP_SIZE", zp_group_size)});
}
}

Expand Down
Loading

0 comments on commit cd4cec3

Please sign in to comment.