Skip to content

Commit

Permalink
[CPU] Native dynamic shapes support in the Bucketize node (openvinoto…
Browse files Browse the repository at this point in the history
  • Loading branch information
usstq authored Dec 2, 2021
1 parent 0bede45 commit 87ea55f
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 21 deletions.
63 changes: 43 additions & 20 deletions inference-engine/src/mkldnn_plugin/nodes/mkldnn_bucketize_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ using namespace InferenceEngine;

bool MKLDNNBucketizeNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
try {
if (isDynamicNgraphNode(op)) {
errorMessage = "Doesn't support op with dynamic shapes";
return false;
}
const auto bucketsize = std::dynamic_pointer_cast<const ngraph::opset3::Bucketize>(op);
if (!bucketsize) {
errorMessage = "Only opset3 Bucketize operation is supported";
Expand Down Expand Up @@ -49,22 +45,6 @@ MKLDNNBucketizeNode::MKLDNNBucketizeNode(const std::shared_ptr<ngraph::Node>& op

// check one attribute
with_right = bucketsize->get_with_right_bound();

// check dimensions of input tensors
SizeVector input_tensor_dims = op->get_input_shape(INPUT_TENSOR_PORT);
if (input_tensor_dims.size() < 1) {
IE_THROW() << errorPrefix << " has incorrect dimensions of the input.";
}
SizeVector input_bin_dims = op->get_input_shape(INPUT_BINS_PORT);
if (input_bin_dims.size() != 1) {
IE_THROW() << errorPrefix << " has incorrect dimensions of the boundaries tensor.";
}
if (input_bin_dims[0] != 0) {
with_bins = true;
}
num_bin_values = input_bin_dims[0];

num_values = std::accumulate(input_tensor_dims.begin(), input_tensor_dims.end(), size_t(1), std::multiplies<size_t>());
}

void MKLDNNBucketizeNode::initSupportedPrimitiveDescriptors() {
Expand Down Expand Up @@ -192,6 +172,49 @@ void MKLDNNBucketizeNode::execute(mkldnn::stream strm) {
}
}

void MKLDNNBucketizeNode::prepareParams() {
auto& inputTensorMemPtr = getParentEdgeAt(INPUT_TENSOR_PORT)->getMemoryPtr();
auto& inputBinsMemPtr = getParentEdgeAt(INPUT_BINS_PORT)->getMemoryPtr();
auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
IE_THROW() << "Destination memory didn't allocate.";
if (!inputTensorMemPtr || !inputTensorMemPtr->GetPrimitivePtr())
IE_THROW() << "Input tensor didn't allocate.";
if (!inputBinsMemPtr || !inputBinsMemPtr->GetPrimitivePtr())
IE_THROW() << "Input bins didn't allocate.";
if (getSelectedPrimitiveDescriptor() == nullptr)
IE_THROW() << "Preferable primitive descriptor is not set.";

// update with_bins/num_values/num_bin_values
auto input_tensor_dims = inputTensorMemPtr->getStaticDims();
if (input_tensor_dims.size() < 1) {
IE_THROW() << errorPrefix << " has incorrect dimensions of the input.";
}
auto input_bin_dims = inputBinsMemPtr->getStaticDims();
if (input_bin_dims.size() != 1) {
IE_THROW() << errorPrefix << " has incorrect dimensions of the boundaries tensor.";
}
if (input_bin_dims[0] != 0) {
with_bins = true;
}
num_bin_values = input_bin_dims[0];

num_values =
std::accumulate(input_tensor_dims.begin(), input_tensor_dims.end(), size_t(1), std::multiplies<size_t>());
}

void MKLDNNBucketizeNode::createPrimitive() {
if (inputShapesDefined()) {
if (needPrepareParams())
prepareParams();
updateLastInputDims();
}
}

std::vector<VectorDims> MKLDNNBucketizeNode::shapeInfer() const {
return {getParentEdgesAtPort(0)[0]->getMemory().getStaticDims()};
}

template <typename T, typename T_BOUNDARIES, typename T_IND>
void MKLDNNBucketizeNode::bucketize() {
const auto *input_data = reinterpret_cast<const T *>(getParentEdgeAt(0)->getMemoryPtr()->GetPtr());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@ class MKLDNNBucketizeNode : public MKLDNNNode {

void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override;
void createPrimitive() override {};
void createPrimitive() override;
void execute(mkldnn::stream strm) override;
bool created() const override;
void executeDynamicImpl(mkldnn::stream strm) override {
execute(strm);
}
void prepareParams() override;
std::vector<VectorDims> shapeInfer() const override;

static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "functional_test_utils/ov_tensor_utils.hpp"
#include "ngraph_functions/builders.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "test_utils/cpu_test_utils.hpp"

using namespace InferenceEngine;
using namespace CPUTestUtils;
using namespace ngraph::opset3;
using namespace ov::test;

namespace CPULayerTestsDefinitions {

using BucketizeCPUParamsTuple = std::tuple<InputShape, // Data shape
InputShape, // Buckets shape
bool, // Right edge of interval
ElementType, // Data input precision
ElementType, // Buckets input precision
ElementType // Output precision
>;

class BucketizeLayerCPUTest : public testing::WithParamInterface<BucketizeCPUParamsTuple>,
virtual public SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<BucketizeCPUParamsTuple>& obj) {
InputShape dataShape;
InputShape bucketsShape;
bool with_right_bound;
ElementType inDataPrc;
ElementType inBucketsPrc;
ElementType netPrc;

std::tie(dataShape, bucketsShape, with_right_bound, inDataPrc, inBucketsPrc, netPrc) = obj.param;

std::ostringstream result;
result << "IS=" << CommonTestUtils::partialShape2str({dataShape.first}) << "_"
<< CommonTestUtils::partialShape2str({bucketsShape.first}) << "_";

result << "TS=";
for (const auto& item : dataShape.second) {
result << CommonTestUtils::vec2str(item) << "_";
}
result << "BS=";
for (const auto& item : bucketsShape.second) {
result << CommonTestUtils::vec2str(item) << "_";
}

result << "with_right_bound=" << with_right_bound;
result << "inDataPrc=" << inDataPrc << "_";
result << "inBucketsPrc=" << inBucketsPrc << "_";
result << "netPrc=" << netPrc << "_";
return result.str();
}

void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override {
inputs.clear();
const auto& funcInputs = function->inputs();

auto data_size = shape_size(targetInputStaticShapes[0]);
ov::runtime::Tensor tensorData = ov::test::utils::create_and_fill_tensor(funcInputs[0].get_element_type(),
targetInputStaticShapes[0],
data_size * 5,
0,
10,
7235346);

ov::runtime::Tensor tensorBucket =
ov::test::utils::create_and_fill_tensor_unique_sequence(funcInputs[1].get_element_type(),
targetInputStaticShapes[1],
0,
10,
8234231);

inputs.insert({funcInputs[0].get_node_shared_ptr(), tensorData});
inputs.insert({funcInputs[1].get_node_shared_ptr(), tensorBucket});
}

protected:
void SetUp() override {
InputShape dataShape;
InputShape bucketsShape;
bool with_right_bound;
ElementType inDataPrc;
ElementType inBucketsPrc;
ElementType netPrc;

targetDevice = CommonTestUtils::DEVICE_CPU;
std::tie(dataShape, bucketsShape, with_right_bound, inDataPrc, inBucketsPrc, netPrc) = this->GetParam();
init_input_shapes({dataShape, bucketsShape});

auto data = std::make_shared<ngraph::op::Parameter>(inDataPrc, inputDynamicShapes[0]);
data->set_friendly_name("a_data");
auto buckets = std::make_shared<ngraph::op::Parameter>(inBucketsPrc, inputDynamicShapes[1]);
buckets->set_friendly_name("b_buckets");
auto bucketize = std::make_shared<ngraph::op::v3::Bucketize>(data, buckets, netPrc, with_right_bound);
function = std::make_shared<ngraph::Function>(std::make_shared<ngraph::opset1::Result>(bucketize),
ngraph::ParameterVector{data, buckets},
"Bucketize");
}
};

TEST_P(BucketizeLayerCPUTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
}

namespace {

const std::vector<ov::test::InputShape> dataShapesDynamic = {
{{ngraph::Dimension(1, 10), ngraph::Dimension::dynamic(), ngraph::Dimension::dynamic()},
{{1, 20, 20}, {3, 16, 16}, {10, 16, 16}}},
{{ngraph::Dimension(1, 10), 3, 50, 50}, {{1, 3, 50, 50}, {2, 3, 50, 50}, {10, 3, 50, 50}}}};

const std::vector<ov::test::InputShape> bucketsShapesDynamic = {{{ngraph::Dimension::dynamic()}, {{5}, {20}, {100}}}};

const std::vector<ov::test::ElementType> inPrc = {ov::element::f32, ov::element::i64, ov::element::i32};
const std::vector<ov::test::ElementType> outPrc = {ov::element::i64, ov::element::i32};

const auto test_Bucketize_right_edge_Dynamic = ::testing::Combine(::testing::ValuesIn(dataShapesDynamic),
::testing::ValuesIn(bucketsShapesDynamic),
::testing::Values(true),
::testing::ValuesIn(inPrc),
::testing::ValuesIn(inPrc),
::testing::ValuesIn(outPrc));

const auto test_Bucketize_left_edge_Dynamic = ::testing::Combine(::testing::ValuesIn(dataShapesDynamic),
::testing::ValuesIn(bucketsShapesDynamic),
::testing::Values(false),
::testing::ValuesIn(inPrc),
::testing::ValuesIn(inPrc),
::testing::ValuesIn(outPrc));

INSTANTIATE_TEST_SUITE_P(smoke_TestsBucketize_right_Dynamic,
BucketizeLayerCPUTest,
test_Bucketize_right_edge_Dynamic,
BucketizeLayerCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_TestsBucketize_left_Dynamic,
BucketizeLayerCPUTest,
test_Bucketize_left_edge_Dynamic,
BucketizeLayerCPUTest::getTestCaseName);

} // namespace
} // namespace CPULayerTestsDefinitions
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,54 @@ fill_data_random(T *pointer, std::size_t size, const uint32_t range = 10, int32_
}
}

/** @brief Fill a memory area with a sorted sequence of unique elements randomly generated.
*
* This function generates and fills a blob of a certain precision, with a
* sorted sequence of unique elements.
*
* @param rawBlobDataPtr pointer to destination memory area
* @param size number of elements in destination memory
* @param range Values range
* @param start_from Value from which range should start
* @param k Resolution of floating point numbers.
* - With k = 1 every random number will be basically integer number.
* - With k = 2 numbers resolution will 1/2 so outputs only .0 or .50
* - With k = 4 numbers resolution will 1/4 so outputs only .0 .25 .50 0.75 and etc.
* @param seed seed of random generator
*/
template <typename T>
void inline fill_random_unique_sequence(T* rawBlobDataPtr,
std::size_t size,
uint64_t range,
int64_t start_from = 0,
const int64_t k = 1,
const int32_t seed = 1) {
if (start_from < 0 && !std::is_signed<T>::value) {
start_from = 0;
}

if (range < size) {
range = size * 2;
}

std::mt19937 generator(seed);
std::uniform_int_distribution<int64_t> dist(k * start_from, k * (start_from + range));

std::set<T> elems;
while (elems.size() != size) {
auto value = static_cast<float>(dist(generator));
value /= static_cast<float>(k);
if (std::is_same<ngraph::float16, T>::value) {
elems.insert(static_cast<T>(ngraph::float16(value).to_bits()));
} else if (std::is_same<ngraph::bfloat16, T>::value) {
elems.insert(static_cast<T>(ngraph::bfloat16(value).to_bits()));
} else {
elems.insert(static_cast<T>(value));
}
}
std::copy(elems.begin(), elems.end(), rawBlobDataPtr);
}

/** @brief Fill blob with random data.
*
* @param blob Target blob
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ ov::runtime::Tensor create_and_fill_tensor(
const int32_t resolution = 1,
const int seed = 1);

ov::runtime::Tensor create_and_fill_tensor_unique_sequence(
const ov::element::Type element_type,
const ov::Shape& shape,
const int32_t start_from = 0,
const int32_t resolution = 1,
const int seed = 1);

void compare(
const ov::runtime::Tensor &expected,
const ov::runtime::Tensor &actual,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,54 @@ ov::runtime::Tensor create_and_fill_tensor(
return tensor;
}

ov::runtime::Tensor create_and_fill_tensor_unique_sequence(const ov::element::Type element_type,
const ov::Shape& shape,
const int32_t start_from,
const int32_t resolution,
const int seed) {
auto tensor = ov::runtime::Tensor{element_type, shape};
auto range = shape_size(shape) * 2;
#define CASE(X) \
case X: \
::CommonTestUtils::fill_random_unique_sequence(tensor.data<element_type_traits<X>::value_type>(), \
shape_size(shape), \
range, \
start_from, \
resolution, \
seed); \
break;

switch (element_type) {
CASE(ov::element::Type_t::boolean)
CASE(ov::element::Type_t::i8)
CASE(ov::element::Type_t::i16)
CASE(ov::element::Type_t::i32)
CASE(ov::element::Type_t::i64)
CASE(ov::element::Type_t::u8)
CASE(ov::element::Type_t::u16)
CASE(ov::element::Type_t::u32)
CASE(ov::element::Type_t::u64)
CASE(ov::element::Type_t::bf16)
CASE(ov::element::Type_t::f16)
CASE(ov::element::Type_t::f32)
CASE(ov::element::Type_t::f64)
case ov::element::Type_t::u1:
case ov::element::Type_t::i4:
case ov::element::Type_t::u4:
::CommonTestUtils::fill_random_unique_sequence(static_cast<uint8_t*>(tensor.data()),
tensor.get_byte_size(),
range,
start_from,
resolution,
seed);
break;
default:
OPENVINO_UNREACHABLE("Unsupported element type: ", element_type);
}
#undef CASE
return tensor;
}

template<typename ExpectedT, typename ActualT>
void compare(const ov::runtime::Tensor& expected,
const ov::runtime::Tensor& actual,
Expand Down

0 comments on commit 87ea55f

Please sign in to comment.