Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IE CLDNN] Allow fusing FQ to deconvolution #2875

Merged
merged 6 commits into from
Nov 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ ParamsKey DeconvolutionKernel_b_fs_zyx_fsv16::GetSupportedKey() const {
k.EnableInputWeightsType(WeightsType::F32);
k.EnableInputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::UINT8);
k.EnableInputWeightsType(WeightsType::F16);
k.EnableInputLayout(DataLayout::b_fs_yx_fsv16);
k.EnableOutputLayout(DataLayout::b_fs_yx_fsv16);
Expand All @@ -44,6 +46,7 @@ ParamsKey DeconvolutionKernel_b_fs_zyx_fsv16::GetSupportedKey() const {
k.EnableBatching();
k.EnableSubGroup();
k.EnableSubGroupShort();
k.EnableDifferentTypes();
return k;
}

Expand Down Expand Up @@ -155,10 +158,11 @@ JitConstants DeconvolutionKernel_b_fs_zyx_fsv16::GetJitConstants(const deconvolu
}
jit.AddConstant(MakeJitConstant("OC_BLOCK", 16));

if (output.GetDType() == Datatype::F32)
if (input.GetDType() == Datatype::F32) {
jit.AddConstant(MakeJitConstant("DT_F32", 1));
else
} else {
jit.AddConstant(MakeJitConstant("DT_F16", 1));
}

