From d15548357380f43c9f33dd825f18b20b65202426 Mon Sep 17 00:00:00 2001 From: Konrad Dobros Date: Mon, 8 Jun 2020 15:44:50 +0200 Subject: [PATCH] [IE CLDNN] Optimize 1x1 imad convolution kernel (#757) --- .../kernel_selector/common/tensor_type.cpp | 10 + .../kernel_selector/common/tensor_type.h | 2 + ...volution_kernel_b_fs_yx_fsv16_imad_1x1.cpp | 304 +++++++++--- ...onvolution_kernel_b_fs_yx_fsv16_imad_1x1.h | 12 +- .../convolution_gpu_b_fs_yx_fsv16_imad_1x1.cl | 465 ++++++++++++------ .../core/cl_kernels/include/fetch.cl | 71 ++- .../core/cl_kernels/reorder_weights.cl | 8 + .../core/kernel_selector_common.cpp | 2 + 8 files changed, 624 insertions(+), 250 deletions(-) diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/common/tensor_type.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/common/tensor_type.cpp index 02ade1b5b6dda9..b9a29d0d4c2998 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/common/tensor_type.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/common/tensor_type.cpp @@ -73,6 +73,8 @@ WeightsTensor::WeightsChannelArray WeightsTensor::weightsChannelArray {{ { WeightsLayout::os_i_osv16__ai8, { -1, -1, -1, 0, 1, -1, -1, -1 } }, { WeightsLayout::os_i_osv16, { -1, -1, -1, 0, 1, -1, -1, -1 } }, { WeightsLayout::os_is_yx_osv16_isv16, { 0, 1, -1, 2, 3, -1, -1, -1 } }, + { WeightsLayout::os_is_zyx_osv32_isv16, { 0, 1, 2, 3, 4, -1, -1, -1 } }, + { WeightsLayout::os_is_zyx_osv64_isv16, { 0, 1, 2, 3, 4, -1, -1, -1 } }, { WeightsLayout::i_yxs_os_yxsv2_osv16, { 1, 2, -1, 3, 0, -1, -1, -1 } }, { WeightsLayout::iy_xs_os_xsv2_osv16__ao32, { 1, 2, -1, 3, 0, -1, -1, -1 } }, { WeightsLayout::iy_xs_os_xsv2_osv8__ao32, { 1, 2, -1, 3, 0, -1, -1, -1 } }, @@ -633,6 +635,14 @@ NDims WeightsTensor::GetSimpleDims(const std::vector& d, WeightsLayout l newDims[2] = RoundUp(newDims[2], 16); newDims[3] = RoundUp(newDims[3], 16); break; + case os_is_zyx_osv32_isv16: + newDims[3] = RoundUp(newDims[3], 16); + newDims[4] = RoundUp(newDims[4], 32); + break; + case os_is_zyx_osv64_isv16: + newDims[3] = RoundUp(newDims[3], 16); + newDims[4] = RoundUp(newDims[4], 64); + break; case gs_oi_yxs_gsv16_yxsv4: newDims[4] = RoundUp(newDims[4], 16); break; diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/common/tensor_type.h b/inference-engine/thirdparty/clDNN/kernel_selector/common/tensor_type.h index 7b7064fe2edf1c..3dbdfd0b229191 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/common/tensor_type.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/common/tensor_type.h @@ -91,6 +91,8 @@ enum WeightsLayout { os_i_osv16__ai8, os_i_osv16, os_is_yx_osv16_isv16, // wieghts for int8 blocked conv + os_is_zyx_osv32_isv16, + os_is_zyx_osv64_isv16, i_yxs_os_yxsv2_osv16, iy_xs_os_xsv2_osv16__ao32, iy_xs_os_xsv2_osv8__ao32, diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/convolution/convolution_kernel_b_fs_yx_fsv16_imad_1x1.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/convolution/convolution_kernel_b_fs_yx_fsv16_imad_1x1.cpp index 1362af8bb46859..64144f2f930409 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/convolution/convolution_kernel_b_fs_yx_fsv16_imad_1x1.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/convolution/convolution_kernel_b_fs_yx_fsv16_imad_1x1.cpp @@ -19,66 +19,30 @@ #include #include #include +#include // // Kernel specific constants // -#define SIMD_SIZE 16 +static constexpr size_t fsv = 16; +static constexpr size_t simd = 16; namespace kernel_selector { -namespace { - -size_t getOutBlock_X(size_t output_size_x, size_t stride_x) { - size_t output_block_width = 0; - size_t max_block_size = std::min((SIMD_SIZE - 1) / stride_x + 1, output_size_x); - - if (output_size_x <= max_block_size) - return output_size_x; - - for (size_t block = 4; block <= max_block_size; ++block) { - if (output_size_x % block == 0) - output_block_width = block; - } - if (output_block_width == 0 && output_size_x < max_block_size * 3) { - size_t min_overhang = max_block_size; - for (size_t block = 4; block <= max_block_size; ++block) { - size_t overhang = block - output_size_x % block; - if (overhang <= min_overhang) { - min_overhang = overhang; - output_block_width = block; - } - } - } - - if (output_block_width == 0) { - output_block_width = max_block_size; - } - return output_block_width; -} - -bool should_k_slice(const convolution_params& params, size_t output_block_width) { - constexpr float preferred_eu_occupancy = 5.f; - if (params.inputs[0].Feature().v % (16 * 4) != 0) - return false; - - size_t eu_count = params.engineInfo.computeUnitsCount; - auto global_size = CeilDiv(params.output.X().v, output_block_width) * - params.output.Y().v * - params.output.Batch().v * Align(CeilDiv(params.output.Feature().v, 2), SIMD_SIZE); - auto threads = global_size / SIMD_SIZE; - auto optimal_threads_num = eu_count * preferred_eu_occupancy; - return threads < optimal_threads_num; -} - -} // namespace - Convolution_kernel_b_fs_yx_fsv16_imad_1x1::Convolution_kernel_b_fs_yx_fsv16_imad_1x1() : ConvolutionKernelBase("convolution_gpu_b_fs_yx_fsv16_imad_1x1") { - for (size_t bw = 1; bw <= SIMD_SIZE; ++bw) { - for (auto exe : ConvolutionKernelBase::autoTuneOptions) { - all_tune_params.push_back(AutoTuneParams{ bw, true, exe }); - all_tune_params.push_back(AutoTuneParams{ bw, false, exe }); + constexpr size_t max_block_elements = 32; + for (size_t bs = 1; bs <= 2 * simd; ++bs) { + for (size_t bf = 1; bf <= 4; ++bf) { + if (bs * bf > max_block_elements) + continue; + for (size_t split = 1; split <= 8; ++split) { + if (bf > split) + continue; + for (auto exe : ConvolutionKernelBase::autoTuneOptions) { + all_tune_params.push_back(AutoTuneParams{ bs, bf, split, exe }); + } + } } } } @@ -94,6 +58,7 @@ ParamsKey Convolution_kernel_b_fs_yx_fsv16_imad_1x1::GetSupportedKey() const { k.EnableOutputDataType(Datatype::F16); k.EnableInputWeightsType(WeightsType::INT8); + k.EnableInputWeightsType(WeightsType::UINT8); k.EnableInputLayout(DataLayout::b_fs_yx_fsv16); k.EnableOutputLayout(DataLayout::b_fs_yx_fsv16); @@ -106,19 +71,33 @@ ParamsKey Convolution_kernel_b_fs_yx_fsv16_imad_1x1::GetSupportedKey() const { k.EnableNonBiasTerm(); k.EnableBatching(); k.EnableQuantization(QuantizationType::SYMMETRIC); + k.DisableTuning(); return k; } JitConstants Convolution_kernel_b_fs_yx_fsv16_imad_1x1::GetJitConstants(const convolution_params& params, const DispatchData& kd) const { auto mem_consts = Parent::GetJitConstants(params, kd); - mem_consts.AddConstant(MakeJitConstant("OUT_BLOCK_WIDTH", kd.cldnnStyle.blockWidth)); - mem_consts.AddConstant(MakeJitConstant("FEATURE_LWS_SPLIT", kd.cldnnStyle.prefetch)); + mem_consts.AddConstant(MakeJitConstant("OUT_BLOCK_SPATIAL", kd.cldnnStyle.blockWidth)); + mem_consts.AddConstant(MakeJitConstant("OUT_BLOCK_FEATURES", kd.cldnnStyle.blockHeight)); + mem_consts.AddConstant(MakeJitConstant("FEATURE_SLM_SPLIT", kd.cldnnStyle.prefetch)); + mem_consts.Merge(MakeTypeJitConstants(GetAccumulatorType(params), "ACCUMULATOR")); + mem_consts.Merge(MakeTypeJitConstants(GetActivationType(params), "ACTIVATION")); if (!params.fused_ops.empty()) { auto input_dt = GetActivationType(params); - FusedOpsConfiguration conf_scalar = {"", {"out_b", "out_f + out_f_offset", "out_y", "out_x + i"}, "dequantized", input_dt, 1 }; - conf_scalar.SetLoopAxes({ Tensor::DataChannelName::X }, true); + std::vector idx_order = { "out_b", + "(out_f + ofb * SIMD)", + "intel_sub_group_shuffle(out_y_shuffle[os / SIMD], os % SIMD)", + "intel_sub_group_shuffle(out_x_shuffle[os / SIMD], os % SIMD)" }; + FusedOpsConfiguration conf_scalar = {"_SCALAR", + idx_order, + "dequantized[ofb][os]", + input_dt, + 1, + LoadType::LT_UNALIGNED, + BoundaryCheck::DISABLED }; + conf_scalar.SetLoopAxes({ Tensor::DataChannelName::X, Tensor::DataChannelName::Y }, true); mem_consts.Merge(MakeFusedOpsJitConstants(params, {conf_scalar})); } @@ -130,24 +109,62 @@ ConvolutionKernelBase::DispatchData Convolution_kernel_b_fs_yx_fsv16_imad_1x1::S DispatchData kd; const auto& output = params.output; auto tune_params = GetAutoTuneParams(params, index); - size_t k_slices = tune_params.k_slicing ? 4 : 1; + size_t k_slices = tune_params.feature_slm_split; - kd.gws0 = CeilDiv(output.X().v, tune_params.out_block_width); - kd.gws1 = output.Y().v; - kd.gws2 = output.Batch().v * Align(CeilDiv(output.Feature().v, 2), SIMD_SIZE) * k_slices; + kd.gws0 = CeilDiv(output.X().v * output.Y().v, tune_params.out_block_spatial); + kd.gws1 = CeilDiv(output.Feature().v, tune_params.out_block_features * simd) * simd * k_slices; + kd.gws2 = output.Batch().v; kd.lws0 = 1; - kd.lws1 = 1; - kd.lws2 = SIMD_SIZE * k_slices; + kd.lws1 = simd * k_slices; + kd.lws2 = 1; kd.cldnnStyle = {0, 0, 0, 0, 0}; kd.gemmStyle = {0, 0, 0, 0, 0, 0}; - kd.cldnnStyle.blockWidth = tune_params.out_block_width; + kd.cldnnStyle.blockWidth = tune_params.out_block_spatial; + kd.cldnnStyle.blockHeight = tune_params.out_block_features; kd.cldnnStyle.prefetch = k_slices; kd.efficiency = FORCE_PRIORITY_2; + auto in_f = params.weights.IFM().v; + auto out_f = params.weights.OFM().v; + auto batch = output.Batch().v; + auto out_x = output.X().v; + auto out_y = output.Y().v; + + bool x_strided = params.stride.x != 1; + bool general_is_faster = false; + + // This kernel cannot split for large x, but general could + general_is_faster |= CeilDiv(in_f, fsv) % 4 == 0 + && (out_x % 15 == 0 || out_x % 16 == 0) + && tune_params.feature_slm_split == 1 + && tune_params.out_block_spatial <= 8; + + // List of known cases where general kernel is better + general_is_faster |= in_f == 24 && out_f == 144 && out_x == 75 && out_y == 75 && batch == 1; + general_is_faster |= in_f == 192 && out_f == 64 && out_x == 28 && out_y == 28 && batch == 1; + general_is_faster |= in_f == 576 && out_f == 96 && out_x == 19 && out_y == 19 && batch == 1; + general_is_faster |= in_f == 384 && out_f == 96 && out_x == 19 && out_y == 19 && batch == 1; + general_is_faster |= in_f == 384 && out_f == 64 && out_x == 19 && out_y == 19 && batch == 1; + general_is_faster |= in_f == 192 && out_f == 64 && out_x == 19 && out_y == 19 && batch == 1; + general_is_faster |= in_f == 96 && out_f == 576 && out_x == 19 && out_y == 19 && batch == 1; + general_is_faster |= in_f == 1024 && out_f == 256 && out_x == 14 && out_y == 14 && batch == 1; + general_is_faster |= in_f == 256 && out_f == 256 && out_x == 14 && out_y == 14 && batch == 1; + general_is_faster |= in_f == 136 && out_f == 816 && out_x == 14 && out_y == 14 && batch == 1; + general_is_faster |= in_f == 1280 && out_f == 256 && out_x == 10 && out_y == 10 && batch == 1; + general_is_faster |= in_f == 256 && out_f == 128 && out_x == 3 && out_y == 3 && batch == 1; + + if (general_is_faster && !x_strided) { + kd.efficiency = FORCE_PRIORITY_3; + } + + // Better to use kernel with 4 input features in a loop + if (static_cast(params.weights.IFM().v) / static_cast(Align(params.weights.IFM().v, fsv)) < 0.5f) + kd.efficiency = FORCE_PRIORITY_4; + return kd; } // SetDefault @@ -165,40 +182,169 @@ bool Convolution_kernel_b_fs_yx_fsv16_imad_1x1::Validate(const Params& params, c return false; } - if ((newParams.stride.x != newParams.stride.y) || - (newParams.stride.x != 1 && newParams.stride.x != 2)) { - // Strides must be 1x1 or 2x2 - return false; - } - if (newParams.groups != 1 || newParams.split != 1) return false; return true; } +WeightsLayout Convolution_kernel_b_fs_yx_fsv16_imad_1x1::GetPreferredWeightsLayout(const convolution_params& params) const { + // TODO Auto tune index is needed in GetPreferredWeightsLayout to select correct weights layout + auto tparams = GetAutoTuneParams(params, -1); + if (tparams.out_block_features == 2) + return WeightsLayout::os_is_zyx_osv32_isv16; + if (tparams.out_block_features == 4) + return WeightsLayout::os_is_zyx_osv64_isv16; + + return WeightsLayout::os_is_yx_osv16_isv16; +} + Convolution_kernel_b_fs_yx_fsv16_imad_1x1::AutoTuneParams Convolution_kernel_b_fs_yx_fsv16_imad_1x1::GetAutoTuneParams(const convolution_params& params, int index) const { if (index >= 0 && index < static_cast(all_tune_params.size())) { return all_tune_params[index]; } - AutoTuneParams default_params; - default_params.out_block_width = getOutBlock_X(params.output.X().v, params.stride.x); - default_params.k_slicing = should_k_slice(params, default_params.out_block_width); - default_params.exe_mode = DEFAULT; - return default_params; + + size_t block_spatial = 1; + size_t block_features = 1; + size_t feature_slm_split = 1; + std::string exe_mode = DEFAULT; + + size_t total_spatial = params.output.X().v * params.output.Y().v; + // Try two features per work-item + if (params.output.Feature().v % 32 == 0 || params.output.Feature().v > 32 * 2) + block_features = 2; + + // Non strict inequality here leads to some regressions, ie: [1, 64, 19, 19] (*) [384, 64, 1, 1] + bool can_split = params.weights.IFM().v > 4 * fsv; + + // Select block size in spatial dimension + { + size_t max_spatial = std::min(2 * simd / block_features, total_spatial); + size_t min_efficient_spatial = 8; + + if (max_spatial <= min_efficient_spatial) { + block_spatial = max_spatial; + } else { + auto minimum_params = AutoTuneParams{ min_efficient_spatial, block_features, 1, exe_mode }; + bool preserve_occupancy = EstimateOccupancy(params, minimum_params) >= 1.f; + + size_t min_overhang = max_spatial; + size_t best_block = min_efficient_spatial; + bool block_write_found = false; + bool output_pad = params.output.X().pad.Total() != 0; + + for (size_t block = min_efficient_spatial; block <= max_spatial; ++block) { + bool c_occupancy = EstimateOccupancy(params, { block, block_features, 1, exe_mode }) >= 1.f; + auto overhang = Align(total_spatial, block) - total_spatial; + bool c_block_write = (overhang == 0 && !output_pad) || params.output.X().v % block == 0; + + // Kernel work-around for spills/inefficient loop order + if (can_split && !c_occupancy && block > 14 && block_features > 1) + break; + + if (preserve_occupancy && !c_occupancy) + break; + + if (overhang <= min_overhang && (!block_write_found || c_block_write)) { + best_block = block; + min_overhang = overhang; + block_write_found = c_block_write; + } + } + + block_spatial = best_block; + } + } + + // Try to split features using slm to increase occupancy + { + auto dummy_params = AutoTuneParams{ block_spatial, block_features, 1, exe_mode }; + bool enough_occupancy = EstimateOccupancy(params, dummy_params) >= 1.f; + if (!enough_occupancy && can_split) { + std::vector check_split = { 4 }; + size_t ifm_blocks = CeilDiv(params.weights.IFM().v, fsv); + for (auto split : check_split) { + if (split > ifm_blocks) + break; + + auto tmp_tune = AutoTuneParams{ block_spatial, block_features, split, exe_mode }; + + bool c_lws = split * simd <= params.engineInfo.maxWorkGroupSize; + bool c_slm = EstimateSLMUsage(params, tmp_tune) <= 1.f; + bool c_fb = block_features <= split; + bool c_occupancy = EstimateOccupancy(params, tmp_tune) >= 1.f; + + if (c_lws && c_slm && c_fb) { + feature_slm_split = split; + } + + // Increasing split will only increase memory and work-group size, don't check bigger split + if (!c_slm || !c_lws || c_occupancy) + break; + } + } + } + + // Occupancy is still extremely low, try to decrease spatials + { + auto dummy_params = AutoTuneParams{ block_spatial, block_features, feature_slm_split, exe_mode }; + constexpr float default_threshold = 5.f / 7.f; + constexpr float split_threshold = 4.f / 7.f; + float threshold_occupancy = feature_slm_split == 1 ? default_threshold : split_threshold; + + if (EstimateOccupancy(params, dummy_params) < threshold_occupancy && block_spatial != 1) { + for (size_t block = block_spatial - 1; block >= 4; --block) { + auto tmp_params = AutoTuneParams{ block, block_features, feature_slm_split, exe_mode }; + bool c_mul = total_spatial % block == 0; + bool c_occupancy = EstimateOccupancy(params, tmp_params) >= threshold_occupancy; + + if (c_mul) { + block_spatial = block; + if (c_occupancy) + break; + } + } + } + } + + return AutoTuneParams{ block_spatial, block_features, feature_slm_split, exe_mode }; +} + +float Convolution_kernel_b_fs_yx_fsv16_imad_1x1::EstimateOccupancy(const convolution_params& params, const AutoTuneParams& tparams) const { + size_t blocks_s = CeilDiv(params.output.X().v * params.output.Y().v, tparams.out_block_spatial); + size_t blocks_f = CeilDiv(params.output.Feature().v, tparams.out_block_features * simd) * tparams.feature_slm_split; + size_t block_b = params.output.Batch().v; + + auto threads = blocks_s * blocks_f * block_b; + constexpr size_t max_threads_per_cu = 7; + size_t compute_units = params.engineInfo.computeUnitsCount; + size_t max_threads = compute_units * max_threads_per_cu; + + return static_cast(threads) / static_cast(max_threads); +} + +float Convolution_kernel_b_fs_yx_fsv16_imad_1x1::EstimateSLMUsage(const convolution_params& params, const AutoTuneParams& tparams) const { + size_t slm_elements = tparams.out_block_spatial * tparams.out_block_features * fsv * (tparams.feature_slm_split - 1); + size_t slm_bytes = slm_elements * BytesPerElement(GetAccumulatorType(params)); + + // TODO Actual maximum slm should also depend on number of work-groups, but this is device specific + size_t max_slm_bytes = params.engineInfo.maxLocalMemSize; + + return static_cast(slm_bytes) / static_cast(max_slm_bytes); } bool Convolution_kernel_b_fs_yx_fsv16_imad_1x1::ValidateAutoTuneParams(const convolution_params& params, const AutoTuneParams& tune_params) const { - if (tune_params.k_slicing && params.inputs[0].Feature().v % (16 * 4) != 0) - return false; + bool c_ifm = CeilDiv(params.weights.IFM().v, fsv) >= tune_params.feature_slm_split; + bool c_slm = EstimateSLMUsage(params, tune_params) <= 1.f; + bool c_lws = tune_params.feature_slm_split * simd <= params.engineInfo.maxWorkGroupSize; - size_t max_block_size = std::min(static_cast((SIMD_SIZE - 1) / params.stride.x + 1), params.output.X().v); - if (tune_params.out_block_width > max_block_size) - return false; + // Work-around for lack of actual AutoTuneParams in GetPreferredWeightsLayout + auto default_params = GetAutoTuneParams(params, -1); + bool c_wa_fb = default_params.out_block_features == tune_params.out_block_features; - return true; + return c_ifm && c_slm && c_lws && c_wa_fb; } KernelsData Convolution_kernel_b_fs_yx_fsv16_imad_1x1::GetKernelsData(const Params& params, diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/convolution/convolution_kernel_b_fs_yx_fsv16_imad_1x1.h b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/convolution/convolution_kernel_b_fs_yx_fsv16_imad_1x1.h index 7133d2dde97307..44f3f4ac82aeab 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/convolution/convolution_kernel_b_fs_yx_fsv16_imad_1x1.h +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/actual_kernels/convolution/convolution_kernel_b_fs_yx_fsv16_imad_1x1.h @@ -38,9 +38,7 @@ class Convolution_kernel_b_fs_yx_fsv16_imad_1x1 : public ConvolutionKernelBase { JitConstants GetJitConstants(const convolution_params& params, const DispatchData& kd) const override; DispatchData SetDefault(const convolution_params& params, int autoTuneIndex = -1) const override; bool NeedPaddedInput() const override { return true; } - WeightsLayout GetPreferredWeightsLayout(const convolution_params&) const override { - return WeightsLayout::os_is_yx_osv16_isv16; - } + WeightsLayout GetPreferredWeightsLayout(const convolution_params&) const override; std::vector GetSupportedFusedOps() const override { return { FusedOpType::ELTWISE, @@ -50,13 +48,17 @@ class Convolution_kernel_b_fs_yx_fsv16_imad_1x1 : public ConvolutionKernelBase { } struct AutoTuneParams { - size_t out_block_width; - bool k_slicing; + size_t out_block_spatial; + size_t out_block_features; + size_t feature_slm_split; std::string exe_mode; }; std::vector all_tune_params; bool ValidateAutoTuneParams(const convolution_params& params, const AutoTuneParams& tune_params) const; AutoTuneParams GetAutoTuneParams(const convolution_params& params, int index) const; + + float EstimateOccupancy(const convolution_params& params, const AutoTuneParams& tune) const; + float EstimateSLMUsage(const convolution_params& params, const AutoTuneParams& tune) const; }; } // namespace kernel_selector diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/convolution_gpu_b_fs_yx_fsv16_imad_1x1.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/convolution_gpu_b_fs_yx_fsv16_imad_1x1.cl index 4cacde12381beb..ee52ae87747a74 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/convolution_gpu_b_fs_yx_fsv16_imad_1x1.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/convolution_gpu_b_fs_yx_fsv16_imad_1x1.cl @@ -17,28 +17,38 @@ #include "include/fetch.cl" #include "include/imad.cl" #include "include/mmad.cl" +#include "include/data_types.cl" + +#define FSV 16 +#define SIMD 16 + +#if FILTER_LAYOUT_OS_IS_YX_OSV16_ISV16 +# define GET_WEIGHTS_INDEX(o, i, z, y, x) GET_FILTER_OS_IS_YX_OSV16_ISV16_INDEX(FILTER, o, i, y, x) +# define WEIGHTS_FEATURE_BLOCK_PITCH (ALIGN(FILTER_IFM_NUM, FSV) * FILTER_SIZE_X * FILTER_SIZE_Y * FSV) +# define WEIGHTS_IS_PITCH (FSV * FSV * FILTER_SIZE_X * FILTER_SIZE_Y) + +#elif FILTER_LAYOUT_OS_IS_ZYX_OSV32_ISV16 +# define GET_WEIGHTS_INDEX(o, i, z, y, x) GET_FILTER_OS_IS_ZYX_OSV32_ISV16_INDEX(FILTER, o, i, z, y, x) +# define WEIGHTS_FEATURE_BLOCK_PITCH (FSV * FSV) +# define WEIGHTS_IS_PITCH (2 * FSV * FSV * FILTER_SIZE_X * FILTER_SIZE_Y * FILTER_SIZE_Z) + +#elif FILTER_LAYOUT_OS_IS_ZYX_OSV64_ISV16 +# define GET_WEIGHTS_INDEX(o, i, z, y, x) GET_FILTER_OS_IS_ZYX_OSV64_ISV16_INDEX(FILTER, o, i, z, y, x) +# define WEIGHTS_FEATURE_BLOCK_PITCH (FSV * FSV) +# define WEIGHTS_IS_PITCH (4 * FSV * FSV * FILTER_SIZE_X * FILTER_SIZE_Y * FILTER_SIZE_Z) -#if QUANTIZATION_TERM - #define ACCUMULATOR_TYPE int - #define TO_ACCUMULATOR_TYPE(x) convert_int(x) - #define ACTIVATION_TYPE float - #define TO_ACTIVATION_TYPE(x) convert_float(x) -#else - #define ACCUMULATOR_TYPE INPUT0_TYPE - #define TO_ACCUMULATOR_TYPE(x) TO_INPUT0_TYPE(x) - #define ACTIVATION_TYPE INPUT0_TYPE - #define TO_ACTIVATION_TYPE(x) TO_INPUT0_TYPE(x) #endif -#define MAKE_VECTOR_TYPE(elem_type, size) CAT(elem_type, size) #define AS_TYPE_N_(type, n, x) as_##type##n(x) #define AS_TYPE_N(type, n, x) AS_TYPE_N_(type, n, x) #define AS_INPUT0_TYPE_4(x) AS_TYPE_N(INPUT0_TYPE, 4, x) +#define AS_FILTER_TYPE_4(x) AS_TYPE_N(FILTER_TYPE, 4, x) #define CEIL_DIV(a, b) (((a) + (b) - 1)/(b)) #define ALIGN(a, b) (CEIL_DIV(a, b) * (b)) -__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((intel_reqd_sub_group_size(SIMD))) +__attribute__((reqd_work_group_size(1, SIMD * FEATURE_SLM_SPLIT, 1))) KERNEL(convolution_gpu_b_fs_yx_fsv16_imad_1x1)( const __global INPUT0_TYPE *conv_input, __global OUTPUT_TYPE *output, @@ -51,182 +61,343 @@ KERNEL(convolution_gpu_b_fs_yx_fsv16_imad_1x1)( #endif uint split_idx) { - #define LUT_VALUE_CLAMP(x) ((x) < (OUT_BLOCK_WIDTH - 1) * STRIDE_SIZE_X + 1 ? (x) : 0) - const int tmp[16] = { - LUT_VALUE_CLAMP(0), - LUT_VALUE_CLAMP(1), - LUT_VALUE_CLAMP(2), - LUT_VALUE_CLAMP(3), - LUT_VALUE_CLAMP(4), - LUT_VALUE_CLAMP(5), - LUT_VALUE_CLAMP(6), - LUT_VALUE_CLAMP(7), - LUT_VALUE_CLAMP(8), - LUT_VALUE_CLAMP(9), - LUT_VALUE_CLAMP(10), - LUT_VALUE_CLAMP(11), - LUT_VALUE_CLAMP(12), - LUT_VALUE_CLAMP(13), - LUT_VALUE_CLAMP(14), - LUT_VALUE_CLAMP(15) - }; - #undef LUT_VALUE_CLAMP - -#if FEATURE_LWS_SPLIT != 1 - const uint subgroup_id = get_sub_group_id(); -#else - const uint subgroup_id = 0; -#endif - const uint subgroup_local_id = get_sub_group_local_id(); + // Use group ids to ease sub-group uniform variables optimization for compiler + const uint out_yx_sg = (uint)get_group_id(0) * OUT_BLOCK_SPATIAL; + uint out_fg = (uint)get_group_id(1) * OUT_BLOCK_FEATURES * SIMD; + const uint out_b = (uint)get_group_id(2); + uint out_f = out_fg + get_sub_group_local_id(); - const uint out_x = (uint)get_global_id(0) * OUT_BLOCK_WIDTH; - const uint out_y = get_global_id(1); - const uint out_b = (uint)(get_group_id(2) * 32) / ALIGN(OUTPUT_FEATURE_NUM, 32); - const uint out_fg = (uint)(get_group_id(2) * 32) % ALIGN(OUTPUT_FEATURE_NUM, 32); - const uint out_f = out_fg + subgroup_local_id; + const uint sglid = get_sub_group_local_id(); - const uint feature_offset = subgroup_id * INPUT0_FEATURE_NUM / FEATURE_LWS_SPLIT; + uint out_x_shuffle[CEIL_DIV(OUT_BLOCK_SPATIAL, SIMD)] = { }; + uint out_y_shuffle[CEIL_DIV(OUT_BLOCK_SPATIAL, SIMD)] = { }; - ACCUMULATOR_TYPE dotProd[OUT_BLOCK_WIDTH * 2] = { 0 }; + const uint max_out_yx = OUTPUT_SIZE_X * OUTPUT_SIZE_Y; + uint max_local_yx = min(max_out_yx, out_yx_sg + OUT_BLOCK_SPATIAL); + __attribute__((opencl_unroll_hint)) + for (uint os = 0; os < CEIL_DIV(OUT_BLOCK_SPATIAL, SIMD); ++os) { + uint out_yx_shuffle = out_yx_sg + sglid + os * SIMD; + uint out_yx_clamp = max_out_yx % OUT_BLOCK_SPATIAL == 0 + ? out_yx_shuffle + : min(out_yx_shuffle, max_local_yx - 1); + out_x_shuffle[os] = out_yx_clamp % OUTPUT_SIZE_X; + out_y_shuffle[os] = out_yx_clamp / OUTPUT_SIZE_X; + } - const int input_x = out_x * STRIDE_SIZE_X - PADDING_SIZE_X; - const int input_y = out_y * STRIDE_SIZE_Y - PADDING_SIZE_Y; - - uint filter_idx = GET_FILTER_OS_IS_YX_OSV16_ISV16_INDEX(FILTER, out_f, feature_offset, 0, 0); - uint filter_idx2 = GET_FILTER_OS_IS_YX_OSV16_ISV16_INDEX(FILTER, out_f + 16, feature_offset, 0, 0); + const uint ifm_blocks = CEIL_DIV(INPUT0_FEATURE_NUM, FSV); + const uint ifm_blocks_per_sg = ifm_blocks / FEATURE_SLM_SPLIT; + const uint ifm_per_sg = ifm_blocks_per_sg * FSV; + + uint feature_offset = 0; + uint feature_blocks = ifm_blocks_per_sg; +#if FEATURE_SLM_SPLIT != 1 + feature_offset = get_sub_group_id() * ifm_per_sg; + + if (ifm_blocks % FEATURE_SLM_SPLIT != 0) { + bool bigger_sg = get_sub_group_id() < ifm_blocks % FEATURE_SLM_SPLIT; + feature_blocks = bigger_sg ? ifm_blocks_per_sg + 1 : ifm_blocks_per_sg; + feature_offset += bigger_sg ? get_sub_group_id() * FSV : ifm_blocks % FEATURE_SLM_SPLIT * FSV; + } +#endif + + uint filter_idx = GET_WEIGHTS_INDEX(out_f, feature_offset, 0, 0, 0); + + uint input_idx[CEIL_DIV(OUT_BLOCK_SPATIAL, SIMD)] = { }; + __attribute__((opencl_unroll_hint)) + for (uint os = 0; os < CEIL_DIV(OUT_BLOCK_SPATIAL, SIMD); ++os) { + uint input_x = out_x_shuffle[os] * STRIDE_SIZE_X - PADDING_SIZE_X; + uint input_y = out_y_shuffle[os] * STRIDE_SIZE_Y - PADDING_SIZE_Y; + input_idx[os] = INPUT0_GET_INDEX(out_b, feature_offset, input_y, input_x); + } + + ACCUMULATOR_TYPE dotProd[OUT_BLOCK_FEATURES][OUT_BLOCK_SPATIAL] = { }; __attribute__((opencl_unroll_hint(1))) - for(uint k = 0; k < CEIL_DIV(INPUT0_FEATURE_NUM, 16)/FEATURE_LWS_SPLIT; k++ ) { - uint4 weights_val = vload4(0, (__global uint*)(weights + filter_idx)); - uint4 weights_val2 = vload4(0, (__global uint *)(weights + filter_idx2)); - - uint input_idx = GET_DATA_B_FS_YX_FSV16_INDEX(INPUT0, out_b, feature_offset + k * 16, input_y, input_x + tmp[get_sub_group_local_id()]); - uint4 input_val0 = vload4(0, (__global uint *)(conv_input + input_idx)); - - __attribute__((opencl_unroll_hint(OUT_BLOCK_WIDTH))) - for(uint ow = 0; ow < OUT_BLOCK_WIDTH; ow++) { - const uint ow_offset = ow + OUT_BLOCK_WIDTH; - dotProd[ow] = TO_ACCUMULATOR_TYPE(IMAD(dotProd[ow], AS_INPUT0_TYPE_4(intel_sub_group_shuffle(input_val0.s0, ow * STRIDE_SIZE_X)), as_char4(weights_val.s0))); - dotProd[ow] = TO_ACCUMULATOR_TYPE(IMAD(dotProd[ow], AS_INPUT0_TYPE_4(intel_sub_group_shuffle(input_val0.s1, ow * STRIDE_SIZE_X)), as_char4(weights_val.s1))); - dotProd[ow] = TO_ACCUMULATOR_TYPE(IMAD(dotProd[ow], AS_INPUT0_TYPE_4(intel_sub_group_shuffle(input_val0.s2, ow * STRIDE_SIZE_X)), as_char4(weights_val.s2))); - dotProd[ow] = TO_ACCUMULATOR_TYPE(IMAD(dotProd[ow], AS_INPUT0_TYPE_4(intel_sub_group_shuffle(input_val0.s3, ow * STRIDE_SIZE_X)), as_char4(weights_val.s3))); - - dotProd[ow_offset] = TO_ACCUMULATOR_TYPE(IMAD(dotProd[ow_offset], AS_INPUT0_TYPE_4(intel_sub_group_shuffle(input_val0.s0, ow * STRIDE_SIZE_X)), as_char4(weights_val2.s0))); - dotProd[ow_offset] = TO_ACCUMULATOR_TYPE(IMAD(dotProd[ow_offset], AS_INPUT0_TYPE_4(intel_sub_group_shuffle(input_val0.s1, ow * STRIDE_SIZE_X)), as_char4(weights_val2.s1))); - dotProd[ow_offset] = TO_ACCUMULATOR_TYPE(IMAD(dotProd[ow_offset], AS_INPUT0_TYPE_4(intel_sub_group_shuffle(input_val0.s2, ow * STRIDE_SIZE_X)), as_char4(weights_val2.s2))); - dotProd[ow_offset] = TO_ACCUMULATOR_TYPE(IMAD(dotProd[ow_offset], AS_INPUT0_TYPE_4(intel_sub_group_shuffle(input_val0.s3, ow * STRIDE_SIZE_X)), as_char4(weights_val2.s3))); + for (uint k = 0; k < feature_blocks; ++k) { + uint4 weights_val[OUT_BLOCK_FEATURES] = { }; + __attribute__((opencl_unroll_hint)) + for (uint ofb = 0; ofb < OUT_BLOCK_FEATURES; ++ofb) { + weights_val[ofb] = vload4(0, (__global uint*)(weights + filter_idx + ofb * WEIGHTS_FEATURE_BLOCK_PITCH)); } - filter_idx += 16 * 16; - filter_idx2 += 16 * 16; - } + uint4 input_val[CEIL_DIV(OUT_BLOCK_SPATIAL, SIMD)] = { }; + __attribute__((opencl_unroll_hint)) + for (uint os = 0; os < CEIL_DIV(OUT_BLOCK_SPATIAL, SIMD); ++os) { + input_val[os] = vload4(0, (__global uint *)(conv_input + input_idx[os])); + } -#if FEATURE_LWS_SPLIT != 1 - __local ACCUMULATOR_TYPE partial_acc[16 * OUT_BLOCK_WIDTH * (FEATURE_LWS_SPLIT - 1) * 2]; - if (subgroup_id == 0) { - __attribute__((opencl_unroll_hint(OUT_BLOCK_WIDTH))) - for (uint i = 0; i < OUT_BLOCK_WIDTH; i++) { - partial_acc[16 * OUT_BLOCK_WIDTH + i * 16 + subgroup_local_id] = dotProd[i + OUT_BLOCK_WIDTH]; +#if OUT_BLOCK_FEATURES > 1 && FEATURE_SLM_SPLIT != 1 && OUT_BLOCK_SPATIAL > 14 + // For some cases compiler spills here due to loop order + // Use suboptimal order to avoid this at cost of instruction dispatch delays. + __attribute__((opencl_unroll_hint)) + for (uint os = 0; os < OUT_BLOCK_SPATIAL; ++os) { + __attribute__((opencl_unroll_hint)) + for (uint ive = 0; ive < 4; ++ive) { + __attribute__((opencl_unroll_hint)) + for (uint ofb = 0; ofb < OUT_BLOCK_FEATURES; ++ofb) { +#else + __attribute__((opencl_unroll_hint)) + for (uint ive = 0; ive < 4; ++ive) { + __attribute__((opencl_unroll_hint)) + for (uint ofb = 0; ofb < OUT_BLOCK_FEATURES; ++ofb) { + __attribute__((opencl_unroll_hint)) + for (uint os = 0; os < OUT_BLOCK_SPATIAL; ++os) { +#endif + dotProd[ofb][os] = IMAD(dotProd[ofb][os], + AS_INPUT0_TYPE_4(intel_sub_group_shuffle(input_val[os / SIMD][ive], os % SIMD)), + AS_FILTER_TYPE_4(weights_val[ofb][ive])); + } + } } - } else if (subgroup_id == 1) { - __attribute__((opencl_unroll_hint(OUT_BLOCK_WIDTH))) - for (uint i = 0; i < OUT_BLOCK_WIDTH; i++) { - partial_acc[i * 16 + subgroup_local_id] = dotProd[i]; - dotProd[i] = dotProd[i + OUT_BLOCK_WIDTH]; + + filter_idx += WEIGHTS_IS_PITCH; + __attribute__((opencl_unroll_hint)) + for (uint os = 0; os < CEIL_DIV(OUT_BLOCK_SPATIAL, SIMD); ++os) { + input_idx[os] += INPUT0_FEATURE_PITCH * FSV; } - } else if (subgroup_id == 2) { - __attribute__((opencl_unroll_hint(OUT_BLOCK_WIDTH))) - for (uint i = 0; i < OUT_BLOCK_WIDTH; i++) { - partial_acc[2 * 16 * OUT_BLOCK_WIDTH + i * 16 + subgroup_local_id] = dotProd[i]; - partial_acc[3 * 16 * OUT_BLOCK_WIDTH + i * 16 + subgroup_local_id] = dotProd[i + OUT_BLOCK_WIDTH]; + } + +#if FEATURE_SLM_SPLIT != 1 + // Additional local memory reduction for feature split mode +# if FEATURE_SLM_SPLIT < OUT_BLOCK_FEATURES +# error convolution_gpu_b_fs_yx_fsv16_imad_1x1.cl - OUT_BLOCK_FEATURES must be less or equal to FEATURE_SLM_SPLIT +# endif + + const uint partial_acc_size = (FEATURE_SLM_SPLIT - 1) * OUT_BLOCK_FEATURES * SIMD * OUT_BLOCK_SPATIAL; + __local ACCUMULATOR_TYPE partial_acc[partial_acc_size]; + + uint sgid_start_idx = get_sub_group_id(); + sgid_start_idx = sgid_start_idx == 0 ? 0 : sgid_start_idx - 1; + __local ACCUMULATOR_TYPE* partial_acc_ptr = partial_acc + sgid_start_idx * OUT_BLOCK_FEATURES * SIMD * OUT_BLOCK_SPATIAL + sglid; + + if (get_sub_group_id() < OUT_BLOCK_FEATURES) { + __attribute__((opencl_unroll_hint)) + for (uint wg = 0; wg < OUT_BLOCK_FEATURES; ++wg) { + if (get_sub_group_id() == wg) { + __attribute__((opencl_unroll_hint)) + for (uint ofb = 0; ofb < wg; ++ofb) { + __attribute__((opencl_unroll_hint)) + for (uint os = 0; os < OUT_BLOCK_SPATIAL; ++os) { + const uint partial_acc_ptr_idx = + ofb * OUT_BLOCK_SPATIAL * SIMD + + os * SIMD; + partial_acc_ptr[partial_acc_ptr_idx] = dotProd[ofb][os]; + } + } + __attribute__((opencl_unroll_hint)) + for (uint os = 0; os < OUT_BLOCK_SPATIAL; ++os) { + dotProd[0][os] = dotProd[wg][os]; + } + __attribute__((opencl_unroll_hint)) + for (uint ofb = wg + 1; ofb < OUT_BLOCK_FEATURES; ++ofb) { + __attribute__((opencl_unroll_hint)) + for (uint os = 0; os < OUT_BLOCK_SPATIAL; ++os) { + const uint partial_acc_ptr_idx = + ((wg != 0) ? OUT_BLOCK_SPATIAL * OUT_BLOCK_FEATURES * SIMD : 0) + + ofb * OUT_BLOCK_SPATIAL * SIMD + + os * SIMD; + partial_acc_ptr[partial_acc_ptr_idx] = dotProd[ofb][os]; + } + } + } } - } else if (subgroup_id == 3) { - __attribute__((opencl_unroll_hint(OUT_BLOCK_WIDTH))) - for (uint i = 0; i < OUT_BLOCK_WIDTH; i++) { - partial_acc[4 * 16 * OUT_BLOCK_WIDTH + i * 16 + subgroup_local_id] = dotProd[i]; - partial_acc[5 * 16 * OUT_BLOCK_WIDTH + i * 16 + subgroup_local_id] = dotProd[i + OUT_BLOCK_WIDTH]; + } else { + __attribute__((opencl_unroll_hint)) + for (uint ofb = 0; ofb < OUT_BLOCK_FEATURES; ++ofb) { + __attribute__((opencl_unroll_hint)) + for (uint os = 0; os < OUT_BLOCK_SPATIAL; ++os) { + const uint partial_acc_ptr_idx = + ofb * OUT_BLOCK_SPATIAL * SIMD + + os * SIMD; + partial_acc_ptr[partial_acc_ptr_idx] = dotProd[ofb][os]; + } } } barrier(CLK_LOCAL_MEM_FENCE); - if (subgroup_id >= 2) + + if (get_sub_group_id() >= OUT_BLOCK_FEATURES) return; - __attribute__((opencl_unroll_hint(OUT_BLOCK_WIDTH))) - for (uint i = 0; i < OUT_BLOCK_WIDTH; i++) { - dotProd[i] += partial_acc[(i + subgroup_id * OUT_BLOCK_WIDTH) * 16 + subgroup_local_id]; - dotProd[i] += partial_acc[(i + (subgroup_id + 2) * OUT_BLOCK_WIDTH) * 16 + subgroup_local_id]; - dotProd[i] += partial_acc[(i + (subgroup_id + 4) * OUT_BLOCK_WIDTH) * 16 + subgroup_local_id]; + + partial_acc_ptr = partial_acc + get_sub_group_id() * OUT_BLOCK_SPATIAL * SIMD + sglid; + __attribute__((opencl_unroll_hint)) + for (uint wg = 0; wg < FEATURE_SLM_SPLIT - 1; ++wg) { + __attribute__((opencl_unroll_hint)) + for (uint os = 0; os < OUT_BLOCK_SPATIAL; ++os) { + const uint partial_acc_ptr_idx = + wg * OUT_BLOCK_FEATURES * SIMD * OUT_BLOCK_SPATIAL + + os * SIMD; + dotProd[0][os] += partial_acc_ptr[partial_acc_ptr_idx]; + } } #endif -#if FEATURE_LWS_SPLIT == 1 -# define OUTPUT_FEATURES_PER_WI 2 -# if BIAS_TERM - BIAS_TYPE bias[OUTPUT_FEATURES_PER_WI] = { biases[out_f], biases[out_f + 16] }; -# endif +#if FEATURE_SLM_SPLIT == 1 +# define FINAL_OUT_BLOCK_FEATURES (OUT_BLOCK_FEATURES) #else -# define OUTPUT_FEATURES_PER_WI 1 -# if BIAS_TERM - BIAS_TYPE bias[OUTPUT_FEATURES_PER_WI] = { biases[out_f + subgroup_id * 16] }; -# endif -#endif +# define FINAL_OUT_BLOCK_FEATURES 1 + out_f += get_sub_group_id() * SIMD; + out_fg += get_sub_group_id() * SIMD; - for (uint j = 0; j < OUTPUT_FEATURES_PER_WI; j++) { - uint out_f_offset = subgroup_id * 16 + j * 16; - -#if OUTPUT_FEATURE_NUM % 32 != 0 && OUTPUT_FEATURE_NUM % 32 <= 16 - if (out_fg + 32 > OUTPUT_FEATURE_NUM && out_f_offset >= OUTPUT_FEATURE_NUM % 32) - break; + if (CEIL_DIV(OUTPUT_FEATURE_NUM, SIMD) % OUT_BLOCK_FEATURES != 0 && out_fg >= OUTPUT_FEATURE_NUM) + return; #endif - const uint dst_index = GET_DATA_B_FS_YX_FSV16_INDEX(OUTPUT, out_b, out_f + out_f_offset, out_y, out_x); -#if HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD - FUSED_OPS_PRELOAD +#if BIAS_TERM + // Preload bias + BIAS_TYPE bias_val[FINAL_OUT_BLOCK_FEATURES]; + for (uint ofb = 0; ofb < FINAL_OUT_BLOCK_FEATURES; ++ofb) { + bias_val[ofb] = biases[out_f + ofb * SIMD]; + } #endif - __attribute__((opencl_unroll_hint(OUT_BLOCK_WIDTH))) - for (uint i = 0; i < OUT_BLOCK_WIDTH; i++) { -#if OUTPUT_SIZE_X % OUT_BLOCK_WIDTH != 0 - if (out_x + OUT_BLOCK_WIDTH > OUTPUT_SIZE_X && i >= OUTPUT_SIZE_X % OUT_BLOCK_WIDTH) - break; -#endif - ACTIVATION_TYPE dequantized = (ACTIVATION_TYPE)0; + // Convert accumulator type to activation type + ACTIVATION_TYPE dequantized[FINAL_OUT_BLOCK_FEATURES][OUT_BLOCK_SPATIAL]; + __attribute__((opencl_unroll_hint)) + for (uint ofb = 0; ofb < FINAL_OUT_BLOCK_FEATURES; ++ofb) { + __attribute__((opencl_unroll_hint)) + for (uint os = 0; os < OUT_BLOCK_SPATIAL; ++os) { + dequantized[ofb][os] = TO_ACTIVATION_TYPE(dotProd[ofb][os]); + #if BIAS_TERM - dequantized = (ACTIVATION_TYPE)dotProd[OUT_BLOCK_WIDTH * j + i] + bias[j]; -#else - dequantized = (ACTIVATION_TYPE)dotProd[OUT_BLOCK_WIDTH * j + i]; + dequantized[ofb][os] += TO_ACTIVATION_TYPE(bias_val[ofb]); +#endif + } + } + + // Fused ops/activation + OUTPUT_TYPE result[FINAL_OUT_BLOCK_FEATURES][OUT_BLOCK_SPATIAL]; + __attribute__((opencl_unroll_hint)) + for (uint ofb = 0; ofb < FINAL_OUT_BLOCK_FEATURES; ++ofb) { +#if HAS_FUSED_OPS && FUSED_OPS_CAN_USE_PRELOAD_SCALAR + FUSED_OPS_PRELOAD_SCALAR; #endif - OUTPUT_TYPE result; + __attribute__((opencl_unroll_hint)) + for (uint os = 0; os < OUT_BLOCK_SPATIAL; ++os) { #if HAS_FUSED_OPS - #if FUSED_OPS_CAN_USE_PRELOAD - FUSED_OPS_CALC - #else - FUSED_OPS - #endif - result = FUSED_OPS_RESULT; + #if FUSED_OPS_CAN_USE_PRELOAD_SCALAR + FUSED_OPS_CALC_SCALAR; + #else + FUSED_OPS_SCALAR; + #endif + result[ofb][os] = FUSED_OPS_RESULT_SCALAR; #else - result = TO_OUTPUT_TYPE(dequantized); + result[ofb][os] = TO_OUTPUT_TYPE(ACTIVATION(dequantized[ofb][os], ACTIVATION_PARAMS)); #endif + } + } + + // Store output + // Check if can use block writes + bool only_x_block = OUTPUT_SIZE_X % OUT_BLOCK_SPATIAL == 0; + bool at_least_one_x_block = OUTPUT_SIZE_X >= OUT_BLOCK_SPATIAL; + bool full_x = out_yx_sg % OUTPUT_SIZE_X <= OUTPUT_SIZE_X - OUT_BLOCK_SPATIAL; + bool can_write_x = only_x_block || (at_least_one_x_block && full_x); -#if OUTPUT_FEATURE_NUM % 16 != 0 - if (out_fg + out_f_offset + 16 > OUTPUT_FEATURE_NUM && subgroup_local_id >= OUTPUT_FEATURE_NUM % 16) - result = (OUTPUT_TYPE)0; + bool no_x_pad = OUTPUT_PAD_BEFORE_SIZE_X == 0 && OUTPUT_PAD_AFTER_SIZE_X == 0; + bool exact_spatial = max_out_yx % OUT_BLOCK_SPATIAL == 0; + bool full_spatial = out_yx_sg <= max_out_yx - OUT_BLOCK_SPATIAL; + bool can_write_spatial = no_x_pad && (exact_spatial || full_spatial); + + bool full_feature_block = (OUTPUT_FEATURE_NUM % SIMD == 0) || (out_fg + FINAL_OUT_BLOCK_FEATURES * SIMD <= OUTPUT_FEATURE_NUM); + + bool can_use_full_block_write = full_feature_block && (can_write_x || can_write_spatial); + if (can_use_full_block_write) { + uint output_idx = OUTPUT_GET_INDEX(out_b, + out_fg, + intel_sub_group_shuffle(out_y_shuffle[0], 0), + intel_sub_group_shuffle(out_x_shuffle[0], 0)); + __attribute__((opencl_unroll_hint)) + for (uint ofb = 0; ofb < FINAL_OUT_BLOCK_FEATURES; ++ofb) { + bool good_of_block = (CEIL_DIV(OUTPUT_FEATURE_NUM, SIMD) % FINAL_OUT_BLOCK_FEATURES == 0) + || (out_fg + FINAL_OUT_BLOCK_FEATURES * SIMD <= OUTPUT_FEATURE_NUM) + || (ofb < CEIL_DIV(OUTPUT_FEATURE_NUM, SIMD) % FINAL_OUT_BLOCK_FEATURES); + if (good_of_block) { + uint os = 0; +#if OUTPUT_TYPE_SIZE == 1 + for (; os + 8 <= OUT_BLOCK_SPATIAL; os += 8) { + MAKE_VECTOR_TYPE(OUTPUT_TYPE, 8) result_val; + __attribute__((opencl_unroll_hint)) + for (uint i = 0; i < 8; ++i) { + result_val[i] = result[ofb][os + i]; + } + DT_OUTPUT_BLOCK_WRITE8(output, output_idx, result_val); + output_idx += 8 * SIMD; + } +#endif +#if OUTPUT_TYPE_SIZE <= 2 + for (; os + 4 <= OUT_BLOCK_SPATIAL; os += 4) { + MAKE_VECTOR_TYPE(OUTPUT_TYPE, 4) result_val; + __attribute__((opencl_unroll_hint)) + for (uint i = 0; i < 4; ++i) { + result_val[i] = result[ofb][os + i]; + } + DT_OUTPUT_BLOCK_WRITE4(output, output_idx, result_val); + output_idx += 4 * SIMD; + } #endif - output[dst_index + i * 16] = result; + for (; os + 2 <= OUT_BLOCK_SPATIAL; os += 2) { + MAKE_VECTOR_TYPE(OUTPUT_TYPE, 2) result_val; + __attribute__((opencl_unroll_hint)) + for (uint i = 0; i < 2; ++i) { + result_val[i] = result[ofb][os + i]; + } + DT_OUTPUT_BLOCK_WRITE2(output, output_idx, result_val); + output_idx += 2 * SIMD; + } + if (OUT_BLOCK_SPATIAL % 2 == 1) { + OUTPUT_TYPE result_val = result[ofb][os]; + DT_OUTPUT_BLOCK_WRITE(output, output_idx, result_val); + output_idx += 1 * SIMD; + } + } + output_idx += OUTPUT_FEATURE_PITCH * FSV - OUT_BLOCK_SPATIAL * SIMD; + } + } else { + uint output_idx_shuffle[CEIL_DIV(OUT_BLOCK_SPATIAL, SIMD)] = { }; + __attribute__((opencl_unroll_hint)) + for (uint os = 0; os < CEIL_DIV(OUT_BLOCK_SPATIAL, SIMD); ++os) { + output_idx_shuffle[os] = OUTPUT_GET_INDEX(out_b, out_fg, out_y_shuffle[os], out_x_shuffle[os]); + } + __attribute__((opencl_unroll_hint)) + for (uint ofb = 0; ofb < FINAL_OUT_BLOCK_FEATURES; ++ofb) { + bool good_of_block = (CEIL_DIV(OUTPUT_FEATURE_NUM, SIMD) % FINAL_OUT_BLOCK_FEATURES == 0) + || (out_fg + FINAL_OUT_BLOCK_FEATURES * SIMD <= OUTPUT_FEATURE_NUM) + || (ofb < CEIL_DIV(OUTPUT_FEATURE_NUM, SIMD) % FINAL_OUT_BLOCK_FEATURES); + if (good_of_block) { + __attribute__((opencl_unroll_hint)) + for (uint os = 0; os < OUT_BLOCK_SPATIAL; ++os) { + bool good_os = (max_out_yx % OUT_BLOCK_SPATIAL == 0) || (out_yx_sg <= max_out_yx - OUT_BLOCK_SPATIAL) || (os < max_out_yx % OUT_BLOCK_SPATIAL); + if (!good_os) + break; + + uint output_idx = intel_sub_group_shuffle(output_idx_shuffle[os / SIMD], os % SIMD); + bool good_of = (OUTPUT_FEATURE_NUM % SIMD == 0) || (out_f + ofb * SIMD < OUTPUT_FEATURE_NUM); + + if (!good_of) + result[ofb][os] = (OUTPUT_TYPE)0; + + output[output_idx + sglid] = result[ofb][os]; + } + } + + __attribute__((opencl_unroll_hint)) + for (uint os = 0; os < CEIL_DIV(OUT_BLOCK_SPATIAL, SIMD); ++os) { + output_idx_shuffle[os] += OUTPUT_FEATURE_PITCH * FSV; + } } } -#undef OUTPUT_FEATURES_PER_WI +#undef FINAL_OUT_BLOCK_FEATURES } #undef AS_INPUT0_TYPE_4 +#undef AS_FILTER_TYPE_4 #undef AS_TYPE_N #undef AS_TYPE_N_ -#undef MAKE_VECTOR_TYPE -#undef TO_ACTIVATION_TYPE -#undef ACTIVATION_TYPE -#undef TO_ACCUMULATOR_TYPE -#undef ACCUMULATOR_TYPE #undef CEIL_DIV #undef ALIGN + +#undef FSV +#undef SIMD diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/include/fetch.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/include/fetch.cl index f51dcb7a39594d..e48227fefcb1e7 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/include/fetch.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/include/fetch.cl @@ -281,32 +281,65 @@ inline uint FUNC(get_b_fs_yx_fsv_index_safe)(uint b, uint f, uint y, uint x, CAT(prefix, _OFFSET) \ ) -#define GET_FILTER_OS_IS_YX_OSV16_ISV16_INDEX(prefix, o, i, y, x) \ - FUNC_CALL(get_os_is_yx_osv16_isv16_index)( \ - o, i, y, x, \ - CAT(prefix, _SIZE_X), \ - CAT(prefix, _SIZE_Y), \ - CAT(prefix, _IFM_NUM), \ - CAT(prefix, _OFM_NUM)) - -inline uint FUNC(get_os_is_yx_osv16_isv16_index)(uint o, uint i, uint y, uint x, - uint x_size, uint y_size, uint i_size, uint o_size) +inline uint FUNC(get_os_is_zyx_osv_isv_index)(uint o, uint i, uint z, uint y, uint x, + uint x_size, uint y_size, uint z_size, uint i_size, uint o_size, uint osv_size, uint isv_size) { - const uint isv = i % 16; - const uint osv = o % 16; - const uint is = i / 16; - const uint os = o / 16; + const uint isv = i % isv_size; + const uint osv = o % osv_size; + const uint is = i / isv_size; + const uint os = o / osv_size; - const uint x_pitch = 16 * 16; + const uint x_pitch = osv_size * isv_size; const uint y_pitch = x_pitch * x_size; - const uint is_pitch = y_pitch * y_size; - const uint os_pitch = is_pitch * ((i_size + 16 - 1) / 16); - - const uint output_offset = isv + osv * 16 + x * x_pitch + y * y_pitch + is * is_pitch + os * os_pitch; + const uint z_pitch = y_pitch * y_size; + const uint is_pitch = z_pitch * z_size; + const uint os_pitch = is_pitch * ((i_size + isv_size - 1) / isv_size); + + const uint output_offset = + isv + + osv * isv_size + + x * x_pitch + + y * y_pitch + + z * z_pitch + + is * is_pitch + + os * os_pitch; return output_offset; } +#define GET_FILTER_OS_IS_YX_OSV16_ISV16_INDEX(prefix, o, i, y, x) \ + FUNC_CALL(get_os_is_zyx_osv_isv_index)( \ + o, i, 0, y, x, \ + CAT(prefix, _SIZE_X), \ + CAT(prefix, _SIZE_Y), \ + 1, \ + CAT(prefix, _IFM_NUM), \ + CAT(prefix, _OFM_NUM), \ + 16, \ + 16) + +#define GET_FILTER_OS_IS_ZYX_OSV32_ISV16_INDEX(prefix, o, i, z, y, x) \ + FUNC_CALL(get_os_is_zyx_osv_isv_index)( \ + o, i, z, y, x, \ + CAT(prefix, _SIZE_X), \ + CAT(prefix, _SIZE_Y), \ + CAT(prefix, _SIZE_Z), \ + CAT(prefix, _IFM_NUM), \ + CAT(prefix, _OFM_NUM), \ + 32, \ + 16) + +#define GET_FILTER_OS_IS_ZYX_OSV64_ISV16_INDEX(prefix, o, i, z, y, x) \ + FUNC_CALL(get_os_is_zyx_osv_isv_index)( \ + o, i, z, y, x, \ + CAT(prefix, _SIZE_X), \ + CAT(prefix, _SIZE_Y), \ + CAT(prefix, _SIZE_Z), \ + CAT(prefix, _IFM_NUM), \ + CAT(prefix, _OFM_NUM), \ + 64, \ + 16) + #define GET_FILTER_G_OS_IS_YX_ISV8_OSV16_ISV2_INDEX(prefix, g, o, i, y, x, sub_group_size) \ FUNC_CALL(get_os_is_zyx_isv8_osv16_isv2_index)( \ g, o, i, 0, y, x, \ diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/reorder_weights.cl b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/reorder_weights.cl index 5bdb29ecde971b..2fdfaf94568885 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/reorder_weights.cl +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/cl_kernels/reorder_weights.cl @@ -93,6 +93,10 @@ inline uint FUNC(get_input_index)(uint g, uint o, uint i, uint z, uint y, uint x return GET_FILTER_GOIYX(INPUT0, g, o, i, y, x); #elif defined INPUT0_LAYOUT_OS_IS_YX_OSV16_ISV16 return GET_FILTER_OS_IS_YX_OSV16_ISV16_INDEX(INPUT0, o, i, y, x); +#elif defined INPUT0_LAYOUT_OS_IS_ZYX_OSV32_ISV16 + return GET_FILTER_OS_IS_ZYX_OSV32_ISV16_INDEX(INPUT0, o, i, z, y, x); +#elif defined INPUT0_LAYOUT_OS_IS_ZYX_OSV64_ISV16 + return GET_FILTER_OS_IS_ZYX_OSV64_ISV16_INDEX(INPUT0, o, i, z, y, x); #elif defined INPUT0_LAYOUT_GS_OI_YXS_GSV16_YXSV4 return GET_FILTER_GS_OI_YXS_GSV16_YXSV4_INDEX(INPUT0, g, o, i, y, x); #elif defined INPUT0_LAYOUT_GS_OI_YXS_GSV32_YXSV4 @@ -220,6 +224,10 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint return GET_FILTER_G_OS_IS_YX_ISV16_OSV16_INDEX(OUTPUT, g, o, i, y, x, SUB_GROUP_SIZE); #elif defined OUTPUT_LAYOUT_OS_IS_YX_OSV16_ISV16 return GET_FILTER_OS_IS_YX_OSV16_ISV16_INDEX(OUTPUT, o, i, y, x); +#elif defined OUTPUT_LAYOUT_OS_IS_ZYX_OSV32_ISV16 + return GET_FILTER_OS_IS_ZYX_OSV32_ISV16_INDEX(OUTPUT, o, i, z, y, x); +#elif defined OUTPUT_LAYOUT_OS_IS_ZYX_OSV64_ISV16 + return GET_FILTER_OS_IS_ZYX_OSV64_ISV16_INDEX(OUTPUT, o, i, z, y, x); #elif defined OUTPUT_LAYOUT_GS_OI_YXS_GSV16_YXSV4 return GET_FILTER_GS_OI_YXS_GSV16_YXSV4_INDEX(OUTPUT, g, o, i, y, x); #elif defined OUTPUT_LAYOUT_GS_OI_YXS_GSV32_YXSV4 diff --git a/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.cpp b/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.cpp index 04d72b396eee9f..d2539c19171d72 100644 --- a/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.cpp +++ b/inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.cpp @@ -307,6 +307,8 @@ std::string toString(WeightsLayout layout) { case WeightsLayout::yxio: return "YXIO"; case WeightsLayout::os_is_yx_isv16_osv16: return "OS_IS_YX_ISV16_OSV16"; case WeightsLayout::os_is_yx_osv16_isv16: return "OS_IS_YX_OSV16_ISV16"; + case WeightsLayout::os_is_zyx_osv32_isv16: return "OS_IS_ZYX_OSV32_ISV16"; + case WeightsLayout::os_is_zyx_osv64_isv16: return "OS_IS_ZYX_OSV64_ISV16"; case WeightsLayout::os_iyx_osv16: return "OS_IYX_OSV16"; case WeightsLayout::os_iyx_osv32: return "OS_IYX_OSV32"; case WeightsLayout::os_iyx_osv32__ai32: return "OS_IYX_OSV32__AI32";