Skip to content

Commit

Permalink
Merge pull request openvinotoolkit#8 from EunsooSheenIntel/taylor/per…
Browse files Browse the repository at this point in the history
…mute_opt_new

Add GetBestLwsFromGws in PermuteKernel_tile_8x8
  • Loading branch information
yeonbok authored Feb 17, 2021
2 parents 588fb8d + 283498f commit 868b41c
Showing 1 changed file with 50 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
#include "kernel_selector_utils.h"
#include <string>
#include <functional>
#include <cmath>

#define CEIL_DIV(A, B) ((A + B - 1)/(B))

namespace kernel_selector {
ParamsKey PermuteKernel_tile_8x8::GetSupportedKey() const {
ParamsKey k;
Expand Down Expand Up @@ -159,39 +162,72 @@ JitConstants PermuteKernel_tile_8x8::GetJitConstants(const permute_params& param
return jit;
}

static std::vector<size_t> GetBestLwsFromGws(const std::vector<size_t>& gws, const size_t tile_height, const size_t tile_width)
{
std::vector<size_t> lws{1,1,1};

// SLM size: float32 * tile_height * tile_width * work_items <= 64K
size_t max_num_work_items = std::min(256lu, 65536 / (4 * tile_height * tile_width));

// set lws[0] first
size_t max_divider = static_cast<size_t>(std::sqrt(gws[0])+1);
for (size_t divider = 1; divider <= max_divider; ++divider)
{
if (gws[0] % divider == 0)
{
const size_t lws0 = gws[0] / divider;
if (lws0 <= max_num_work_items)
{
lws[0] = std::max(lws[0], lws0);
}
if (divider <= max_num_work_items)
{
lws[0] = std::max(lws[0], divider);
}
}
}

// set lws[2]
max_num_work_items /= lws[0];
max_divider = static_cast<size_t>(std::sqrt(gws[2])+1);
for (size_t divider = 1; divider <= max_divider; ++divider)
{
if (gws[2] % divider == 0)
{
const size_t lws2 = gws[2] / divider;
if (lws2 <= max_num_work_items)
{
lws[2] = std::max(lws[2], lws2);
}
if (divider <= max_num_work_items)
{
lws[2] = std::max(lws[2], divider);
}
}
}
return lws;
}

CommonDispatchData PermuteKernel_tile_8x8::SetDefault(const permute_params& params) const {
CommonDispatchData dispatchData;
const auto& in = params.inputs[0];
const auto& tile_w = params.tile_w;
const auto& tile_h = params.tile_h;
switch (in.GetLayout()) {
case DataLayout::bfyx:
// for f800, y64, x64
// gws : 64/8, 64, 800
//dispatchData.gws = {in.X().v / tile_w, in.Y().v, (in.Feature().v / tile_h) * in.Batch().v};
//dispatchData.lws = {2, 1, 50}; // TODO

// for f800, y64, x64
// gws : 64/8, 64, 810
dispatchData.gws = {CEIL_DIV(in.X().v , tile_w), in.Y().v, CEIL_DIV(in.Feature().v, tile_h) * in.Batch().v};
dispatchData.lws = {2, 1, 51}; // TODO
break;
case DataLayout::bfzyx:
dispatchData.gws = {CEIL_DIV(in.X().v , tile_w), in.Y().v * in.Z().v, CEIL_DIV(in.Feature().v, tile_h) * in.Batch().v};
dispatchData.lws = {64, 1, 2}; // TODO
// dispatchData.lws = {128, 1, 2}; // TODO
// dispatchData.lws = {64, 1, 2}; // TODO
// dispatchData.lws = {3, 1, 2}; // TODO
break;
case DataLayout::bfwzyx:
dispatchData.gws = {CEIL_DIV(in.X().v , tile_w), in.Y().v * in.Z().v * in.W().v, CEIL_DIV(in.Feature().v, tile_h) * in.Batch().v};
dispatchData.lws = {64, 1, 2}; // TODO
break;
default:
throw std::runtime_error("Unsupported combination\n");
break;
}

dispatchData.lws = GetBestLwsFromGws(dispatchData.gws, tile_h, tile_w);
return dispatchData;
}

Expand Down

0 comments on commit 868b41c

Please sign in to comment.