Skip to content

Commit

Permalink
[GPU] Gather shape agnostic kernel fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman committed Feb 7, 2024
1 parent 67c270f commit 455bc48
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
#include "include/batch_headers/int4_utils.cl"

#ifdef INDEX_DIM
inline uint FUNC(get_positive_index)(int in)
inline uint FUNC(get_positive_index)(OPTIONAL_SHAPE_INFO_ARG int in)
{
if(in < 0)
if (in < 0)
return in + INDEX_DIM;
else
return in;
}
#define INPUT_AXIS_INDEX (uint)FUNC_CALL(get_positive_index)(indices[indices_idx])
#define INPUT_AXIS_INDEX (uint)FUNC_CALL(get_positive_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[indices_idx])
#else
#define INPUT_AXIS_INDEX (uint)(indices[indices_idx])
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,17 @@ JitConstants GatherKernelRef::GetJitConstants(const gather_params& params) const

jit.AddConstant(MakeJitConstant("DICTIONARY_INDEX_ORDER", GetDictionaryIndexOrder(params, GetGatherChannelIndex(params))));
jit.AddConstant(MakeJitConstant("INDICES_INDEX_ORDER", GetIndicesIdxOrder(params, GetGatherChannelIndex(params), GetGatherBatchDim(params))));
if (params.support_neg_ind)
jit.AddConstant(MakeJitConstant("INDEX_DIM", GetGatherMaxIndexDim(params)));

if (!GetGatherIndexDim(params).is_dynamic)
bool dyn_gather_idx_dim = GetGatherIndexDim(params).is_dynamic;
if (params.support_neg_ind) {
if (!dyn_gather_idx_dim) {
jit.AddConstant(MakeJitConstant("INDEX_DIM", GetGatherMaxIndexDim(params)));
} else {
jit.AddConstant(MakeJitConstant("INDEX_DIM", "shape_info[" + std::to_string(GetGatherAxisIndexInShapeInfo(params)) + "]"));
}
}

if (!dyn_gather_idx_dim)
jit.AddConstant(MakeJitConstant("AXIS_DIM", GetGatherMaxIndexDim(params)));

if (params.is_shape_agnostic)
Expand Down

0 comments on commit 455bc48

Please sign in to comment.