Skip to content

Commit

Permalink
[IE CLDNN] Optimize 1x1 imad convolution kernel (#757)
Browse files Browse the repository at this point in the history
  • Loading branch information
Konrad Dobros authored Jun 8, 2020
1 parent 626bc4f commit d155483
Show file tree
Hide file tree
Showing 8 changed files with 624 additions and 250 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ WeightsTensor::WeightsChannelArray WeightsTensor::weightsChannelArray {{
{ WeightsLayout::os_i_osv16__ai8, { -1, -1, -1, 0, 1, -1, -1, -1 } },
{ WeightsLayout::os_i_osv16, { -1, -1, -1, 0, 1, -1, -1, -1 } },
{ WeightsLayout::os_is_yx_osv16_isv16, { 0, 1, -1, 2, 3, -1, -1, -1 } },
{ WeightsLayout::os_is_zyx_osv32_isv16, { 0, 1, 2, 3, 4, -1, -1, -1 } },
{ WeightsLayout::os_is_zyx_osv64_isv16, { 0, 1, 2, 3, 4, -1, -1, -1 } },
{ WeightsLayout::i_yxs_os_yxsv2_osv16, { 1, 2, -1, 3, 0, -1, -1, -1 } },
{ WeightsLayout::iy_xs_os_xsv2_osv16__ao32, { 1, 2, -1, 3, 0, -1, -1, -1 } },
{ WeightsLayout::iy_xs_os_xsv2_osv8__ao32, { 1, 2, -1, 3, 0, -1, -1, -1 } },
Expand Down Expand Up @@ -633,6 +635,14 @@ NDims WeightsTensor::GetSimpleDims(const std::vector<size_t>& d, WeightsLayout l
newDims[2] = RoundUp(newDims[2], 16);
newDims[3] = RoundUp(newDims[3], 16);
break;
case os_is_zyx_osv32_isv16:
newDims[3] = RoundUp(newDims[3], 16);
newDims[4] = RoundUp(newDims[4], 32);
break;
case os_is_zyx_osv64_isv16:
newDims[3] = RoundUp(newDims[3], 16);
newDims[4] = RoundUp(newDims[4], 64);
break;
case gs_oi_yxs_gsv16_yxsv4:
newDims[4] = RoundUp(newDims[4], 16);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ enum WeightsLayout {
os_i_osv16__ai8,
os_i_osv16,
os_is_yx_osv16_isv16, // wieghts for int8 blocked conv
os_is_zyx_osv32_isv16,
os_is_zyx_osv64_isv16,
i_yxs_os_yxsv2_osv16,
iy_xs_os_xsv2_osv16__ao32,
iy_xs_os_xsv2_osv8__ao32,
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ class Convolution_kernel_b_fs_yx_fsv16_imad_1x1 : public ConvolutionKernelBase {
JitConstants GetJitConstants(const convolution_params& params, const DispatchData& kd) const override;
DispatchData SetDefault(const convolution_params& params, int autoTuneIndex = -1) const override;
bool NeedPaddedInput() const override { return true; }
WeightsLayout GetPreferredWeightsLayout(const convolution_params&) const override {
return WeightsLayout::os_is_yx_osv16_isv16;
}
WeightsLayout GetPreferredWeightsLayout(const convolution_params&) const override;

std::vector<FusedOpType> GetSupportedFusedOps() const override {
return { FusedOpType::ELTWISE,
Expand All @@ -50,13 +48,17 @@ class Convolution_kernel_b_fs_yx_fsv16_imad_1x1 : public ConvolutionKernelBase {
}

struct AutoTuneParams {
size_t out_block_width;
bool k_slicing;
size_t out_block_spatial;
size_t out_block_features;
size_t feature_slm_split;
std::string exe_mode;
};
std::vector<AutoTuneParams> all_tune_params;

bool ValidateAutoTuneParams(const convolution_params& params, const AutoTuneParams& tune_params) const;
AutoTuneParams GetAutoTuneParams(const convolution_params& params, int index) const;

float EstimateOccupancy(const convolution_params& params, const AutoTuneParams& tune) const;
float EstimateSLMUsage(const convolution_params& params, const AutoTuneParams& tune) const;
};
} // namespace kernel_selector

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -281,32 +281,65 @@ inline uint FUNC(get_b_fs_yx_fsv_index_safe)(uint b, uint f, uint y, uint x,
CAT(prefix, _OFFSET) \
)

#define GET_FILTER_OS_IS_YX_OSV16_ISV16_INDEX(prefix, o, i, y, x) \
FUNC_CALL(get_os_is_yx_osv16_isv16_index)( \
o, i, y, x, \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \
CAT(prefix, _IFM_NUM), \
CAT(prefix, _OFM_NUM))

inline uint FUNC(get_os_is_yx_osv16_isv16_index)(uint o, uint i, uint y, uint x,
uint x_size, uint y_size, uint i_size, uint o_size)
inline uint FUNC(get_os_is_zyx_osv_isv_index)(uint o, uint i, uint z, uint y, uint x,
uint x_size, uint y_size, uint z_size, uint i_size, uint o_size, uint osv_size, uint isv_size)
{
const uint isv = i % 16;
const uint osv = o % 16;
const uint is = i / 16;
const uint os = o / 16;
const uint isv = i % isv_size;
const uint osv = o % osv_size;
const uint is = i / isv_size;
const uint os = o / osv_size;

const uint x_pitch = 16 * 16;
const uint x_pitch = osv_size * isv_size;
const uint y_pitch = x_pitch * x_size;
const uint is_pitch = y_pitch * y_size;
const uint os_pitch = is_pitch * ((i_size + 16 - 1) / 16);

const uint output_offset = isv + osv * 16 + x * x_pitch + y * y_pitch + is * is_pitch + os * os_pitch;
const uint z_pitch = y_pitch * y_size;
const uint is_pitch = z_pitch * z_size;
const uint os_pitch = is_pitch * ((i_size + isv_size - 1) / isv_size);

const uint output_offset =
isv +
osv * isv_size +
x * x_pitch +
y * y_pitch +
z * z_pitch +
is * is_pitch +
os * os_pitch;

return output_offset;
}

#define GET_FILTER_OS_IS_YX_OSV16_ISV16_INDEX(prefix, o, i, y, x) \
FUNC_CALL(get_os_is_zyx_osv_isv_index)( \
o, i, 0, y, x, \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \
1, \
CAT(prefix, _IFM_NUM), \
CAT(prefix, _OFM_NUM), \
16, \
16)

#define GET_FILTER_OS_IS_ZYX_OSV32_ISV16_INDEX(prefix, o, i, z, y, x) \
FUNC_CALL(get_os_is_zyx_osv_isv_index)( \
o, i, z, y, x, \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \
CAT(prefix, _SIZE_Z), \
CAT(prefix, _IFM_NUM), \
CAT(prefix, _OFM_NUM), \
32, \
16)

#define GET_FILTER_OS_IS_ZYX_OSV64_ISV16_INDEX(prefix, o, i, z, y, x) \
FUNC_CALL(get_os_is_zyx_osv_isv_index)( \
o, i, z, y, x, \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \
CAT(prefix, _SIZE_Z), \
CAT(prefix, _IFM_NUM), \
CAT(prefix, _OFM_NUM), \
64, \
16)

#define GET_FILTER_G_OS_IS_YX_ISV8_OSV16_ISV2_INDEX(prefix, g, o, i, y, x, sub_group_size) \
FUNC_CALL(get_os_is_zyx_isv8_osv16_isv2_index)( \
g, o, i, 0, y, x, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ inline uint FUNC(get_input_index)(uint g, uint o, uint i, uint z, uint y, uint x
return GET_FILTER_GOIYX(INPUT0, g, o, i, y, x);
#elif defined INPUT0_LAYOUT_OS_IS_YX_OSV16_ISV16
return GET_FILTER_OS_IS_YX_OSV16_ISV16_INDEX(INPUT0, o, i, y, x);
#elif defined INPUT0_LAYOUT_OS_IS_ZYX_OSV32_ISV16
return GET_FILTER_OS_IS_ZYX_OSV32_ISV16_INDEX(INPUT0, o, i, z, y, x);
#elif defined INPUT0_LAYOUT_OS_IS_ZYX_OSV64_ISV16
return GET_FILTER_OS_IS_ZYX_OSV64_ISV16_INDEX(INPUT0, o, i, z, y, x);
#elif defined INPUT0_LAYOUT_GS_OI_YXS_GSV16_YXSV4
return GET_FILTER_GS_OI_YXS_GSV16_YXSV4_INDEX(INPUT0, g, o, i, y, x);
#elif defined INPUT0_LAYOUT_GS_OI_YXS_GSV32_YXSV4
Expand Down Expand Up @@ -220,6 +224,10 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint
return GET_FILTER_G_OS_IS_YX_ISV16_OSV16_INDEX(OUTPUT, g, o, i, y, x, SUB_GROUP_SIZE);
#elif defined OUTPUT_LAYOUT_OS_IS_YX_OSV16_ISV16
return GET_FILTER_OS_IS_YX_OSV16_ISV16_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_ZYX_OSV32_ISV16
return GET_FILTER_OS_IS_ZYX_OSV32_ISV16_INDEX(OUTPUT, o, i, z, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_ZYX_OSV64_ISV16
return GET_FILTER_OS_IS_ZYX_OSV64_ISV16_INDEX(OUTPUT, o, i, z, y, x);
#elif defined OUTPUT_LAYOUT_GS_OI_YXS_GSV16_YXSV4
return GET_FILTER_GS_OI_YXS_GSV16_YXSV4_INDEX(OUTPUT, g, o, i, y, x);
#elif defined OUTPUT_LAYOUT_GS_OI_YXS_GSV32_YXSV4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ std::string toString(WeightsLayout layout) {
case WeightsLayout::yxio: return "YXIO";
case WeightsLayout::os_is_yx_isv16_osv16: return "OS_IS_YX_ISV16_OSV16";
case WeightsLayout::os_is_yx_osv16_isv16: return "OS_IS_YX_OSV16_ISV16";
case WeightsLayout::os_is_zyx_osv32_isv16: return "OS_IS_ZYX_OSV32_ISV16";
case WeightsLayout::os_is_zyx_osv64_isv16: return "OS_IS_ZYX_OSV64_ISV16";
case WeightsLayout::os_iyx_osv16: return "OS_IYX_OSV16";
case WeightsLayout::os_iyx_osv32: return "OS_IYX_OSV32";
case WeightsLayout::os_iyx_osv32__ai32: return "OS_IYX_OSV32__AI32";
Expand Down

0 comments on commit d155483

Please sign in to comment.