Skip to content

Commit

Permalink
[IE CLDNN] Add Select int32/int16 input support (openvinotoolkit#5877)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman authored and rnugmanx committed Aug 26, 2021
1 parent 1b0fea4 commit 5035e16
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ using namespace LayerTestsDefinitions;

const std::vector<InferenceEngine::Precision> inputPrecision = {
InferenceEngine::Precision::U8,
InferenceEngine::Precision::FP16,
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::I16,
InferenceEngine::Precision::FP32
InferenceEngine::Precision::I32
};

const std::vector<std::vector<std::vector<size_t>>> noneShapes = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,21 @@ JitConstants SelectKernelBase::GetJitConstantsCommon(const select_params& params
// f32, f32, u8
// f16, f16, i8
// f16, f16, u8
// i32, i32, i8
// i32, i32, u8
// i16, i16, i8
// i16, i16, u8
} else {
absType = "abs";
}

// f32, f32, x
if (params.inputs[1].GetDType() == Datatype::F32) {
// i32, i32, x
if (params.inputs[1].GetDType() == Datatype::F32 || params.inputs[1].GetDType() == Datatype::INT32) {
destType = "int";
// f16, f16, x
} else if (params.inputs[1].GetDType() == Datatype::F16) {
// i16, i16, x
} else if (params.inputs[1].GetDType() == Datatype::F16 || params.inputs[1].GetDType() == Datatype::INT16) {
destType = "short";
// i8, i8, f32
// i8, i8, f16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@ ParamsKey SelectKernelRef::GetSupportedKey() const {
k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT8);
k.EnableInputDataType(Datatype::UINT8);
k.EnableInputDataType(Datatype::INT16);
k.EnableInputDataType(Datatype::INT32);

k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::INT8);
k.EnableOutputDataType(Datatype::UINT8);
k.EnableOutputDataType(Datatype::INT16);
k.EnableOutputDataType(Datatype::INT32);

k.EnableInputLayout(DataLayout::bfyx);
k.EnableInputLayout(DataLayout::yxfb);
Expand Down

0 comments on commit 5035e16

Please sign in to comment.