Skip to content

Commit

Permalink
[IE CLDNN] Allow fusing FQ to deconvolution (openvinotoolkit#2875)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-chaiko authored and mryzhov committed Dec 15, 2020
1 parent e718172 commit 31ac9ac
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 595 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ ParamsKey DeconvolutionKernel_b_fs_zyx_fsv16::GetSupportedKey() const {
k.EnableInputWeightsType(WeightsType::F32);
k.EnableInputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::UINT8);
k.EnableInputWeightsType(WeightsType::F16);
k.EnableInputLayout(DataLayout::b_fs_yx_fsv16);
k.EnableOutputLayout(DataLayout::b_fs_yx_fsv16);
Expand All @@ -44,6 +46,7 @@ ParamsKey DeconvolutionKernel_b_fs_zyx_fsv16::GetSupportedKey() const {
k.EnableBatching();
k.EnableSubGroup();
k.EnableSubGroupShort();
k.EnableDifferentTypes();
return k;
}

Expand Down Expand Up @@ -155,10 +158,11 @@ JitConstants DeconvolutionKernel_b_fs_zyx_fsv16::GetJitConstants(const deconvolu
}
jit.AddConstant(MakeJitConstant("OC_BLOCK", 16));

if (output.GetDType() == Datatype::F32)
if (input.GetDType() == Datatype::F32) {
jit.AddConstant(MakeJitConstant("DT_F32", 1));
else
} else {
jit.AddConstant(MakeJitConstant("DT_F16", 1));
}

auto mb_block = 1;
auto ic_block = 16;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ ParamsKey DeconvolutionKernel_bfyx_opt::GetSupportedKey() const {
k.EnableInputWeightsType(WeightsType::F32);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::UINT8);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableTensorOffset();
Expand All @@ -36,6 +38,7 @@ ParamsKey DeconvolutionKernel_bfyx_opt::GetSupportedKey() const {
k.EnableSplitSupport();
k.EnableDepthwiseSeparableOpt();
k.EnableGroupedConvolution();
k.EnableDifferentTypes();
return k;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019 Intel Corporation
* Copyright 2019-2020 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,10 +14,21 @@
* limitations under the License.
*******************************************************************************/

#include "ocl_types.h"
#include "include/fetch.cl"
#include "include/data_types.cl"

#define INPUT_TYPE8 MAKE_VECTOR_TYPE(INPUT0_TYPE, 8)
#define OUTPUT_TYPE8 MAKE_VECTOR_TYPE(OUTPUT_TYPE, 8)
#define FILTER_TYPE8 MAKE_VECTOR_TYPE(FILTER_TYPE, 8)

#if DT_F16 == 1
#define FMA_ARG_TYPE half
#define FMA_ARG_TYPE8 half8
#else
#define FMA_ARG_TYPE INPUT0_TYPE
#define FMA_ARG_TYPE8 INPUT_TYPE8
#endif

#if ID > 1
#define CASE_3D 1
#else
Expand All @@ -31,11 +42,11 @@ __attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) // attr:no-format
__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) // attr:no-format
#endif
KERNEL(gen9_common_conv_bwd_data_kernel)(
const __global DATA_T *diff_dst,
__global DATA_T * restrict diff_src,
const __global DATA_T *wei,
const __global INPUT0_TYPE *diff_dst,
__global OUTPUT_TYPE * restrict diff_src,
const __global FILTER_TYPE *wei,
#if WITH_BIAS
const __global DATA_T *bias,
const __global BIAS_TYPE *bias,
#endif
#if HAS_FUSED_OPS_DECLS
FUSED_OPS_DECLS,
Expand Down Expand Up @@ -76,11 +87,11 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
diff_dst += input_offset + mb * OC_FULL * G * OD_FULL * OH_FULL * OW_FULL + g * OC * OD_FULL * OH_FULL * OW_FULL * MB_BLOCK;

#if WITH_BIAS
DATA8_T blockC00 = (DATA8_T)bias[g * IC + gic * IC_BLOCK + local_id];
DATA8_T blockC01 = (DATA8_T)bias[g * IC + gic * IC_BLOCK + local_id];
INPUT_TYPE8 blockC00 = (INPUT_TYPE8)bias[g * IC + gic * IC_BLOCK + local_id];
INPUT_TYPE8 blockC01 = (INPUT_TYPE8)bias[g * IC + gic * IC_BLOCK + local_id];
#else
DATA8_T blockC00 = 0.0f;
DATA8_T blockC01 = 0.0f;
INPUT_TYPE8 blockC00 = INPUT0_VAL_ZERO;
INPUT_TYPE8 blockC01 = INPUT0_VAL_ZERO;
#endif

wei += gic * KD * KH * KW * OC_BLOCK * IC_BLOCK
Expand Down Expand Up @@ -111,13 +122,13 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
#endif
if (oh >= OH || ow >= OW) continue;

const __global DATA_T *diff_dst1 = diff_dst
const __global INPUT0_TYPE *diff_dst1 = diff_dst
+ ow * OC_BLOCK * MB_BLOCK
+ oh * OW_FULL * OC_BLOCK * MB_BLOCK;
#if CASE_3D
diff_dst1 += od * OH_FULL * OW_FULL * OC_BLOCK * MB_BLOCK;
#endif
const __global DATA_T *wei1 = wei
const __global FILTER_TYPE *wei1 = wei
#if CASE_3D
+ kd * KH * KW * OC_BLOCK * IC_BLOCK
#endif
Expand Down Expand Up @@ -148,44 +159,30 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
#if SW != 1 || SH != 1 || SD != 1 || PH != 0 || PW != 0 || PD != 0
if (do_ker) {
#endif
const __global DATA_T *diff_dst1 = diff_dst
const __global INPUT0_TYPE *diff_dst1 = diff_dst
+ ow * OC_BLOCK * MB_BLOCK + oh * OW_FULL * OC_BLOCK * MB_BLOCK;
#if CASE_3D
diff_dst1 += od * OH_FULL * OW_FULL * OC_BLOCK * MB_BLOCK;
#endif
const __global DATA_T *wei1 = wei;
const __global FILTER_TYPE *wei1 = wei;
#endif

#define LOAD_DIFF_DST(_block, _diff_dst, mb_chunk) \
{ \
(_block) = AS_DATA8_T( \
BLOCK_READ8((const __global BLOCK_DATA_T *)((_diff_dst) \
+ (mb_chunk)*OC_BLOCK))); \
}

#define SAVE_SRC_DIFF(_block, _diff_src, mb_chunk) \
{ \
BLOCK_WRITE8((const __global BLOCK_DATA_T *)(&( \
_diff_src)[(mb_chunk)*IC_BLOCK]), \
AS_BLOCK_DATA8_T((_block))); \
}

#if DT_F32
#define TRANSPOSE_8(_block, _col) \
(DATA8_T)(intel_sub_group_shuffle(_block, _col))
(intel_sub_group_shuffle(_block, _col))
#else
#define TRANSPOSE_8(_block, _col) \
(DATA8_T)(intel_sub_group_shuffle(_block[0], _col), \
intel_sub_group_shuffle(_block[1], _col), \
intel_sub_group_shuffle(_block[2], _col), \
intel_sub_group_shuffle(_block[3], _col), \
intel_sub_group_shuffle(_block[4], _col), \
intel_sub_group_shuffle(_block[5], _col), \
intel_sub_group_shuffle(_block[6], _col), \
intel_sub_group_shuffle(_block[7], _col))
(intel_sub_group_shuffle(_block[0], _col), \
intel_sub_group_shuffle(_block[1], _col), \
intel_sub_group_shuffle(_block[2], _col), \
intel_sub_group_shuffle(_block[3], _col), \
intel_sub_group_shuffle(_block[4], _col), \
intel_sub_group_shuffle(_block[5], _col), \
intel_sub_group_shuffle(_block[6], _col), \
intel_sub_group_shuffle(_block[7], _col))
#endif

#define FMA8(a, b, c) fma((DATA8_T)(a), (DATA8_T)b, (DATA8_T)c)
#define FMA8(a, b, c) fma((FMA_ARG_TYPE8)(a), (FMA_ARG_TYPE8)b, (FMA_ARG_TYPE8)c)

#define MULTIPLY_BLOCKS_8x8(_result, _blockA, _blockB, _blockB1) \
{ \
Expand All @@ -207,14 +204,10 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
_result = FMA8(_blockB1.s7, TRANSPOSE_8(_blockA, 15), _result); \
}

DATA8_T blockA0, blockA1;
LOAD_DIFF_DST(blockA0, diff_dst1, 0);
LOAD_DIFF_DST(blockA1, diff_dst1, 8);
DATA8_T blockB00 = AS_DATA8_T(
BLOCK_READ8((const __global BLOCK_DATA_T *)wei1));
DATA8_T blockB01 = AS_DATA8_T(
BLOCK_READ8((const __global BLOCK_DATA_T *)(wei1
+ 8 * IC_BLOCK)));
INPUT_TYPE8 blockA0 = DT_INPUT_BLOCK_READ(diff_dst1, 0);
INPUT_TYPE8 blockA1 = DT_INPUT_BLOCK_READ(diff_dst1, 8 * OC_BLOCK);
FILTER_TYPE8 blockB00 = DT_FILTER_BLOCK_READ8(wei1, 0);
FILTER_TYPE8 blockB01 = DT_FILTER_BLOCK_READ8(wei1, 8 * IC_BLOCK);
MULTIPLY_BLOCKS_8x8(blockC00, blockA0, blockB00, blockB01);
MULTIPLY_BLOCKS_8x8(blockC01, blockA1, blockB00, blockB01);

Expand All @@ -232,28 +225,32 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
ocb += OC_BLOCK;
} while (ocb < OC);

__global DATA_T *src_write0 = diff_src + OUTPUT_OFFSET + mb * IC_FULL * G * ID_FULL * IH_FULL * IW_FULL
__global OUTPUT_TYPE *src_write0 = diff_src + OUTPUT_OFFSET + mb * IC_FULL * G * ID_FULL * IH_FULL * IW_FULL
+ gic * ID_FULL * IH_FULL * IW_FULL * IC_BLOCK * MB_BLOCK
+ g * IC * ID_FULL * IH_FULL * IW_FULL * MB_BLOCK
+ id * IH_FULL * IW_FULL * IC_BLOCK * MB_BLOCK + ih * IW_FULL * IC_BLOCK * MB_BLOCK
+ iw * IC_BLOCK * MB_BLOCK;

blockC00 = ACTIVATION(blockC00, ACTIVATION_PARAMS);
blockC01 = ACTIVATION(blockC01, ACTIVATION_PARAMS);
OUTPUT_TYPE8 res0, res1;

#if HAS_FUSED_OPS
{
FUSED_OPS_BLOCK_C00;
blockC00 = FUSED_OPS_RESULT_BLOCK_C00;
res0 = FUSED_OPS_RESULT_BLOCK_C00;
}
{
FUSED_OPS_BLOCK_C01;
blockC01 = FUSED_OPS_RESULT_BLOCK_C01;
res1 = FUSED_OPS_RESULT_BLOCK_C01;
}
#else
res0 = blockC00;
res1 = blockC01;
#endif

SAVE_SRC_DIFF(blockC00, src_write0, 0);
SAVE_SRC_DIFF(blockC01, src_write0, 8);
DT_OUTPUT_BLOCK_WRITE8(src_write0, 0, res0);
DT_OUTPUT_BLOCK_WRITE8(src_write0, 8 * IC_BLOCK, res1);

#endif
#if VER_8OW16C == 1
Expand All @@ -278,7 +275,7 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
const int iw = (ihw % IWB) * IW_BLOCK;

diff_dst += input_offset + mb * OC_FULL * G * OD_FULL * OH_FULL * OW_FULL + g * OC * OD_FULL * OH_FULL * OW_FULL * MB_BLOCK;
DATA_T blockC00[IW_BLOCK] = {0.0f};
INPUT0_TYPE blockC00[IW_BLOCK] = {INPUT0_VAL_ZERO};

#if WITH_BIAS
for (int i = 0; i < IW_BLOCK; i++)
Expand Down Expand Up @@ -307,12 +304,12 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
oh /= SH;
if (oh >= OH) continue;

const __global DATA_T *diff_dst1
const __global INPUT0_TYPE *diff_dst1
= diff_dst + oh * OW_FULL * OC_BLOCK * MB_BLOCK;
#if CASE_3D
diff_dst1 += od * OH_FULL * OW_FULL * OC_BLOCK * MB_BLOCK;
#endif
const __global DATA_T *wei1 = wei
const __global FILTER_TYPE *wei1 = wei
#if CASE_3D
+ kd * KH * KW * OC_BLOCK * IC_BLOCK
#endif
Expand Down Expand Up @@ -341,21 +338,21 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
#if SW != 1 || SH != 1 || SD != 1 || PH != 0 || PW != 0 || PD != 0
if (do_ker) {
#endif
const __global DATA_T *diff_dst1
const __global INPUT0_TYPE *diff_dst1
= diff_dst + oh * OW_FULL * OC_BLOCK * MB_BLOCK;
#if CASE_3D
diff_dst1 += od * OH_FULL * OW_FULL * OC_BLOCK * MB_BLOCK;
#endif
const __global DATA_T *wei1 = wei;
const __global FILTER_TYPE *wei1 = wei;
#endif

int ocb = 0;
do {

#define TRANSPOSE_1(_block, _col) \
(DATA_T)(intel_sub_group_shuffle(_block, _col))
(intel_sub_group_shuffle(_block, _col))

#define FMA1(a, b, c) fma((DATA_T)(a), (DATA_T)b, (DATA_T)c)
#define FMA1(a, b, c) fma((FMA_ARG_TYPE)(a), (FMA_ARG_TYPE)b, (FMA_ARG_TYPE)c)

#define MULTIPLY_BLOCKS_8x8(_result, _blockA, _blockB, _blockB1) \
{ \
Expand All @@ -377,12 +374,9 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
_result = FMA1(_blockB1.s7, TRANSPOSE_1(_blockA, 15), _result); \
}

DATA8_T blockB00 = AS_DATA8_T(
BLOCK_READ8((const __global BLOCK_DATA_T *)wei1));
DATA8_T blockB01 = AS_DATA8_T(
BLOCK_READ8((const __global BLOCK_DATA_T *)(wei1
+ 8 * IC_BLOCK)));
DATA_T blockA[IW_BLOCK];
FILTER_TYPE8 blockB00 = DT_FILTER_BLOCK_READ8(wei1, 0);
FILTER_TYPE8 blockB01 = DT_FILTER_BLOCK_READ8(wei1, 8 * IC_BLOCK);
INPUT0_TYPE blockA[IW_BLOCK];

__attribute__((
opencl_unroll_hint(IW_BLOCK))) // attr:no-format
Expand All @@ -407,9 +401,7 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
blockA[i] = 0.0;
continue;
}
blockA[i] = AS_DATA_T(
BLOCK_READ((const __global BLOCK_DATA_T *)(&(
diff_dst1)[ow * OC_BLOCK])));
blockA[i] = DT_INPUT_BLOCK_READ(diff_dst1, ow * OC_BLOCK);
}

__attribute__((
Expand All @@ -434,7 +426,7 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
#endif
#endif

__global DATA_T *src_write0 = diff_src + output_offset + mb * IC_FULL * G * ID_FULL * IH_FULL * IW_FULL
__global OUTPUT_TYPE *src_write0 = diff_src + output_offset + mb * IC_FULL * G * ID_FULL * IH_FULL * IW_FULL
+ gic * ID_FULL * IH_FULL * IW_FULL * IC_BLOCK * MB_BLOCK
+ g * IC * ID_FULL * IH_FULL * IW_FULL * MB_BLOCK
+ id * IH_FULL * IW_FULL * IC_BLOCK * MB_BLOCK + ih * IW_FULL * IC_BLOCK * MB_BLOCK
Expand All @@ -443,12 +435,14 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
for (int i = 0; i < IW_BLOCK; i++) {
blockC00[i] = ACTIVATION(blockC00[i], ACTIVATION_PARAMS);
if (iw + i >= IW) continue;
OUTPUT_TYPE res;
#if HAS_FUSED_OPS
FUSED_OPS_BLOCK_CI;
blockC00[i] = FUSED_OPS_RESULT_BLOCK_CI;
res = FUSED_OPS_RESULT_BLOCK_CI;
#else
res = blockC00[i];
#endif
BLOCK_WRITE((__global BLOCK_DATA_T *)(&(src_write0)[i * IC_BLOCK]),
AS_BLOCK_DATA_T(blockC00[i]));
DT_OUTPUT_BLOCK_WRITE(src_write0, i * IC_BLOCK, res);
}
#endif
}
Expand Down
Loading

0 comments on commit 31ac9ac

Please sign in to comment.