Skip to content

Commit

Permalink
[GPU] Added fp16 support for GatherTree (openvinotoolkit#15983)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman authored and andrei-cv committed Mar 21, 2023
1 parent 32ee20a commit 04ae165
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 3 deletions.
7 changes: 6 additions & 1 deletion src/plugins/intel_gpu/src/graph/impls/ocl/gather_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ struct gather_tree_impl : typed_primitive_impl_ocl<gather_tree> {

namespace detail {
attach_gather_tree_impl::attach_gather_tree_impl() {
auto types = {data_types::i32, data_types::f32};
auto types = {
data_types::f32,
data_types::f16,
data_types::i32
};

auto formats = {
format::yxfb,
format::bfyx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ KERNEL(gather_tree_gpu_ref)(
}

for (int parent = beam; time >= 0; time--) {
output[OUTPUT_GET_INDEX(time, batch, beam, 0)] = step_input[INPUT0_GET_INDEX(time, batch, parent, 0)];
parent = parent_input[INPUT1_GET_INDEX(time, batch, parent, 0)];
output[OUTPUT_GET_INDEX(time, batch, beam, 0)] = TO_OUTPUT_TYPE(step_input[INPUT0_GET_INDEX(time, batch, parent, 0)]);
parent = (int)parent_input[INPUT1_GET_INDEX(time, batch, parent, 0)];
}
bool finished = false;
for (int time = 0; time < max_sequence_in_beam; time++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@ ParamsKey GatherTreeKernelRef::GetSupportedKey() const {

k.EnableInputDataType(Datatype::INT32);
k.EnableOutputDataType(Datatype::INT32);

k.EnableInputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::F32);

k.EnableInputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F16);

k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace {

const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
InferenceEngine::Precision::I32
};

Expand Down

0 comments on commit 04ae165

Please sign in to comment.