diff --git a/src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp b/src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp index 348bfea1970a67..92392276f3f128 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp @@ -140,8 +140,8 @@ class debug_configuration { int disable_runtime_skip_reorder; // Disable runtime skip reorder int disable_primitive_fusing; // Disable primitive fusing int disable_fake_alignment; // Disable fake alignment - int enable_dynamic_quantize; // Enable Dynamic quantization for Fully-connected primitive std::vector dynamic_quantize_layers_without_onednn; // Specify Fully-connected layers which enable Dynamic quantization + int dynamic_quantize_group_size; // Enable Dynamic quantization for fully connected primitive by specified group size int disable_horizontal_fc_fusion; // Disable fc horizontal fusion std::set dump_iteration; // Dump n-th execution of network. std::vector load_layers_raw_dump; // List of layers to load dumped raw binary and filenames diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/fully_connected_gpu_bf_tiled.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/fully_connected_gpu_bf_tiled.cl index 92be3f31f97f3f..f71b51dfe24423 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/fully_connected_gpu_bf_tiled.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/fully_connected_gpu_bf_tiled.cl @@ -26,22 +26,27 @@ KERNEL(quantize_input)( __global INPUT0_TYPE* de_quan_scale) { const uint offset = get_global_id(0); - uint input_offset = offset * QUANTIZE_GROUP_SIZE; - half4 input_0[8]; - char4 quantized_value[8]; - half max[8]; + const uint input_offset = offset * QUANTIZE_GROUP_SIZE; + const uint quantize_block = QUANTIZE_GROUP_SIZE / 4; + half4 input_0[quantize_block]; + char4 quantized_value[quantize_block]; + half max[quantize_block]; - unroll_for (uint i = 0 ; i < 8 ; ++i) { + unroll_for (uint i = 0 ; i < quantize_block ; ++i) { input_0[i] = vload4(0, &input[input_offset + i * 4]); max[i] = fmax(fmax(fabs(input_0[i][0]), fabs(input_0[i][1])), fmax(fabs(input_0[i][2]), fabs(input_0[i][3]))); } - half max_value = fmax(fmax(fmax(max[0], max[1]), fmax(max[2], max[3])), - fmax(fmax(max[4], max[5]), fmax(max[6], max[7]))); + half max_value = 0.001; + for (uint i = 0 ; i < quantize_block; i+=8) { + half temp = fmax(fmax(fmax(max[i], max[i+1]), fmax(max[i+2], max[i+3])), + fmax(fmax(max[i+4], max[i+5]), fmax(max[i+6], max[i+7]))); + max_value = fmax(max_value, temp); + } half quan_scale = max_value / 128; - unroll_for (uint i = 0 ; i < 8 ; ++i) { + unroll_for (uint i = 0 ; i < quantize_block ; ++i) { quantized_value[i] = CAT(convert_, MAKE_VECTOR_TYPE(char, INPUT_LOAD_SIZE))(input_0[i] / (half4)quan_scale); vstore4(quantized_value[i], 0, &quantized_input[input_offset + i * 4]); } @@ -715,7 +720,7 @@ inline void FUNC(fc_bf_tiled_kernel_default)( #define PACKED_DQ_TYPE int #define DQ_VEC_TYPE MAKE_VECTOR_TYPE(DQ_TYPE, TILE_IFM) #define DQ_SLM_FILTER_VEC MAKE_VECTOR_TYPE(DQ_TYPE, 4) -#define DQ_SLM_FILTER_PACKED_VEC MAKE_VECTOR_TYPE(FILTER_TYPE, FILTER_LOAD_BLOCK_SIZE) +#define DQ_SLM_FILTER_PACKED_VEC MAKE_VECTOR_TYPE(FILTER_TYPE, FILTER_ACTUAL_LOAD_BLOCK_SIZE) #define DQ_SLM_FILTER_UNPACKED_VEC MAKE_VECTOR_TYPE(DQ_TYPE, FILTER_ELEMENTS_PER_LOAD) #define DQ_FILTER_VEC_TYPE MAKE_VECTOR_TYPE(DQ_TYPE, TILE_K_OFM) @@ -820,11 +825,13 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( // ===================================================================================================================================== // Main computation loop - const uint iterations = MAIN_LOOP_ELEMENTS_COUNT / (TILE_IFM * SIMD); + const uint iterations = MAIN_LOOP_ELEMENTS_COUNT / TILE_IFM_ELEMENTS_SIZE; // TILE_IFM_ELEMENTS_SIZE : (TILE_IFM * SIMD) // Each sub-group loads 2 Batch - uint idx_sglid = (sglid * TILE_K) % QUANTIZE_GROUP_SIZE; // same index for sglid 0~7 : to tile_k direction - uint batch_sglid = (sglid * TILE_K) / QUANTIZE_GROUP_SIZE; // 0 to 1 : to batch direction + uint idx_sglid = (sglid * TILE_K) % TILE_IFM_ELEMENTS_SIZE; // same index for sglid 0~7 : to tile_k direction + uint batch_sglid = (sglid * TILE_K) / TILE_IFM_ELEMENTS_SIZE; // 0 to 1 : to batch direction + const uint scale_pitch = TILE_IN_B_PITCH / QUANTIZE_GROUP_SIZE; + MAKE_VECTOR_TYPE(int, TILE_B) acc_tmp[TILE_OFM] = { }; __attribute__((opencl_unroll_hint(1))) for (uint ni = 0; ni < iterations; ++ni) { uint in_offset = input_offset + (idx_sglid + batch_sglid * TILE_IN_B_PITCH); @@ -832,40 +839,59 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( for (uint bi = 0; bi < HALF_TILE_B; ++bi) { // Load quantizing info from pre-quantizing kernel tiled_input_0[bi] = vload4(0, &quantized_input[in_offset]); - de_quantize_scale[bi * 2] = scale[scale_offset]; - de_quantize_scale[bi * 2 + 1] = scale[scale_offset+ (TILE_IN_B_PITCH/QUANTIZE_GROUP_SIZE)]; - // Packing : Get 4(B)x4(K) integer vector (packing to 4x1 vector) packed_in_0[bi] = as_int(tiled_input_0[bi]); // Next batch in_offset += (TILE_IN_B_PITCH * 2); - scale_offset += (TILE_IN_B_PITCH/QUANTIZE_GROUP_SIZE * 2); + + #if NUM_LOOP_IN_DYN_QUAN_GROUP == 1 + de_quantize_scale[bi * 2] = scale[scale_offset]; + de_quantize_scale[bi * 2 + 1] = scale[scale_offset+ scale_pitch]; + scale_offset += (scale_pitch * 2); + #endif } - input_offset += TILE_IFM * SIMD; + #if NUM_LOOP_IN_DYN_QUAN_GROUP > 1 + if (ni % NUM_LOOP_IN_DYN_QUAN_GROUP == 0) { + unroll_for (uint bi = 0; bi < TILE_B; ++bi) { + de_quantize_scale[bi] = scale[scale_offset]; + scale_offset += scale_pitch; + } + } + #endif - // Packing - MAKE_VECTOR_TYPE(int, TILE_B) acc_tmp[TILE_OFM] = { }; + input_offset += TILE_IFM_ELEMENTS_SIZE; #if TILE_OFM != 2 #error "FC bf_tiled kernel: can't use SLM optimization with TILE_OFM != 2" #endif + #if FILTER_LAYOUT_OS_IYX_OSV16 && TILE_K != 4 + #error "FC bf_tiled kernel: can't use SLM optimization with TILE_K != 2 && OS_IYX_OSV16 layout" + #endif // Skip first barrier synchronization if there is only single outer loop iteration. - #if MAIN_LOOP_ELEMENTS_COUNT / (TILE_IFM * SIMD) > 1 + #if MAIN_LOOP_ELEMENTS_COUNT / TILE_IFM_ELEMENTS_SIZE > 1 barrier(CLK_LOCAL_MEM_FENCE); #endif __local int* char_slm_weight = (__local int*)wei_local_mem; - uint weights_idx = weights_offset + local_id * SIMD * FILTER_LOAD_ITERS * FILTER_LOAD_BLOCK_SIZE; + uint weights_idx = weights_offset + local_id * SIMD * FILTER_LOAD_ITERS * FILTER_ACTUAL_LOAD_BLOCK_SIZE; uint wei_local_idx = local_id * SIMD * FILTER_LOAD_ITERS * (FILTER_LOAD_BLOCK_SIZE/2) + sglid * 2; // DECOMPRESSION_SCALE_POST_OP SHOULD be enabled for dynamic quantize FC : scale is ACCUMULATOR_VAL_ONE unroll_for(uint load_iter = 0; load_iter < FILTER_LOAD_ITERS; ++load_iter) { - SLM_FILTER_PACKED_VEC wei_packed = BLOCK_READN(FILTER_TYPE, FILTER_LOAD_BLOCK_SIZE, weights, weights_idx); - DQ_SLM_FILTER_UNPACKED_VEC dq_wei_unpacked = UNPACK_TRANSPOSED_INT4(DQ_TYPE, *((INT4_PACKED_TYPE_PRELOAD *)&wei_packed)); + #if FILTER_LAYOUT_OS_IYX_OSV16 + SLM_FILTER_PACKED_VEC wei_packed0 = BLOCK_READN(FILTER_TYPE, FILTER_ACTUAL_LOAD_BLOCK_SIZE, weights, weights_idx); + SLM_FILTER_PACKED_VEC wei_packed1 = BLOCK_READN(FILTER_TYPE, FILTER_ACTUAL_LOAD_BLOCK_SIZE, weights, (weights_idx + ((IFM_SIZE / 2) * 16))); + DQ_SLM_FILTER_UNPACKED_VEC dq_wei_unpacked; + dq_wei_unpacked.s0123 = UNPACK_TRANSPOSED_INT4(DQ_TYPE, *((INT4_PACKED_TYPE_PRELOAD*)&wei_packed0)); + dq_wei_unpacked.s4567 = UNPACK_TRANSPOSED_INT4(DQ_TYPE, *((INT4_PACKED_TYPE_PRELOAD*)&wei_packed1)); + #else + SLM_FILTER_PACKED_VEC wei_packed = BLOCK_READN(FILTER_TYPE, FILTER_LOAD_BLOCK_SIZE, weights, weights_idx); + DQ_SLM_FILTER_UNPACKED_VEC dq_wei_unpacked = UNPACK_TRANSPOSED_INT4(DQ_TYPE, *((INT4_PACKED_TYPE_PRELOAD *)&wei_packed)); + #endif // Calculate zero-point and scale only for DECOMPRESSION_SCALE_POST_OP enabled #if DECOMPRESSION_ZP_TERM @@ -914,14 +940,14 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( #endif wei_local_idx += SIMD * (FILTER_LOAD_BLOCK_SIZE/2); - weights_idx += SIMD * FILTER_LOAD_BLOCK_SIZE; + weights_idx += SIMD * FILTER_ACTUAL_LOAD_BLOCK_SIZE; } wei_local_idx = sglid * 2; barrier(CLK_LOCAL_MEM_FENCE); - unroll_for(uint ki = 0; ki < (TILE_IFM * SIMD) / TILE_K; ++ki) { + unroll_for(uint ki = 0; ki < TILE_IFM_ELEMENTS_SIZE / TILE_K; ++ki) { #if TILE_K != 4 #error "FC bf_tiled kernel: unsupported TILE_K size for SLM kernel" #endif @@ -936,9 +962,13 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( acc_tmp[1][bi] = imad_SW(acc_tmp[1][bi], input_val, second_weight); } - weights_offset += TILE_K_OFM_PACKED * SIMD; + #if FILTER_LAYOUT_OS_IYX_OSV16 && TILE_OFM == 2 + weights_offset += (TILE_K_OFM_PACKED/2) * SIMD; + #else + weights_offset += TILE_K_OFM_PACKED * SIMD; + #endif - #if DECOMPRESSION_SCALE_POST_OP && (TILE_IFM * SIMD > DECOMPRESSION_SCALE_GROUP_SIZE) + #if DECOMPRESSION_SCALE_POST_OP && (TILE_IFM_ELEMENTS_SIZE > DECOMPRESSION_SCALE_GROUP_SIZE) unroll_for (uint bi = 0; bi < TILE_B; ++bi) { unroll_for(uint fi = 0; fi < TILE_OFM; ++fi) { const uint offset_ofm = out_f + fi*SIMD + sglid; @@ -958,20 +988,24 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( #endif } // Whole tile_k elements of each iteration : ki - #if DECOMPRESSION_SCALE_POST_OP && (TILE_IFM * SIMD <= DECOMPRESSION_SCALE_GROUP_SIZE) - const uint ni_offset = ((ni*TILE_IFM*SIMD) / DECOMPRESSION_SCALE_GROUP_SIZE)*DECOMPRESSION_SCALE_FEATURE_PITCH; - unroll_for (uint bi = 0; bi < TILE_B; ++bi) { - unroll_for(uint fi = 0; fi < TILE_OFM; ++fi) { - const uint offset_ofm = out_f + fi*SIMD + sglid; + #if DECOMPRESSION_SCALE_POST_OP && (TILE_IFM_ELEMENTS_SIZE <= DECOMPRESSION_SCALE_GROUP_SIZE) + // Dynamic-quantizing group size set to same or smaller than scale group size + if ((ni % NUM_LOOP_IN_DYN_QUAN_GROUP) == (NUM_LOOP_IN_DYN_QUAN_GROUP - 1)) { + const uint ni_offset = ((ni*TILE_IFM*SIMD) / DECOMPRESSION_SCALE_GROUP_SIZE)*DECOMPRESSION_SCALE_FEATURE_PITCH; + unroll_for (uint bi = 0; bi < TILE_B; ++bi) { + unroll_for(uint fi = 0; fi < TILE_OFM; ++fi) { + const uint offset_ofm = out_f + fi*SIMD + sglid; - #if DECOMPRESSION_SCALE_GROUPS_NUM > 1 - const uint scale_offset = (offset_ofm % DECOMPRESSION_SCALE_BATCH_NUM) * DECOMPRESSION_SCALE_BATCH_PITCH + ni_offset; - ACCUMULATOR_TYPE ds = decompression_scale[scale_offset]; - #else - ACCUMULATOR_TYPE ds = d_scales[fi % DECOMPRESSION_SCALE_LENGTH]; - #endif + #if DECOMPRESSION_SCALE_GROUPS_NUM > 1 + const uint scale_offset = (offset_ofm % DECOMPRESSION_SCALE_BATCH_NUM) * DECOMPRESSION_SCALE_BATCH_PITCH + ni_offset; + ACCUMULATOR_TYPE ds = decompression_scale[scale_offset]; + #else + ACCUMULATOR_TYPE ds = d_scales[fi % DECOMPRESSION_SCALE_LENGTH]; + #endif - ((ACCUMULATOR_TYPE*)(&acc[bi]))[fi] += convert_half(((int *)(&acc_tmp[fi]))[bi]) * ds * de_quantize_scale[bi]; + ((ACCUMULATOR_TYPE*)(&acc[bi]))[fi] += convert_half(((int *)(&acc_tmp[fi]))[bi]) * ds * de_quantize_scale[bi]; + acc_tmp[fi][bi] = 0; + } } } #endif diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/include/batch_headers/int4_utils.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/include/batch_headers/int4_utils.cl index d919d1ce1104ab..68d778475f5601 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/include/batch_headers/int4_utils.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/include/batch_headers/int4_utils.cl @@ -38,14 +38,6 @@ inline uchar2 unpack_to_uchar(uint4x2_t v) __attribute__((overloadable)) { return cvt_uint4x2_to_uint8x2(v); } -inline uchar8 unpack_to_uchar(uint4x8_t v) __attribute__((overloadable)) { - uchar2 v0 = unpack_to_uchar(v.s0); - uchar2 v1 = unpack_to_uchar(v.s1); - uchar2 v2 = unpack_to_uchar(v.s2); - uchar2 v3 = unpack_to_uchar(v.s3); - return (uchar8)(v0.s0, v0.s1, v1.s0, v1.s1, v2.s0, v2.s1, v3.s0, v3.s1); -} - inline char2 unpack_to_char(int4x2_t v) __attribute__((overloadable)) { return cvt_int4x2_to_int8x2(v); } @@ -54,12 +46,47 @@ inline char2 unpack_to_char(uint4x2_t v) __attribute__((overloadable)) { return convert_char2(cvt_uint4x2_to_uint8x2(v)); } +// 4bit x 4 inline char4 unpack_to_char(int4x4_t v) __attribute__((overloadable)) { char2 v0 = unpack_to_char(v.s0); char2 v1 = unpack_to_char(v.s1); return (char4)(v0.s0, v0.s1, v1.s0, v1.s1); } +inline char4 unpack_to_char(uint4x4_t v) __attribute__((overloadable)) { + char2 v0 = unpack_to_char(v.s0); + char2 v1 = unpack_to_char(v.s1); + return (char4)(v0.s0, v0.s1, v1.s0, v1.s1); +} + +inline char4 unpack_transposed_to_char(int4x4_t v) __attribute__((overloadable)) { + char2 v0 = unpack_to_char(v.s0); + char2 v1 = unpack_to_char(v.s1); + return (char4)(v0.s0, v1.s0, v0.s1, v1.s1); +} + +inline char4 unpack_transposed_to_char(uint4x4_t v) __attribute__((overloadable)) { + char2 v0 = unpack_to_char(v.s0); + char2 v1 = unpack_to_char(v.s1); + return (char4)(v0.s0, v1.s0, v0.s1, v1.s1); +} + +inline uchar4 unpack_transposed_to_uchar(uint4x4_t v) __attribute__((overloadable)) { + uchar2 v0 = unpack_to_uchar(v.s0); + uchar2 v1 = unpack_to_uchar(v.s1); + return (uchar4)(v0.s0, v1.s0, v0.s1, v1.s1); +} + + +// 4bit x 8 +inline uchar8 unpack_to_uchar(uint4x8_t v) __attribute__((overloadable)) { + uchar2 v0 = unpack_to_uchar(v.s0); + uchar2 v1 = unpack_to_uchar(v.s1); + uchar2 v2 = unpack_to_uchar(v.s2); + uchar2 v3 = unpack_to_uchar(v.s3); + return (uchar8)(v0.s0, v0.s1, v1.s0, v1.s1, v2.s0, v2.s1, v3.s0, v3.s1); +} + inline char8 unpack_to_char(int4x8_t v) __attribute__((overloadable)) { char2 v0 = unpack_to_char(v.s0); char2 v1 = unpack_to_char(v.s1); @@ -68,6 +95,14 @@ inline char8 unpack_to_char(int4x8_t v) __attribute__((overloadable)) { return (char8)(v0.s0, v0.s1, v1.s0, v1.s1, v2.s0, v2.s1, v3.s0, v3.s1); } +inline char8 unpack_to_char(uint4x8_t v) __attribute__((overloadable)) { + char2 v0 = unpack_to_char(v.s0); + char2 v1 = unpack_to_char(v.s1); + char2 v2 = unpack_to_char(v.s2); + char2 v3 = unpack_to_char(v.s3); + return (char8)(v0.s0, v0.s1, v1.s0, v1.s1, v2.s0, v2.s1, v3.s0, v3.s1); +} + inline char8 unpack_transposed_to_char(int4x8_t v) __attribute__((overloadable)) { char2 v0 = unpack_to_char(v.s0); char2 v1 = unpack_to_char(v.s1); @@ -92,6 +127,7 @@ inline uchar8 unpack_transposed_to_uchar(uint4x8_t v) __attribute__((overloadabl return (uchar8)(v0.s0, v1.s0, v2.s0, v3.s0, v0.s1, v1.s1, v2.s1, v3.s1); } +// For float inline float2 unpack_to_float(uint4x2_t v) __attribute__((overloadable)) { return convert_float2(cvt_uint4x2_to_uint8x2(v)); } diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp index db42bc969b4a32..7a3ea70f37d366 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp @@ -9,8 +9,9 @@ #include "common_types.h" static constexpr size_t simd = 16; -static constexpr size_t quantize_grp_size = 32; +static constexpr size_t min_quantize_grp_size = 32; static constexpr size_t min_slm_size = 256; +static std::vector available_quantize_grp_size = {128, 64, 32}; namespace kernel_selector { @@ -50,13 +51,14 @@ static std::pair get_output_aligned_bf_size(const fully_connecte } // DYNAMIC_QUANTIZE -static bool should_dynamic_quantize(const fully_connected_params& params) { +static size_t get_dynamic_quantize_group_size(const fully_connected_params& params) { auto dynamic_quantization_group_size = params.dynamic_quantization_group_size; GPU_DEBUG_GET_INSTANCE(debug_config); - GPU_DEBUG_IF(debug_config->enable_dynamic_quantize) { - dynamic_quantization_group_size = quantize_grp_size; + GPU_DEBUG_IF(debug_config->dynamic_quantize_group_size) { + dynamic_quantization_group_size = debug_config->dynamic_quantize_group_size; + // Specify which Fully-connected layer would be dynamic-quantized GPU_DEBUG_IF(!debug_config->dynamic_quantize_layers_without_onednn.empty()) { auto layers = debug_config->dynamic_quantize_layers_without_onednn; auto iter = std::find_if(layers.begin(), layers.end(), [&](const std::string& pattern){ @@ -64,7 +66,7 @@ static bool should_dynamic_quantize(const fully_connected_params& params) { }); if (iter != layers.end()) { - dynamic_quantization_group_size = quantize_grp_size; + dynamic_quantization_group_size = debug_config->dynamic_quantize_group_size; GPU_DEBUG_COUT << "Found specified Fully-connected layer [" << params.layerID << "]. Enable Dynamic-quantize." << std::endl; } else { dynamic_quantization_group_size = 0; @@ -72,18 +74,41 @@ static bool should_dynamic_quantize(const fully_connected_params& params) { } } + const size_t scale_group_size = params.weights.IFM().v / params.decompression_scale.Feature().v; + for (auto group_size : available_quantize_grp_size) { + if (dynamic_quantization_group_size >= group_size) { + dynamic_quantization_group_size = group_size; + + if (dynamic_quantization_group_size > scale_group_size) { + GPU_DEBUG_TRACE_DETAIL << " Scale group size " << scale_group_size << " is smaller than FC dyn-quan group size " + << dynamic_quantization_group_size << ". Reduce FC dyn-quan group size to scale size." << std::endl; + dynamic_quantization_group_size = scale_group_size; + } + return (size_t)dynamic_quantization_group_size; + } + } + + return 0; +} + +static bool should_dynamic_quantize(const fully_connected_params& params) { + size_t dynamic_quantization_group_size = get_dynamic_quantize_group_size(params); + if (params.inputs[0].GetFirstElementOffset() != 0) return false; - if (dynamic_quantization_group_size < quantize_grp_size) - return false; + if (dynamic_quantization_group_size < min_quantize_grp_size) { + GPU_DEBUG_TRACE_DETAIL << "Set dynamic_quantize_group_size " << dynamic_quantization_group_size + << " is smaller than minimum supported size 32" << std::endl; + return false; + } auto threads = get_input_bf_size(params); auto input_b = threads.first; auto input_f = threads.second; const size_t scale_group_size = params.weights.IFM().v / params.decompression_scale.Feature().v; - if ((scale_group_size % simd == 0) && (input_f % quantize_grp_size == 0) && + if ((scale_group_size % simd == 0) && (input_f % dynamic_quantization_group_size == 0) && (params.is_shape_agnostic || (params.inputs[0].Batch().v > 1 && input_b > min_slm_size)) && params.inputs[0].GetDType() == Datatype::F16 && (params.weights.GetDType() == WeightsType::INT4 || params.weights.GetDType() == WeightsType::UINT4) && @@ -487,6 +512,7 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para JitConstants jit = Parent::GetJitConstants(params, dispatchData); size_t tile_k_ofm = dispatchData.tile_nk * dispatchData.tile_n; size_t tile_k_ofm_packed = tile_k_ofm; + size_t quantize_grp_size = get_dynamic_quantize_group_size(params); WeightsType weights_dt = params.weights.GetDType(); if (weights_dt == WeightsType::UINT4 || weights_dt == WeightsType::INT4) { @@ -557,6 +583,7 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para jit.AddConstant(MakeJitConstant("QUANTIZE_GROUP_SIZE", quantize_grp_size)); } else { jit.AddConstant(MakeJitConstant("DYNAMIC_QUANTIZE", 0)); + jit.AddConstant(MakeJitConstant("QUANTIZE_GROUP_SIZE", -1)); } jit.AddConstant(MakeJitConstant("IFM_SIZE", get_input_bf_size(params).second)); @@ -570,6 +597,13 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para jit.AddConstant(MakeJitConstant("TILE_K_OFM_PACKED", tile_k_ofm_packed)); jit.AddConstant(MakeJitConstant("DISPATCH_BSV", dispatchData.tile_ms)); jit.AddConstant(MakeJitConstant("DISPATCH_FSV", dispatchData.tile_ns)); + jit.AddConstant(MakeJitConstant("TILE_IFM_ELEMENTS_SIZE", (dispatchData.tile_mk * simd))); + + if (quantize_grp_size / (dispatchData.tile_mk * simd) > 1 && quantize_grp_size % (dispatchData.tile_mk * simd) == 0) { + jit.AddConstant(MakeJitConstant("NUM_LOOP_IN_DYN_QUAN_GROUP", quantize_grp_size / (dispatchData.tile_mk * simd))); + } else { + jit.AddConstant(MakeJitConstant("NUM_LOOP_IN_DYN_QUAN_GROUP", 1)); + } auto max_tile_b_size = dispatchData.tile_m; if (params.compressed && @@ -639,6 +673,7 @@ void FullyConnected_bf_tiled::GetUpdateDispatchDataFunc(KernelData& kd) const { kd.update_dispatch_data_func = [this](const Params& params, KernelData& kd) { const auto& prim_params = static_cast(params); + size_t quantize_grp_size = get_dynamic_quantize_group_size(prim_params); size_t output_batch = get_output_aligned_bf_size(prim_params, false).first; // Get index of the added shape-agnostic kernel @@ -728,7 +763,7 @@ KernelsData FullyConnected_bf_tiled::GetTunedKernelsDataByIndex(const Params &pa KernelsData kernels_data; if (should_dynamic_quantize(fc_params)) { // Use seperate 2 kernels for dynamic quantizing : quantizing_kernel + fc_kernel - // 1st kernel : Dynamic quantizing by quantize_grp_size + // 1st kernel : Dynamic quantizing by dynamic_quantize_grp_size // 2nd kernel : fully connected kernel with KernelType::DEFAULT. Quantized inputs and scale values could be used. // 3rd kernel : (optional) fully connected shape_agnostic kernel with KernelType::SLM. Quantized inputs and scale values would be used. kernels_data = GetMultiKernelsData(params, @@ -813,6 +848,8 @@ KernelsData FullyConnected_bf_tiled::GetMultiKernelsData(const Params ¶ms, const auto& fc_params = static_cast(params); + size_t quantize_grp_size = get_dynamic_quantize_group_size(fc_params); + bool bProperInput = fc_params.inputs[0].GetLayout() == dl; if (!bProperInput && !fc_params.inputs[0].PitchesDifferFromLogicalDims()) { bProperInput = (dl == DataLayout::fb && fc_params.inputs[0].GetLayout() == DataLayout::fyxb) || diff --git a/src/plugins/intel_gpu/src/runtime/debug_configuration.cpp b/src/plugins/intel_gpu/src/runtime/debug_configuration.cpp index f0a31f81c2e2bd..f85295593416c8 100644 --- a/src/plugins/intel_gpu/src/runtime/debug_configuration.cpp +++ b/src/plugins/intel_gpu/src/runtime/debug_configuration.cpp @@ -181,9 +181,10 @@ static void print_help_messages() { message_list.emplace_back("OV_GPU_DisableRuntimeSkipReorder", "Disable runtime skip reorder."); message_list.emplace_back("OV_GPU_DisablePrimitiveFusing", "Disable primitive fusing"); message_list.emplace_back("OV_GPU_DisableFakeAlignment", "Disable fake alignment"); - message_list.emplace_back("OV_GPU_EnableDynamicQuantize", "Enable Dynamic quantization for Fully connected primitive"); message_list.emplace_back("OV_GPU_DynamicQuantizeLayersWithoutOnednn", "Enable Dynamic quantization for specified Fully connected layers only, " "separated by space. Support case-insensitive and regular expression. For example .*fully_connected.*"); + message_list.emplace_back("OV_GPU_DynamicQuantizeGroupSize", "Specify a group size of dynamic quantization to enable " + "dynamic quantization for Fully-connected primitive."); message_list.emplace_back("OV_GPU_DisableHorizontalFCFusion", "Disable horizontal fc fusion"); message_list.emplace_back("OV_GPU_DumpIteration", "Dump n-th execution of network, separated by space."); message_list.emplace_back("OV_GPU_MemPreallocationOptions", "Controls buffer pre-allocation feature. Expects 4 values separated by space in " @@ -250,7 +251,7 @@ debug_configuration::debug_configuration() , disable_runtime_skip_reorder(0) , disable_primitive_fusing(0) , disable_fake_alignment(0) - , enable_dynamic_quantize(0) + , dynamic_quantize_group_size(0) , disable_horizontal_fc_fusion(0) { #ifdef GPU_DEBUG_CONFIG get_gpu_debug_env_var("Help", help); @@ -302,7 +303,7 @@ debug_configuration::debug_configuration() get_gpu_debug_env_var("DisableRuntimeSkipReorder", disable_runtime_skip_reorder); get_gpu_debug_env_var("DisablePrimitiveFusing", disable_primitive_fusing); get_gpu_debug_env_var("DisableFakeAlignment", disable_fake_alignment); - get_gpu_debug_env_var("EnableDynamicQuantize", enable_dynamic_quantize); + get_gpu_debug_env_var("DynamicQuantizeGroupSize", dynamic_quantize_group_size); get_gpu_debug_env_var("DisableHorizontalFCFusion", disable_horizontal_fc_fusion); std::string dump_iteration_str; get_gpu_debug_env_var("DumpIteration", dump_iteration_str); diff --git a/src/plugins/intel_gpu/src/runtime/execution_config.cpp b/src/plugins/intel_gpu/src/runtime/execution_config.cpp index bb6e15707c39cd..a498dad24aa2f5 100644 --- a/src/plugins/intel_gpu/src/runtime/execution_config.cpp +++ b/src/plugins/intel_gpu/src/runtime/execution_config.cpp @@ -203,8 +203,11 @@ void ExecutionConfig::apply_debug_options(const cldnn::device_info& info) { set_property(ov::intel_gpu::use_only_static_kernels_for_dynamic_shape(true)); } - GPU_DEBUG_IF(debug_config->enable_dynamic_quantize) { - set_property(ov::hint::dynamic_quantization_group_size(UINT64_MAX)); + GPU_DEBUG_IF(debug_config->dynamic_quantize_group_size) { + if (debug_config->dynamic_quantize_group_size == -1) + set_property(ov::hint::dynamic_quantization_group_size(UINT64_MAX)); + else + set_property(ov::hint::dynamic_quantization_group_size(debug_config->dynamic_quantize_group_size)); } } diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/fully_connected_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/fully_connected_gpu_test.cpp index 7a9519ba48c05a..681a0ebbe9e05b 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/fully_connected_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/fully_connected_gpu_test.cpp @@ -2537,7 +2537,8 @@ class fully_connected_gpu_tests: public ::testing::Test { ASSERT_EQ(3.0f, output_ptr[3]); } - void test_compressed_int4_scale_dyn_quan_weight_i4(bool is_dynamic, int batch = 1, int ifm = 512, int ofm = 2048) { + void test_compressed_int4_scale_dyn_quan_weight_i4(bool is_dynamic, int batch = 1, int ifm = 512, int ofm = 2048, + int quantize_group_size = 32, int scales_group_size = 128) { tests::random_generator rg(GET_SUITE_NAME); auto& engine = get_test_engine(); @@ -2547,7 +2548,6 @@ class fully_connected_gpu_tests: public ::testing::Test { long int batch_num = batch; long int ifm_num = ifm; long int ofm_num = ofm; - int scales_group_size = 128; auto input_ps = ov::PartialShape{ batch_num, 1, ifm_num }; auto input_mem = engine.allocate_memory({ input_ps, data_types::f16, format::bfyx }); @@ -2561,7 +2561,7 @@ class fully_connected_gpu_tests: public ::testing::Test { auto weigths_data = rg.generate_random_1d(ofm_num * ifm_num / 2, 0, 4); set_values(weights_mem, weigths_data); - auto scale_data = rg.generate_random_1d(ofm_num * ifm_num / scales_group_size, -4.f, 4.f); + auto scale_data = rg.generate_random_1d(ofm_num * ifm_num / scales_group_size, -2.f, 2.f); set_values(scale_mem, scale_data); auto in_layout = is_dynamic ? layout{ ov::PartialShape{ -1, -1, -1 }, data_types::f16, format::bfyx } @@ -2608,7 +2608,7 @@ class fully_connected_gpu_tests: public ::testing::Test { auto config = get_test_default_config(engine); config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); config.set_property(ov::intel_gpu::optimize_data(true)); - config.set_property(ov::hint::dynamic_quantization_group_size(32)); + config.set_property(ov::hint::dynamic_quantization_group_size(quantize_group_size)); network::ptr network = get_network(engine, topology, config, get_test_stream_ptr(), false); @@ -2616,7 +2616,9 @@ class fully_connected_gpu_tests: public ::testing::Test { auto inst = network->get_primitive("fc_prim"); auto impl = inst->get_impl(); ASSERT_TRUE(impl != NULL); - ASSERT_EQ(impl->get_kernels().size(), size_t((is_dynamic ? 3 : 2))); + auto kernel_num = (is_dynamic) ? 3 : 2; + kernel_num = (quantize_group_size < 32) ? 2 : kernel_num; + ASSERT_EQ(impl->get_kernels().size(), size_t(kernel_num)); } network->set_input_data("input", input_mem); @@ -2640,10 +2642,10 @@ class fully_connected_gpu_tests: public ::testing::Test { max_diff = abs_diff; avg += abs_diff; count++; - OPENVINO_ASSERT(abs_diff < 10); + OPENVINO_ASSERT(abs_diff < 5); } GPU_DEBUG_LOG << "---> count: " << count << ", max_diff:" << max_diff << ", avg_diff: " << (avg/count) << std::endl; - OPENVINO_ASSERT((avg/count) < 1); + OPENVINO_ASSERT((avg/count) < 0.5); } }; @@ -3666,6 +3668,35 @@ TEST_F(fully_connected_gpu_tests, compressed_int4_scale_dynamic_quantize_edge_ca this->test_compressed_int4_scale_dyn_quan_weight_i4(true, 359, 1536, 2560); } +TEST_F(fully_connected_gpu_tests, compressed_int4_scale_dynamic_quantize_edge_case_12_groupsize) { + // Expect no dynamic-quantized FC + this->test_compressed_int4_scale_dyn_quan_weight_i4(true, 269, 512, 1024, 12); +} + +TEST_F(fully_connected_gpu_tests, compressed_int4_scale_dynamic_quantize_edge_case_34_groupsize) { + this->test_compressed_int4_scale_dyn_quan_weight_i4(true, 359, 1536, 2560, 34); +} + +TEST_F(fully_connected_gpu_tests, compressed_int4_scale_dynamic_quantize_edge_case_64_groupsize) { + this->test_compressed_int4_scale_dyn_quan_weight_i4(true, 359, 1536, 2560, 64); +} + +TEST_F(fully_connected_gpu_tests, compressed_int4_scale_dynamic_quantize_edge_case_148_groupsize) { + this->test_compressed_int4_scale_dyn_quan_weight_i4(true, 359, 1536, 2560, 148); +} + +TEST_F(fully_connected_gpu_tests, compressed_int4_scale_dynamic_quantize_edge_case_128_groupsize) { + this->test_compressed_int4_scale_dyn_quan_weight_i4(true, 359, 1536, 2560, 128); +} + +TEST_F(fully_connected_gpu_tests, compressed_int4_scale_dynamic_quantize_edge_case_128_groupsize_32_scale) { + this->test_compressed_int4_scale_dyn_quan_weight_i4(true, 359, 1536, 2560, 128, 32); +} + +TEST_F(fully_connected_gpu_tests, compressed_int4_scale_dynamic_quantize_edge_case_128_groupsize_64_scale) { + this->test_compressed_int4_scale_dyn_quan_weight_i4(true, 359, 1536, 2560, 128, 64); +} + TEST_F(fully_connected_gpu_tests, compressed_scale_bias) { this->test_compressed_scale_bias(false); }