Skip to content

Commit

Permalink
[IE CLDNN] Update kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-chaiko committed Nov 18, 2020
1 parent 0816c6d commit 30398d0
Showing 1 changed file with 35 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,20 @@
* limitations under the License.
*******************************************************************************/

#include "ocl_types.h"
#include "include/fetch.cl"
#include "include/data_types.cl"

#define INPUT_TYPE8 MAKE_VECTOR_TYPE(INPUT0_TYPE, 8)
#define FILTER_TYPE8 MAKE_VECTOR_TYPE(FILTER_TYPE, 8)

#if DT_F16 == 1
#define FMA_ARG_TYPE half
#define FMA_ARG_TYPE8 half8
#else
#define FMA_ARG_TYPE INPUT0_TYPE
#define FMA_ARG_TYPE8 INPUT_TYPE8
#endif

#if ID > 1
#define CASE_3D 1
#else
Expand Down Expand Up @@ -76,11 +86,11 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
diff_dst += input_offset + mb * OC_FULL * G * OD_FULL * OH_FULL * OW_FULL + g * OC * OD_FULL * OH_FULL * OW_FULL * MB_BLOCK;

#if WITH_BIAS
MAKE_VECTOR_TYPE(INPUT0_TYPE, 8) blockC00 = (MAKE_VECTOR_TYPE(INPUT0_TYPE, 8))bias[g * IC + gic * IC_BLOCK + local_id];
MAKE_VECTOR_TYPE(INPUT0_TYPE, 8) blockC01 = (MAKE_VECTOR_TYPE(INPUT0_TYPE, 8))bias[g * IC + gic * IC_BLOCK + local_id];
INPUT_TYPE8 blockC00 = (INPUT_TYPE8)bias[g * IC + gic * IC_BLOCK + local_id];
INPUT_TYPE8 blockC01 = (INPUT_TYPE8)bias[g * IC + gic * IC_BLOCK + local_id];
#else
MAKE_VECTOR_TYPE(INPUT0_TYPE, 8) blockC00 = INPUT0_VAL_ZERO;
MAKE_VECTOR_TYPE(INPUT0_TYPE, 8) blockC01 = INPUT0_VAL_ZERO;
INPUT_TYPE8 blockC00 = INPUT0_VAL_ZERO;
INPUT_TYPE8 blockC01 = INPUT0_VAL_ZERO;
#endif

wei += gic * KD * KH * KW * OC_BLOCK * IC_BLOCK
Expand Down Expand Up @@ -156,36 +166,22 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
const __global FILTER_TYPE *wei1 = wei;
#endif

#define LOAD_DIFF_DST(_block, _diff_dst, mb_chunk) \
{ \
(_block) = AS_DATA8_T( \
BLOCK_READ8((const __global BLOCK_DATA_T *)((_diff_dst) \
+ (mb_chunk)*OC_BLOCK))); \
}

#define SAVE_SRC_DIFF(_block, _diff_src, mb_chunk) \
{ \
BLOCK_WRITE8((const __global BLOCK_DATA_T *)(&( \
_diff_src)[(mb_chunk)*IC_BLOCK]), \
AS_BLOCK_DATA8_T((_block))); \
}

#if DT_F32
#define TRANSPOSE_8(_block, _col) \
(DATA8_T)(intel_sub_group_shuffle(_block, _col))
(intel_sub_group_shuffle(_block, _col))
#else
#define TRANSPOSE_8(_block, _col) \
(DATA8_T)(intel_sub_group_shuffle(_block[0], _col), \
intel_sub_group_shuffle(_block[1], _col), \
intel_sub_group_shuffle(_block[2], _col), \
intel_sub_group_shuffle(_block[3], _col), \
intel_sub_group_shuffle(_block[4], _col), \
intel_sub_group_shuffle(_block[5], _col), \
intel_sub_group_shuffle(_block[6], _col), \
intel_sub_group_shuffle(_block[7], _col))
(intel_sub_group_shuffle(_block[0], _col), \
intel_sub_group_shuffle(_block[1], _col), \
intel_sub_group_shuffle(_block[2], _col), \
intel_sub_group_shuffle(_block[3], _col), \
intel_sub_group_shuffle(_block[4], _col), \
intel_sub_group_shuffle(_block[5], _col), \
intel_sub_group_shuffle(_block[6], _col), \
intel_sub_group_shuffle(_block[7], _col))
#endif