auto mb_block = 1;
auto ic_block = 16;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ ParamsKey DeconvolutionKernel_bfyx_opt::GetSupportedKey() const {
k.EnableInputWeightsType(WeightsType::F32);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::UINT8);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableTensorOffset();
Expand All @@ -36,6 +38,7 @@ ParamsKey DeconvolutionKernel_bfyx_opt::GetSupportedKey() const {
k.EnableSplitSupport();
k.EnableDepthwiseSeparableOpt();
k.EnableGroupedConvolution();
k.EnableDifferentTypes();
return k;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019 Intel Corporation
* Copyright 2019-2020 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,10 +14,21 @@
* limitations under the License.
*******************************************************************************/

#include "ocl_types.h"
#include "include/fetch.cl"
#include "include/data_types.cl"

#define INPUT_TYPE8 MAKE_VECTOR_TYPE(INPUT0_TYPE, 8)
#define OUTPUT_TYPE8 MAKE_VECTOR_TYPE(OUTPUT_TYPE, 8)
#define FILTER_TYPE8 MAKE_VECTOR_TYPE(FILTER_TYPE, 8)

#if DT_F16 == 1
#define FMA_ARG_TYPE half
#define FMA_ARG_TYPE8 half8
#else
#define FMA_ARG_TYPE INPUT0_TYPE
#define FMA_ARG_TYPE8 INPUT_TYPE8
#endif

#if ID > 1
#define CASE_3D 1
#else
Expand All @@ -31,11 +42,11 @@ __attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) // attr:no-format
__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) // attr:no-format
#endif
KERNEL(gen9_common_conv_bwd_data_kernel)(
const __global DATA_T *diff_dst,
__global DATA_T * restrict diff_src,
const __global DATA_T *wei,
const __global INPUT0_TYPE *diff_dst,
__global OUTPUT_TYPE * restrict diff_src,
const __global FILTER_TYPE *wei,
#if WITH_BIAS
const __global DATA_T *bias,
const __global BIAS_TYPE *bias,
#endif
#if HAS_FUSED_OPS_DECLS
FUSED_OPS_DECLS,
Expand Down Expand Up @@ -76,11 +87,11 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
diff_dst += input_offset + mb * OC_FULL * G * OD_FULL * OH_FULL * OW_FULL + g * OC * OD_FULL * OH_FULL * OW_FULL * MB_BLOCK;

#if WITH_BIAS
DATA8_T blockC00 = (DATA8_T)bias[g * IC + gic * IC_BLOCK + local_id];
DATA8_T blockC01 = (DATA8_T)bias[g * IC + gic * IC_BLOCK + local_id];
INPUT_TYPE8 blockC00 = (INPUT_TYPE8)bias[g * IC + gic * IC_BLOCK + local_id];
INPUT_TYPE8 blockC01 = (INPUT_TYPE8)bias[g * IC + gic * IC_BLOCK + local_id];
#else
DATA8_T blockC00 = 0.0f;
DATA8_T blockC01 = 0.0f;
INPUT_TYPE8 blockC00 = INPUT0_VAL_ZERO;
INPUT_TYPE8 blockC01 = INPUT0_VAL_ZERO;
#endif

wei += gic * KD * KH * KW * OC_BLOCK * IC_BLOCK
Expand Down Expand Up @@ -111,13 +122,13 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
#endif
if (oh >= OH || ow >= OW) continue;

const __global DATA_T *diff_dst1 = diff_dst
const __global INPUT0_TYPE *diff_dst1 = diff_dst
+ ow * OC_BLOCK * MB_BLOCK
+ oh * OW_FULL * OC_BLOCK * MB_BLOCK;
#if CASE_3D
diff_dst1 += od * OH_FULL * OW_FULL * OC_BLOCK * MB_BLOCK;
#endif
const __global DATA_T *wei1 = wei
const __global FILTER_TYPE *wei1 = wei
#if CASE_3D
+ kd * KH * KW * OC_BLOCK * IC_BLOCK
#endif
Expand Down Expand Up @@ -148,44 +159,30 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
#if SW != 1 || SH != 1 || SD != 1 || PH != 0 || PW != 0 || PD != 0
if (do_ker) {
#endif
const __global DATA_T *diff_dst1 = diff_dst
const __global INPUT0_TYPE *diff_dst1 = diff_dst
+ ow * OC_BLOCK * MB_BLOCK + oh * OW_FULL * OC_BLOCK * MB_BLOCK;
#if CASE_3D
diff_dst1 += od * OH_FULL * OW_FULL * OC_BLOCK * MB_BLOCK;
#endif
const __global DATA_T *wei1 = wei;
const __global FILTER_TYPE *wei1 = wei;
#endif

#define LOAD_DIFF_DST(_block, _diff_dst, mb_chunk) \
{ \
(_block) = AS_DATA8_T( \
BLOCK_READ8((const __global BLOCK_DATA_T *)((_diff_dst) \
+ (mb_chunk)*OC_BLOCK))); \
}

#define SAVE_SRC_DIFF(_block, _diff_src, mb_chunk) \
{ \
BLOCK_WRITE8((const __global BLOCK_DATA_T *)(&( \
_diff_src)[(mb_chunk)*IC_BLOCK]), \
AS_BLOCK_DATA8_T((_block))); \
}

#if DT_F32
#define TRANSPOSE_8(_block, _col) \
(DATA8_T)(intel_sub_group_shuffle(_block, _col))
(intel_sub_group_shuffle(_block, _col))
#else
#define TRANSPOSE_8(_block, _col) \
(DATA8_T)(intel_sub_group_shuffle(_block[0], _col), \
intel_sub_group_shuffle(_block[1], _col), \
intel_sub_group_shuffle(_block[2], _col), \
intel_sub_group_shuffle(_block[3], _col), \
intel_sub_group_shuffle(_block[4], _col), \
intel_sub_group_shuffle(_block[5], _col), \
intel_sub_group_shuffle(_block[6], _col), \
intel_sub_group_shuffle(_block[7], _col))
(intel_sub_group_shuffle(_block[0], _col), \
intel_sub_group_shuffle(_block[1], _col), \
intel_sub_group_shuffle(_block[2], _col), \
intel_sub_group_shuffle(_block[3], _col), \
intel_sub_group_shuffle(_block[4], _col), \
intel_sub_group_shuffle(_block[5], _col), \
intel_sub_group_shuffle(_block[6], _col), \
intel_sub_group_shuffle(_block[7], _col))
#endif

#define FMA8(a, b, c) fma((DATA8_T)(a), (DATA8_T)b, (DATA8_T)c)
#define FMA8(a, b, c) fma((FMA_ARG_TYPE8)(a), (FMA_ARG_TYPE8)b, (FMA_ARG_TYPE8)c)

#define MULTIPLY_BLOCKS_8x8(_result, _blockA, _blockB, _blockB1) \
{ \
Expand All @@ -207,14 +204,10 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
_result = FMA8(_blockB1.s7, TRANSPOSE_8(_blockA, 15), _result); \
}

DATA8_T blockA0, blockA1;
LOAD_DIFF_DST(blockA0, diff_dst1, 0);
LOAD_DIFF_DST(blockA1, diff_dst1, 8);
DATA8_T blockB00 = AS_DATA8_T(
BLOCK_READ8((const __global BLOCK_DATA_T *)wei1));
DATA8_T blockB01 = AS_DATA8_T(
BLOCK_READ8((const __global BLOCK_DATA_T *)(wei1
+ 8 * IC_BLOCK)));
INPUT_TYPE8 blockA0 = DT_INPUT_BLOCK_READ(diff_dst1, 0);
INPUT_TYPE8 blockA1 = DT_INPUT_BLOCK_READ(diff_dst1, 8 * OC_BLOCK);
FILTER_TYPE8 blockB00 = DT_FILTER_BLOCK_READ8(wei1, 0);
FILTER_TYPE8 blockB01 = DT_FILTER_BLOCK_READ8(wei1, 8 * IC_BLOCK);
MULTIPLY_BLOCKS_8x8(blockC00, blockA0, blockB00, blockB01);
MULTIPLY_BLOCKS_8x8(blockC01, blockA1, blockB00, blockB01);

Expand All @@ -232,28 +225,32 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
ocb += OC_BLOCK;
} while (ocb < OC);

__global DATA_T *src_write0 = diff_src + OUTPUT_OFFSET + mb * IC_FULL * G * ID_FULL * IH_FULL * IW_FULL
__global OUTPUT_TYPE *src_write0 = diff_src + OUTPUT_OFFSET + mb * IC_FULL * G * ID_FULL * IH_FULL * IW_FULL
+ gic * ID_FULL * IH_FULL * IW_FULL * IC_BLOCK * MB_BLOCK
+ g * IC * ID_FULL * IH_FULL * IW_FULL * MB_BLOCK
+ id * IH_FULL * IW_FULL * IC_BLOCK * MB_BLOCK + ih * IW_FULL * IC_BLOCK * MB_BLOCK
+ iw * IC_BLOCK * MB_BLOCK;

blockC00 = ACTIVATION(blockC00, ACTIVATION_PARAMS);
blockC01 = ACTIVATION(blockC01, ACTIVATION_PARAMS);
OUTPUT_TYPE8 res0, res1;

#if HAS_FUSED_OPS
{
FUSED_OPS_BLOCK_C00;
blockC00 = FUSED_OPS_RESULT_BLOCK_C00;
res0 = FUSED_OPS_RESULT_BLOCK_C00;
}
{
FUSED_OPS_BLOCK_C01;
blockC01 = FUSED_OPS_RESULT_BLOCK_C01;
res1 = FUSED_OPS_RESULT_BLOCK_C01;
}
#else
res0 = blockC00;
res1 = blockC01;
#endif

SAVE_SRC_DIFF(blockC00, src_write0, 0);
SAVE_SRC_DIFF(blockC01, src_write0, 8);
DT_OUTPUT_BLOCK_WRITE8(src_write0, 0, res0);
DT_OUTPUT_BLOCK_WRITE8(src_write0, 8 * IC_BLOCK, res1);

#endif
#if VER_8OW16C == 1
Expand All @@ -278,7 +275,7 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
const int iw = (ihw % IWB) * IW_BLOCK;

diff_dst += input_offset + mb * OC_FULL * G * OD_FULL * OH_FULL * OW_FULL + g * OC * OD_FULL * OH_FULL * OW_FULL * MB_BLOCK;
DATA_T blockC00[IW_BLOCK] = {0.0f};
INPUT0_TYPE blockC00[IW_BLOCK] = {INPUT0_VAL_ZERO};

#if WITH_BIAS
for (int i = 0; i < IW_BLOCK; i++)
Expand Down Expand Up @@ -307,12 +304,12 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
oh /= SH;
if (oh >= OH) continue;

const __global DATA_T *diff_dst1
const __global INPUT0_TYPE *diff_dst1
= diff_dst + oh * OW_FULL * OC_BLOCK * MB_BLOCK;
#if CASE_3D
diff_dst1 += od * OH_FULL * OW_FULL * OC_BLOCK * MB_BLOCK;
#endif
const __global DATA_T *wei1 = wei
const __global FILTER_TYPE *wei1 = wei
#if CASE_3D
+ kd * KH * KW * OC_BLOCK * IC_BLOCK
#endif
Expand Down Expand Up @@ -341,21 +338,21 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
#if SW != 1 || SH != 1 || SD != 1 || PH != 0 || PW != 0 || PD != 0
if (do_ker) {
#endif
const __global DATA_T *diff_dst1
const __global INPUT0_TYPE *diff_dst1
= diff_dst + oh * OW_FULL * OC_BLOCK * MB_BLOCK;
#if CASE_3D
diff_dst1 += od * OH_FULL * OW_FULL * OC_BLOCK * MB_BLOCK;
#endif
const __global DATA_T *wei1 = wei;
const __global FILTER_TYPE *wei1 = wei;
#endif

int ocb = 0;
do {

#define TRANSPOSE_1(_block, _col) \
(DATA_T)(intel_sub_group_shuffle(_block, _col))
(intel_sub_group_shuffle(_block, _col))

#define FMA1(a, b, c) fma((DATA_T)(a), (DATA_T)b, (DATA_T)c)
#define FMA1(a, b, c) fma((FMA_ARG_TYPE)(a), (FMA_ARG_TYPE)b, (FMA_ARG_TYPE)c)

#define MULTIPLY_BLOCKS_8x8(_result, _blockA, _blockB, _blockB1) \
{ \
Expand All @@ -377,12 +374,9 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
_result = FMA1(_blockB1.s7, TRANSPOSE_1(_blockA, 15), _result); \
}

DATA8_T blockB00 = AS_DATA8_T(
BLOCK_READ8((const __global BLOCK_DATA_T *)wei1));
DATA8_T blockB01 = AS_DATA8_T(
BLOCK_READ8((const __global BLOCK_DATA_T *)(wei1
+ 8 * IC_BLOCK)));
DATA_T blockA[IW_BLOCK];
FILTER_TYPE8 blockB00 = DT_FILTER_BLOCK_READ8(wei1, 0);
FILTER_TYPE8 blockB01 = DT_FILTER_BLOCK_READ8(wei1, 8 * IC_BLOCK);
INPUT0_TYPE blockA[IW_BLOCK];

__attribute__((
opencl_unroll_hint(IW_BLOCK))) // attr:no-format
Expand All @@ -407,9 +401,7 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
blockA[i] = 0.0;
continue;
}
blockA[i] = AS_DATA_T(
BLOCK_READ((const __global BLOCK_DATA_T *)(&(
diff_dst1)[ow * OC_BLOCK])));
blockA[i] = DT_INPUT_BLOCK_READ(diff_dst1, ow * OC_BLOCK);
}

__attribute__((
Expand All @@ -434,7 +426,7 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
#endif
#endif

__global DATA_T *src_write0 = diff_src + output_offset + mb * IC_FULL * G * ID_FULL * IH_FULL * IW_FULL
__global OUTPUT_TYPE *src_write0 = diff_src + output_offset + mb * IC_FULL * G * ID_FULL * IH_FULL * IW_FULL
+ gic * ID_FULL * IH_FULL * IW_FULL * IC_BLOCK * MB_BLOCK
+ g * IC * ID_FULL * IH_FULL * IW_FULL * MB_BLOCK
+ id * IH_FULL * IW_FULL * IC_BLOCK * MB_BLOCK + ih * IW_FULL * IC_BLOCK * MB_BLOCK
Expand All @@ -443,12 +435,14 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
for (int i = 0; i < IW_BLOCK; i++) {
blockC00[i] = ACTIVATION(blockC00[i], ACTIVATION_PARAMS);
if (iw + i >= IW) continue;
OUTPUT_TYPE res;
#if HAS_FUSED_OPS
FUSED_OPS_BLOCK_CI;
blockC00[i] = FUSED_OPS_RESULT_BLOCK_CI;
res = FUSED_OPS_RESULT_BLOCK_CI;
#else
res = blockC00[i];
#endif
BLOCK_WRITE((__global BLOCK_DATA_T *)(&(src_write0)[i * IC_BLOCK]),
AS_BLOCK_DATA_T(blockC00[i]));
DT_OUTPUT_BLOCK_WRITE(src_write0, i * IC_BLOCK, res);
}
#endif
}
Expand Down
Loading