Skip to content

Commit

Permalink
[TEMP]
Browse files Browse the repository at this point in the history
Signed-off-by: Min, Byungil <[email protected]>
  • Loading branch information
byungilm committed Dec 2, 2024
1 parent ae38233 commit 73dfeff
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ KERNEL(quantize_input)(
const uint quantize_block = QUANTIZE_GROUP_SIZE / 4;
#endif

if (get_global_id(0) == 1 && get_global_id(2) == 0) {
printf("gid(%u) input_offset(%u) offset(%u) IFM_SIZE(%u) QUANTIZE_GROUP_SIZE(%u)\n",
gid, input_offset, offset, (uint)IFM_SIZE, (uint)QUANTIZE_GROUP_SIZE);
}
#if PER_TOKEN_QUANTIZE_SIZE
if (get_global_id(0) == 1 && get_global_id(2) == 0) {
printf("gid(%u) input_offset(%u) offset(%u) quantize_block(%u) : PER_TOKEN_QUANTIZE_SIZE(%d) IFM_SIZE(%u) QUANTIZE_GROUP_SIZE(%u) QUANTIZE_GROUP_BLOCKS_PER_TOKEN(%u)\n",
gid, input_offset, offset, quantize_block, (int)PER_TOKEN_QUANTIZE_SIZE,
(uint)IFM_SIZE, (uint)QUANTIZE_GROUP_SIZE, (uint)QUANTIZE_GROUP_BLOCKS_PER_TOKEN);
}
#endif

// const uint input_offset = offset * INPUT_ELEMENTS_COUNT;
// const uint quantize_block = INPUT_ELEMENTS_COUNT / 4;
Expand All @@ -70,36 +73,45 @@ KERNEL(quantize_input)(
// printf("\n");
// }

INPUT0_TYPE max_value = 0.001;
float max_value = 0.01f;
for (uint i = 0 ; i < quantize_block ; i+=8) {
INPUT0_TYPE temp = fmax(fmax(fmax(max[i], max[i+1]), fmax(max[i+2], max[i+3])),
fmax(fmax(max[i+4], max[i+5]), fmax(max[i+6], max[i+7])));
// if (get_global_id(0) == 0 && get_global_id(2) == 0)
// printf(" (%.3f)", temp);

max_value = fmax(max_value, temp);
//#if COMPRESSED_WEIGHTS_INT8
// if (max_value < 1)
// printf("(%.3f,%.3f)", max_value, (float)temp);
//#endif
max_value = fmax(max_value, (float)temp);
}

// if (get_global_id(0) == 0 && get_global_id(2) == 0) {
// printf("\n");
// }
// if (max_value < 1)
// printf("X(%.3f)\n", max_value);

half quan_scale = (half)max_value / 127;
float quan_scale = max_value / 127.f;
#if COMPRESSED_WEIGHTS_INT8
#if PER_TOKEN_QUANTIZE_SIZE
int quantized_sum[QUANTIZE_GROUP_BLOCKS_PER_TOKEN] = { 0 }; // 1024 / 32 = 32
if (get_global_id(0) == 0 && get_global_id(2) == 0) {
printf("\n");
if (max_value < 0.1f) {
printf("(%u:%.3f:%.6f)", get_global_id(0), max_value, quan_scale);
}
#else
printf("Y");
int quantized_sum = 0;
#endif
#endif

// Store quantized input
// float de_quan_scale = 1.f / quan_scale;
for (uint i = 0 ; i < quantize_block ; ++i) {
half4 buff = input_0[i] / (half4)quan_scale;
// float4 buff = { 0, 0, 0, 0 };
// if (quan_scale != 0)
// buff = (convert_float4)(input_0[i]) / quan_scale;
float4 buff = (convert_float4)(input_0[i]) / (float4)quan_scale;
quantized_value[i] = CAT(CAT(convert_, MAKE_VECTOR_TYPE(DQ_TYPE, INPUT_LOAD_SIZE)), _rte)(buff);

#if COMPRESSED_WEIGHTS_INT8
#if PER_TOKEN_QUANTIZE_SIZE
uint index = quantize_block / QUANTIZE_GROUP_BLOCKS_PER_TOKEN;
Expand All @@ -112,6 +124,7 @@ KERNEL(quantize_input)(
vstore4(quantized_value[i], 0, &quantized_input[input_offset + i * 4]);
}

#if 0
// Store quantizing scale and activation sum(only if int8 asym)
// [TEST]
// if (get_global_id(0) < 8 && get_global_id(2) == 0) {
Expand All @@ -137,6 +150,8 @@ KERNEL(quantize_input)(
quan_var[(offset * 2) + 1] = CAT(CAT(convert_, float), _rte)(quantized_sum);
#endif
#endif

#endif
}
#else // !FC_KERNEL_DYNAMIC_QUANTIZE

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ static bool is_weight_dyn_quantizable(const fully_connected_params& params) {
static bool is_per_token_dynamic_quantize(const fully_connected_params& params) {
auto dynamic_quantization_group_size = params.dynamic_quantization_group_size;
if (dynamic_quantization_group_size == UINT64_MAX) {
// std::cout << "Set FC dynamic Quantize group size to per-token" << std::endl;
//std::cout << "Set FC dynamic Quantize group size to per-token" << std::endl;
return true;
}

Expand Down Expand Up @@ -107,14 +107,17 @@ static size_t get_dynamic_quantize_group_size(const fully_connected_params& para
const size_t scale_group_size = params.weights.IFM().v / params.decompression_scale.Feature().v;
// Per-token dyn-quan
if (dynamic_quantization_group_size != 0 && is_per_token_dynamic_quantize(params)) {
if (is_dyn_quan_8bit_asym(params)) {
// Should calculate activation sum by scale_group_size for post-operation
dynamic_quantization_group_size = scale_group_size;
} else {
// dynamic_quantization_group_size = get_input_bf_size(params).second;
dynamic_quantization_group_size = scale_group_size;
}

// if (is_dyn_quan_8bit_asym(params)) {
// // Should calculate activation sum by scale_group_size for post-operation
// dynamic_quantization_group_size = scale_group_size;
// // printf("!!!! per token dyn-quan(%s) : scale_group_size(%u) input_f(%d) get_input_bf_size(params).second(%u)\n",
// // ((is_dyn_quan_8bit_asym(params) == true) ? "Y" : "N"),
// // scale_group_size, (int)get_input_bf_size(params).second, dynamic_quantization_group_size);
// } else {
// dynamic_quantization_group_size = get_input_bf_size(params).second;
// }

dynamic_quantization_group_size = scale_group_size;
return (size_t)dynamic_quantization_group_size;
}

Expand Down Expand Up @@ -701,6 +704,9 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para
if (is_per_token_dynamic_quantize(params)) {
jit.AddConstant(MakeJitConstant("PER_TOKEN_QUANTIZE_SIZE", 1));
jit.AddConstant(MakeJitConstant("QUANTIZE_GROUP_BLOCKS_PER_TOKEN", (get_input_bf_size(params).second / quantize_grp_size)));
// [TEST]
if ((get_input_bf_size(params).second / quantize_grp_size) < 1)
std::cout << " -- QUANTIZE_GROUP_BLOCKS_PER_TOKEN : " << (get_input_bf_size(params).second / quantize_grp_size) << std::endl;
}
} else {
if (add_decompress_scale_post_op)
Expand All @@ -710,6 +716,10 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para
}
jit.AddConstant(MakeJitConstant("DQ_TYPE", "char"));

// if (is_dyn_quan_8bit_asym(params)) {
// std::cout << ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> INT8 ASYM [" << params.layerID << "]" << std::endl;
// }

jit.AddConstant(MakeJitConstant("IFM_SIZE", get_input_bf_size(params).second));
jit.AddConstant(MakeJitConstant("SIMD", simd));
jit.AddConstant(MakeJitConstant("TILE_B", dispatchData.tile_m));
Expand All @@ -729,10 +739,8 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para
// // For decompression post operation, scale group size and dynamic quantizing group size should fit to each other.
// const size_t post_ops_size = (scale_group_size < quantize_grp_size) ? scale_group_size : quantize_grp_size;
jit.AddConstant(MakeJitConstant("NUM_LOOP_IN_DYN_QUAN_GROUP", quantize_grp_size / (dispatchData.tile_mk * simd)));
printf(" -- NUM_LOOP_IN_DYN_QUAN_GROUP(%d)\n", (int)(quantize_grp_size / (dispatchData.tile_mk * simd)));
} else {
jit.AddConstant(MakeJitConstant("NUM_LOOP_IN_DYN_QUAN_GROUP", 1));
printf(" -- NUM_LOOP_IN_DYN_QUAN_GROUP(%d)\n", 1);
}

auto max_tile_b_size = dispatchData.tile_m;
Expand Down Expand Up @@ -844,24 +852,30 @@ void FullyConnected_bf_tiled::GetUpdateDispatchDataFunc(KernelData& kd) const {
kd.kernels[0].skip_execution = false;
size_t input_f = get_input_bf_size(prim_params).second;
size_t input_size = input_f * dispatchData.tile_m * dispatchData.gws[2];
printf(">> Update Kernel!!!!! : input_size(%u), input_f(%u) quantize_grp_size(%u)\n", input_size, input_f, quantize_grp_size);

if (kd.internalBufferSizes[0] < input_size) {
if (kd.internalBufferSizes[0] < input_size || true) {
kd.internalBufferSizes.clear();
// quantized input is char type
kd.internalBufferSizes.push_back(input_size);
// half type of de_quan_scale and activation sum for each quantized group
// [TEST]
// kd.internalBufferSizes.push_back((input_size / quantize_grp_size) * 2 * 2);
kd.internalBufferSizes.push_back((input_size / quantize_grp_size) * 2 * 4);
printf(" ----------------- updated buffer\n");
}

if (is_per_token_dynamic_quantize(prim_params)) {
// Group size fit to the whole ifm size of each token
kd.kernels[0].params.workGroups.global = {std::max((input_size / input_f), (size_t)1), 1, 1};
} else {
kd.kernels[0].params.workGroups.global = {std::max((input_size / quantize_grp_size), (size_t)1), 1, 1};
}
// if (is_per_token_dynamic_quantize(prim_params)) {
// // Group size fit to the whole ifm size of each token
// kd.kernels[0].params.workGroups.global = {std::max((input_size / input_f), (size_t)1), 1, 1};
// } else {
// kd.kernels[0].params.workGroups.global = {std::max((input_size / quantize_grp_size), (size_t)1), 1, 1};
// }
kd.kernels[0].params.workGroups.global = {std::max((input_size / quantize_grp_size), (size_t)1), 1, 1};

// [TEST]
kd.kernels[0].params.workGroups.local = {16, 1, 1};
// kd.kernels[0].params.workGroups.local = {1, 1, 1};
}
}
};
Expand Down Expand Up @@ -904,7 +918,7 @@ KernelsData FullyConnected_bf_tiled::GetTunedKernelsDataByIndex(const Params &pa
}

KernelsData kernels_data;
if (should_dynamic_quantize(fc_params, true)) {
if (should_dynamic_quantize(fc_params, false)) {
// Use seperate 2 kernels for dynamic quantizing : quantizing_kernel + fc_kernel
// 1st kernel : Dynamic quantizing by dynamic_quantize_grp_size
// 2nd kernel : fully connected kernel with KernelType::DEFAULT. Quantized inputs and scale values could be used.
Expand Down Expand Up @@ -1035,12 +1049,15 @@ KernelsData FullyConnected_bf_tiled::GetMultiKernelsData(const Params &params,
input_size = std::max(input_size, Align(get_input_bf_size(fc_params).first, lws_batches) * get_input_bf_size(fc_params).second);

// [TEST]
if (is_per_token_dynamic_quantize(fc_params))
dyn_quan_dispatch.gws = {input_size / get_input_bf_size(fc_params).second, 1, 1};
else
dyn_quan_dispatch.gws = {input_size / quantize_grp_size, 1, 1};
// if (is_per_token_dynamic_quantize(fc_params))
// dyn_quan_dispatch.gws = {input_size / get_input_bf_size(fc_params).second, 1, 1};
// else
// dyn_quan_dispatch.gws = {input_size / quantize_grp_size, 1, 1};
dyn_quan_dispatch.gws = {input_size / quantize_grp_size, 1, 1};

// [TEST]
dyn_quan_dispatch.lws = {16, 1, 1};
// dyn_quan_dispatch.lws = {1, 1, 1};
quan_kernel.params.workGroups.global = dyn_quan_dispatch.gws;
quan_kernel.params.workGroups.local = dyn_quan_dispatch.lws;
quan_kernel.skip_execution = false;
Expand Down

0 comments on commit 73dfeff

Please sign in to comment.