From 12e5706719be3df822928482f72d1900d7312535 Mon Sep 17 00:00:00 2001 From: Aleksandr Pertovsky Date: Thu, 29 Apr 2021 18:31:55 +0300 Subject: [PATCH] Validate missing precisions --- inference-engine/src/mkldnn_plugin/nodes/mkldnn_roll_node.cpp | 3 +-- .../cpu/shared_tests_instances/single_layer_tests/roll.cpp | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roll_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roll_node.cpp index 1716fb43d29ac7..f2970d8145a753 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roll_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roll_node.cpp @@ -35,7 +35,7 @@ MKLDNNRollNode::MKLDNNRollNode(const InferenceEngine::CNNLayerPtr& layer, const const auto &dataShape = dataTensor.getDims(); const auto &dataPrecision = dataTensor.getPrecision(); - if (!MKLDNNPlugin::one_of(dataPrecision, Precision::I8, Precision::U8, Precision::I16, Precision::I32, Precision::FP32, Precision::I64)) { + if (!MKLDNNPlugin::one_of(dataPrecision, Precision::I8, Precision::U8, Precision::I16, Precision::I32, Precision::FP32, Precision::I64, Precision::BF16)) { IE_THROW() << layerErrorPrefix << " has unsupported 'data' input precision: " << dataPrecision.name(); } if (dataShape.size() < 1) { @@ -95,7 +95,6 @@ void MKLDNNRollNode::initSupportedPrimitiveDescriptors() { InferenceEngine::Precision precision = inputData->getPrecision(); - auto dataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision); auto srcDims = getParentEdgeAt(0)->getDims(); diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/roll.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/roll.cpp index d60aa3ad3708dd..d7a2e5ea6f7f6f 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/roll.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/roll.cpp @@ -17,6 +17,8 @@ const std::vector inputPrecision = { InferenceEngine::Precision::I16, InferenceEngine::Precision::I32, InferenceEngine::Precision::FP32, + InferenceEngine::Precision::I64, + InferenceEngine::Precision::BF16 }; const auto testCase2DZeroShifts = ::testing::Combine(