Skip to content

Commit

Permalink
[IE CLDNN] TopK registry spill avoiding for sort-by-value mode (openv…
Browse files Browse the repository at this point in the history
  • Loading branch information
lznamens authored and mryzhov committed Dec 15, 2020
1 parent 5d955e4 commit 6785542
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 31 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2018 Intel Corporation
// Copyright (c) 2018-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,6 +29,18 @@ size_t getOperationNumber(const arg_max_min_params& params) {
}
}

size_t getSortSize(const arg_max_min_params& params) {
switch (params.argMaxMinAxis) {
case ArgMaxMinAxis::BATCH: return params.inputs[0].Batch().v;
case ArgMaxMinAxis::FEATURE: return params.inputs[0].Feature().v;
case ArgMaxMinAxis::Z: return params.inputs[0].Z().v;
case ArgMaxMinAxis::Y: return params.inputs[0].Y().v;
case ArgMaxMinAxis::X: return params.inputs[0].X().v;
default:
throw std::invalid_argument("Unsupported axis");
}
}

ParamsKey ArgMaxMinKernelAxis::GetSupportedKey() const {
ParamsKey k;
k.EnableInputDataType(Datatype::F16);
Expand Down Expand Up @@ -72,19 +84,24 @@ KernelsData ArgMaxMinKernelAxis::GetKernelsData(const Params& params, const opti
if (!Validate(params, options)) {
return {};
}

const arg_max_min_params& orgParams = static_cast<const arg_max_min_params&>(params);

DispatchData runInfo;
runInfo.fp16UnitUsed = orgParams.inputs[0].GetDType() == Datatype::F16;

runInfo.gws0 = Align(getOperationNumber(orgParams), 32);
runInfo.gws1 = 1;
runInfo.gws2 = 1;
size_t sort_size = orgParams.argMaxMinSortType == ArgMaxMinSortType::VALUE ? getSortSize(orgParams) : 1;

std::vector<size_t> local, global;
global = { Align(getOperationNumber(orgParams), 32), sort_size, 1 };
local = GetOptimalLocalWorkGroupSizes(global, params.engineInfo);

runInfo.gws0 = global[0];
runInfo.gws1 = global[1];
runInfo.gws2 = global[2];

runInfo.lws0 = 32;
runInfo.lws1 = 1;
runInfo.lws2 = 1;
runInfo.lws0 = local[0];
runInfo.lws1 = local[1];
runInfo.lws2 = local[2];

KernelData kd = KernelData::Default<arg_max_min_params>(params);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,15 @@
#define COMPARE_SIGN <
#define COMPARE_PARTIAL_SIGN >=
#define COMPARE_MERGE_SIGN >
#define COMPARE_PARALLEL_SIGN_1 <=
#define COMPARE_PARALLEL_SIGN_2 <
#define INPUT0_FILL_VAL INPUT0_VAL_MIN
#else
#define COMPARE_SIGN >
#define COMPARE_PARTIAL_SIGN <=
#define COMPARE_MERGE_SIGN <
#define COMPARE_PARALLEL_SIGN_1 >=
#define COMPARE_PARALLEL_SIGN_2 >
#define INPUT0_FILL_VAL INPUT0_VAL_MAX
#endif

Expand Down Expand Up @@ -83,17 +87,19 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
)
{
#include "include/arg_max_min_common.cl"
#if (TOP_K == 1)
#if SORT_BY_VALUE
const uint sort_idx = (uint)get_global_id(1);
#elif TOP_K == 1
iav_type result[TOP_K];
#else
iav_type result[VALUES_NUM], temp_buf[VALUES_NUM];
const uint group_size = TOP_K >= 8 ? TOP_K : 8;
const uint group_num = ((VALUES_NUM - 1) / group_size) + 1;
const uint last_group_size = (VALUES_NUM % group_size > 0) ? (VALUES_NUM % group_size) : group_size;
const uint last_group_offset = (group_num - 1) * group_size;
#endif // (TOP_K == 1)
#endif // SORT_BY_VALUE

uint output_idx = (uint)get_global_id(0);
const uint output_idx = (uint)get_global_id(0);

if (output_idx >= OPERATION_NUM)
return;
Expand Down Expand Up @@ -162,9 +168,35 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
#endif
#endif

// Using simple sorting for (TOP_K == 1)
#if (TOP_K == 1)
// Using parallel sorting for sorting by values
#if SORT_BY_VALUE
uint sort_position = 0;
indices[AXIS] = sort_idx;

iav_type result;
result.value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
result.index = sort_idx;

for (uint i = 0; i < sort_idx; i++) {
indices[AXIS] = i;
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_1 test_value)
sort_position++;
if (sort_position >= TOP_K)
return;
}

for (uint i = sort_idx + 1; i < VALUES_NUM; i++) {
indices[AXIS] = i;
INPUT0_TYPE test_value = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
if (result.value COMPARE_PARALLEL_SIGN_2 test_value)
sort_position++;
if (sort_position >= TOP_K)
return;
}

// Using simple sorting for sorting by indices and when TOP_K == 1
#elif TOP_K == 1
INPUT0_TYPE val = input[FUNC_CALL(get_input_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])];
result[0].index = 0;
result[0].value = val;
Expand Down Expand Up @@ -194,9 +226,8 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
val = INPUT0_FILL_VAL;
}