#define FMA8(a, b, c) fma((DATA8_T)(a), (DATA8_T)b, (DATA8_T)c)
#define FMA8(a, b, c) fma((FMA_ARG_TYPE8)(a), (FMA_ARG_TYPE8)b, (FMA_ARG_TYPE8)c)

#define MULTIPLY_BLOCKS_8x8(_result, _blockA, _blockB, _blockB1) \
{ \
Expand All @@ -207,14 +203,10 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
_result = FMA8(_blockB1.s7, TRANSPOSE_8(_blockA, 15), _result); \
}

DATA8_T blockA0, blockA1;
LOAD_DIFF_DST(blockA0, diff_dst1, 0);
LOAD_DIFF_DST(blockA1, diff_dst1, 8);
DATA8_T blockB00 = AS_DATA8_T(
BLOCK_READ8((const __global BLOCK_DATA_T *)wei1));
DATA8_T blockB01 = AS_DATA8_T(
BLOCK_READ8((const __global BLOCK_DATA_T *)(wei1
+ 8 * IC_BLOCK)));
INPUT_TYPE8 blockA0 = DT_INPUT_BLOCK_READ(diff_dst1, 0);
INPUT_TYPE8 blockA1 = DT_INPUT_BLOCK_READ(diff_dst1, 8 * OC_BLOCK);
FILTER_TYPE8 blockB00 = DT_FILTER_BLOCK_READ8(wei1, 0);
FILTER_TYPE8 blockB01 = DT_FILTER_BLOCK_READ8(wei1, 8 * IC_BLOCK);
MULTIPLY_BLOCKS_8x8(blockC00, blockA0, blockB00, blockB01);
MULTIPLY_BLOCKS_8x8(blockC01, blockA1, blockB00, blockB01);

Expand Down Expand Up @@ -252,8 +244,8 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
}
#endif

SAVE_SRC_DIFF(blockC00, src_write0, 0);
SAVE_SRC_DIFF(blockC01, src_write0, 8);
DT_OUTPUT_BLOCK_WRITE8(src_write0, 0, blockC00);
DT_OUTPUT_BLOCK_WRITE8(src_write0, 8 * IC_BLOCK, blockC01);

#endif
#if VER_8OW16C == 1
Expand Down Expand Up @@ -353,9 +345,9 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
do {

#define TRANSPOSE_1(_block, _col) \
(DATA_T)(intel_sub_group_shuffle(_block, _col))
(intel_sub_group_shuffle(_block, _col))

#define FMA1(a, b, c) fma((DATA_T)(a), (DATA_T)b, (DATA_T)c)
#define FMA1(a, b, c) fma((FMA_ARG_TYPE)(a), (FMA_ARG_TYPE)b, (FMA_ARG_TYPE)c)

#define MULTIPLY_BLOCKS_8x8(_result, _blockA, _blockB, _blockB1) \
{ \
Expand All @@ -377,8 +369,8 @@ KERNEL(gen9_common_conv_bwd_data_kernel)(
_result = FMA1(_blockB1.s7, TRANSPOSE_1(_blockA, 15), _result); \
}

MAKE_VECTOR_TYPE(FILTER_TYPE, 8) blockB00 = DT_FILTER_BLOCK_READ8(wei1, 0);
MAKE_VECTOR_TYPE(FILTER_TYPE, 8) blockB01 = DT_FILTER_BLOCK_READ8(wei1, 8 * IC_BLOCK);
FILTER_TYPE8 blockB00 = DT_FILTER_BLOCK_READ8(wei1, 0);
FILTER_TYPE8 blockB01 = DT_FILTER_BLOCK_READ8(wei1, 8 * IC_BLOCK);
INPUT0_TYPE blockA[IW_BLOCK];

__attribute__((
Expand Down

0 comments on commit 30398d0

Please sign in to comment.