// Using merge sorting when (TOP_K >= (VALUES_NUM / 2)) or (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING)
// Using merge sorting for sorting by indices and when (TOP_K >= (VALUES_NUM / 2)) or (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING)
#elif ((TOP_K >= (VALUES_NUM / 2)) || (VALUES_NUM < MINIMUM_NUMBER_FOR_PARTIAL_SORTING))

for (uint i = 0; i < VALUES_NUM / 8; i++) {
uint index_offset = i * 8;
indices[AXIS] = result[index_offset].index = index_offset;
Expand Down Expand Up @@ -245,9 +276,8 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input
}
}

// In other cases using mixed partial/merge sorting
#else // (TOP_K == 1)

// In other cases for sorting by indices using mixed partial/merge sorting
#else // SORT_BY_VALUE
for (uint i = 0; i < VALUES_NUM / 8; i++) {
uint index_offset = i * 8;
indices[AXIS] = temp_buf[index_offset].index = index_offset;
Expand Down Expand Up @@ -365,41 +395,56 @@ KERNEL(arg_max_min_modified)(const __global INPUT0_TYPE* input

result[i] = merge_buf;
}
#endif // SORT_BY_VALUE

#endif // (TOP_K == 1)
#if SORT_BY_VALUE
indices[AXIS] = sort_position;
#ifdef TOP_K_ORDER
output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.value);
#else
output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result.index);
#endif
#ifdef SECOND_OUTPUT_EXIST
#ifdef TOP_K_ORDER
second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.index);
#else
second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result.value);
#endif
#endif

#else // SORT_BY_VALUE
for (uint top_k = 0; top_k < TOP_K; ++top_k) {
#ifdef SORT_BY_VALUE
indices[AXIS] = top_k;
#endif
#ifdef SORT_BY_INDEX
uint out_position = 0;

for (uint i = 0; i < TOP_K; ++i) {
if (i == top_k)
continue;
if (result[i].index < result[top_k].index)
out_position++;
}

indices[AXIS] = out_position;
#endif
#ifdef TOP_K_ORDER
output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].value);
output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].value);
#else
output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].index);
output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_OUTPUT_TYPE(result[top_k].index);
#endif
#ifdef SECOND_OUTPUT_EXIST
#ifdef TOP_K_ORDER
second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].index);
#else
second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].value);
#endif
#ifdef TOP_K_ORDER
second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].index);
#else
second_output[FUNC_CALL(get_output_offset)(indices[0], indices[1], indices[2], indices[3], indices[4])] = TO_INPUT1_TYPE(result[top_k].value);
#endif
#endif
}
#endif
}

#undef COMPARE_SIGN
#undef COMPARE_PARTIAL_SIGN
#undef COMPARE_MERGE_SIGN
#undef COMPARE_PARALLEL_SIGN_1
#undef COMPARE_PARALLEL_SIGN_2
#undef INPUT0_FILL_VAL
#undef AXIS
#undef VALUES_NUM
Expand Down

0 comments on commit 6785542

Please sign in to comment.