Skip to content

Commit

Permalink
Revise GatherTree reference implementation (#7275)
Browse files Browse the repository at this point in the history
* Add visitor api test

* Review ngraph op shell with type_prop tests

* Add op to list of trusted operations

* Change name of struct with information of inputs

* Add include of array data structure to fix windowds compilation error

* Add template plugin test class

* Remove usage of CoordinateTransform index function call from reference implementation

* Rename SLT test suite

* Add template plugin unit test

* Add serialization SLTs

* Add indentation on GatherTreeParams class data members
ggalieroc authored Sep 10, 2021
1 parent 288a763 commit deeb964
Showing 8 changed files with 506 additions and 93 deletions.
100 changes: 100 additions & 0 deletions docs/template_plugin/tests/functional/op_reference/gather_tree.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include <ie_core.hpp>
#include <ie_ngraph_utils.hpp>
#include <limits>
#include <algorithm>
#include <ngraph/ngraph.hpp>
#include <shared_test_classes/base/layer_test_utils.hpp>

#include "base_reference_test.hpp"

using namespace reference_tests;
using namespace ngraph;
using namespace InferenceEngine;

namespace {
struct GatherTreeParams {
template <class IN_ET>
GatherTreeParams(const ngraph::Shape inShape, std::vector<IN_ET> stepIds, const std::vector<IN_ET> parentIds,
const std::vector<IN_ET> maxSeqLen, const std::vector<IN_ET> endToken, std::vector<IN_ET> output) :
stepIdsTensor(inShape, element::from<IN_ET>(), stepIds), parentIdsTensor(inShape, element::from<IN_ET>(), parentIds),
maxSeqLenTensor(ngraph::Shape{inShape[1]}, element::from<IN_ET>(), maxSeqLen), endTokenTensor(ngraph::Shape{}, element::from<IN_ET>(), endToken),
expectedTensor(inShape, element::from<IN_ET>(), output) {}
Tensor stepIdsTensor;
Tensor parentIdsTensor;
Tensor maxSeqLenTensor;
Tensor endTokenTensor;
Tensor expectedTensor;
};

class ReferenceGatherTreeTest : public testing::TestWithParam<GatherTreeParams>, public CommonReferenceTest {
public:
void SetUp() override {
auto params = GetParam();
function = CreateFunction(params);
inputData = {params.stepIdsTensor.data, params.parentIdsTensor.data, params.maxSeqLenTensor.data, params.endTokenTensor.data};
refOutData = {params.expectedTensor.data};
}
static std::string getTestCaseName(const testing::TestParamInfo<GatherTreeParams>& obj) {
auto param = obj.param;
std::ostringstream result;
result << "iType=" << param.stepIdsTensor.type << "_";
result << "iShape=" << param.stepIdsTensor.shape;
return result.str();
}

private:
static std::shared_ptr<Function> CreateFunction(const GatherTreeParams& params) {
const auto stepIds = std::make_shared<op::Parameter>(params.stepIdsTensor.type, params.stepIdsTensor.shape);
const auto parentIds = std::make_shared<op::Parameter>(params.parentIdsTensor.type, params.parentIdsTensor.shape);
const auto maxSeqLen = std::make_shared<op::Parameter>(params.maxSeqLenTensor.type, params.maxSeqLenTensor.shape);
const auto endToken = std::make_shared<op::Parameter>(params.endTokenTensor.type, params.endTokenTensor.shape);
const auto gatherTree = std::make_shared<op::v1::GatherTree>(stepIds, parentIds, maxSeqLen, endToken);
return std::make_shared<Function>(NodeVector {gatherTree}, ParameterVector {stepIds, parentIds, maxSeqLen, endToken});
}
};

TEST_P(ReferenceGatherTreeTest, CompareWithRefs) {
Exec();
}

template <element::Type_t IN_ET>
std::vector<GatherTreeParams> generateGatherTreeParams() {
using T = typename element_type_traits<IN_ET>::value_type;
std::vector<GatherTreeParams> gatherTreeParams {
GatherTreeParams(Shape{4, 1, 3},
std::vector<T>{1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1},
std::vector<T>{0, 0, 0, 0, 1, 1, 2, 1, 2, -1, -1, -1},
std::vector<T>{3},
std::vector<T>{10},
std::vector<T>{2, 2, 2, 6, 5, 6, 7, 8, 9, 10, 10, 10}),
GatherTreeParams(Shape{2, 2, 2},
std::vector<T>{1, 2, 3, 4, 5, 6, 7, 8},
std::vector<T>{0, 0, 0, 0, 0, 0, 0, 0},
std::vector<T>{2, 4},
std::vector<T>{0},
std::vector<T>{1, 1, 3, 3, 5, 6, 7, 8})
};
return gatherTreeParams;
}

std::vector<GatherTreeParams> generateGatherTreeCombinedParams() {
const std::vector<std::vector<GatherTreeParams>> gatherTreeTypeParams {
generateGatherTreeParams<element::Type_t::f32>(),
generateGatherTreeParams<element::Type_t::i32>()};
std::vector<GatherTreeParams> combinedParams;

for (const auto& params : gatherTreeTypeParams) {
combinedParams.insert(combinedParams.end(), params.begin(), params.end());
}
return combinedParams;
}

INSTANTIATE_TEST_SUITE_P(smoke_GatherTree_With_Hardcoded_Refs, ReferenceGatherTreeTest,
testing::ValuesIn(generateGatherTreeCombinedParams()), ReferenceGatherTreeTest::getTestCaseName);
} // namespace
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <vector>

#include "shared_test_classes/single_layer/gather_tree.hpp"
#include "common_test_utils/test_constants.hpp"

using namespace LayerTestsDefinitions;

namespace {

TEST_P(GatherTreeLayerTest, Serialize) {
Serialize();
}

const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::I32
};

const std::vector<std::vector<size_t>> inputShapes = { {5, 1, 10}, {1, 1, 10}, {20, 1, 10}, {20, 20, 10} };

const std::vector<ngraph::helpers::InputLayerType> secondaryInputTypes = {
ngraph::helpers::InputLayerType::CONSTANT,
ngraph::helpers::InputLayerType::PARAMETER
};

INSTANTIATE_TEST_SUITE_P(smoke_GatherTree_Serialization, GatherTreeLayerTest,
::testing::Combine(
::testing::ValuesIn(inputShapes),
::testing::ValuesIn(secondaryInputTypes),
::testing::ValuesIn(netPrecisions),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
GatherTreeLayerTest::getTestCaseName);
} // namespace
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ const std::vector<ngraph::helpers::InputLayerType> secondaryInputTypes = {
ngraph::helpers::InputLayerType::PARAMETER
};

INSTANTIATE_TEST_SUITE_P(Basic_smoke, GatherTreeLayerTest,
INSTANTIATE_TEST_SUITE_P(smoke_GatherTree, GatherTreeLayerTest,
::testing::Combine(
::testing::ValuesIn(inputShapes),
::testing::ValuesIn(secondaryInputTypes),
33 changes: 19 additions & 14 deletions ngraph/core/reference/src/runtime/reference/gather_tree.cpp
Original file line number Diff line number Diff line change
@@ -72,11 +72,12 @@ void runtime::reference::gather_tree(const char* step_ids,
throw ngraph_error("max_seq_len must have size of BATCH_SIZE");
}

NGRAPH_SUPPRESS_DEPRECATED_START
ngraph::CoordinateTransform cordinate_transform(step_ids_shape);
const auto in_strides = row_major_strides(step_ids_shape);
ngraph::CoordinateTransformBasic cordinate_transform(step_ids_shape);

for (const auto& coord : cordinate_transform) {
memcpy(out + cordinate_transform.index(coord) * elem_size, end_token, elem_size);
const auto out_idx = std::inner_product(coord.begin(), coord.end(), in_strides.begin(), 0);
memcpy(out + out_idx * elem_size, end_token, elem_size);
}

for (size_t batch = 0; batch < batch_size; ++batch) {
@@ -87,31 +88,35 @@ void runtime::reference::gather_tree(const char* step_ids,
continue;
}

auto offset = cordinate_transform.index({max_seq_in_beam - 1, batch, beam}) * elem_size;

const auto coord = Coordinate({max_seq_in_beam - 1, batch, beam});
const auto offset = std::inner_product(coord.begin(), coord.end(), in_strides.begin(), 0) * elem_size;
memcpy(out + offset, step_ids + offset, elem_size);

size_t parent = _asIndex(parent_ids + offset, element_type);

for (size_t level = max_seq_in_beam - 1; level-- > 0;) {
memcpy(out + cordinate_transform.index({level, batch, beam}) * elem_size,
step_ids + cordinate_transform.index({level, batch, parent}) * elem_size,
elem_size);
const auto coord_beam = Coordinate({level, batch, beam});
const auto out_idx = std::inner_product(coord_beam.begin(), coord_beam.end(), in_strides.begin(), 0);

const auto coord_parent = Coordinate({level, batch, parent});
const auto step_ids_idx =
std::inner_product(coord_parent.begin(), coord_parent.end(), in_strides.begin(), 0);

memcpy(out + out_idx * elem_size, step_ids + step_ids_idx * elem_size, elem_size);

parent =
_asIndex(parent_ids + cordinate_transform.index({level, batch, parent}) * elem_size, element_type);
parent = _asIndex(parent_ids + step_ids_idx * elem_size, element_type);
}

bool finished = false;
for (size_t time = 0; time < max_seq_in_beam; ++time) {
const auto out_coord = Coordinate({time, batch, beam});
const auto out_idx = std::inner_product(out_coord.begin(), out_coord.end(), in_strides.begin(), 0);
if (finished) {
memcpy(out + cordinate_transform.index({time, batch, beam}) * elem_size, end_token, elem_size);
} else if (_asIndex(out + cordinate_transform.index({time, batch, beam}) * elem_size, element_type) ==
_asIndex(end_token, element_type)) {
memcpy(out + out_idx * elem_size, end_token, elem_size);
} else if (_asIndex(out + out_idx * elem_size, element_type) == _asIndex(end_token, element_type)) {
finished = true;
}
}
}
}
NGRAPH_SUPPRESS_DEPRECATED_END
}
75 changes: 54 additions & 21 deletions ngraph/core/src/op/gather_tree.cpp
Original file line number Diff line number Diff line change
@@ -33,35 +33,68 @@ bool ngraph::op::v1::GatherTree::visit_attributes(AttributeVisitor& visitor) {

void op::v1::GatherTree::validate_and_infer_types() {
NGRAPH_OP_SCOPE(v1_GatherTree_validate_and_infer_types);
const auto& step_ids_rank = get_input_partial_shape(0);
const auto& parent_idx_rank = get_input_partial_shape(1);
const auto& max_seq_len_rank = get_input_partial_shape(2);
const auto& end_token_rank = get_input_partial_shape(3);

const auto& step_ids_et = get_input_element_type(0);
const auto& parent_idx_et = get_input_element_type(1);
const auto& max_seq_len_et = get_input_element_type(2);
const auto& end_token_et = get_input_element_type(3);

element::Type result_et;
NODE_VALIDATION_CHECK(this,
step_ids_rank.rank().is_dynamic() || step_ids_rank.rank().get_length() == 3,
"step_ids input rank must equal to 3 (step_ids rank: ",
step_ids_rank.rank().get_length(),
element::Type::merge(result_et, step_ids_et, parent_idx_et) &&
element::Type::merge(result_et, result_et, max_seq_len_et) &&
element::Type::merge(result_et, result_et, end_token_et),
"Inputs must have the same element type. Got: step_ids (",
step_ids_et,
"), parent_idx_et (",
parent_idx_et,
"), max_seq_len (",
max_seq_len_et,
"), end_token (",
end_token_et,
")");

NODE_VALIDATION_CHECK(this,
parent_idx_rank.rank().is_dynamic() || parent_idx_rank.rank().get_length() == 3,
"parent_idx input rank must equal to 3 (parent_idx rank: ",
parent_idx_rank.rank().get_length(),
")");
result_et.is_real() || result_et.is_integral_number(),
"Element type of inputs must be numeric. Got: ",
result_et);

const auto& step_ids_pshape = get_input_partial_shape(0);
const auto& parent_idx_pshape = get_input_partial_shape(1);
const auto& max_seq_len_pshape = get_input_partial_shape(2);
const auto& end_token_pshape = get_input_partial_shape(3);

PartialShape result_pshape{PartialShape::dynamic()};
NODE_VALIDATION_CHECK(this,
max_seq_len_rank.rank().is_dynamic() || max_seq_len_rank.rank().get_length() == 1,
"max_seq_len input rank must equal to 1 (max_seq_len rank: ",
max_seq_len_rank.rank().get_length(),
")");
PartialShape::merge_into(result_pshape, step_ids_pshape) &&
PartialShape::merge_into(result_pshape, parent_idx_pshape) &&
result_pshape.rank().compatible(3),
"step_ids and parent_idx inputs must have the same shape with rank 3. Got: ",
step_ids_pshape,
" and ",
parent_idx_pshape,
", respectively");

NODE_VALIDATION_CHECK(this,
end_token_rank.rank().is_dynamic() || end_token_rank.rank().get_length() == 0,
"end_token input rank must be scalar (end_token rank: ",
end_token_rank.rank().get_length(),
")");
max_seq_len_pshape.rank().compatible(1),
"max_seq_len input must have rank 1. Got: ",
max_seq_len_pshape);

const auto& step_ids_et = get_input_element_type(0);
set_output_type(0, step_ids_et, step_ids_rank);
if (result_pshape.rank().is_static() && max_seq_len_pshape.rank().is_static()) {
NODE_VALIDATION_CHECK(this,
Dimension::merge(result_pshape[1], result_pshape[1], max_seq_len_pshape[0]),
"Number of elements of max_seq_len input must match BATCH_SIZE dimension of "
"step_ids/parent_idx inputs. Got: ",
result_pshape[1],
" and ",
max_seq_len_pshape[0],
", respectively");
}

NODE_VALIDATION_CHECK(this,
end_token_pshape.rank().compatible(0),
"end_token input must be scalar. Got: ",
end_token_pshape);

set_output_type(0, result_et, result_pshape);
}
1 change: 1 addition & 0 deletions ngraph/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -270,6 +270,7 @@ set(SRC
visitors/op/floor_mod.cpp
visitors/op/floor.cpp
visitors/op/gather.cpp
visitors/op/gather_tree.cpp
visitors/op/gelu.cpp
visitors/op/greater_equal.cpp
visitors/op/greater.cpp
319 changes: 262 additions & 57 deletions ngraph/test/type_prop/gather_tree.cpp

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions ngraph/test/visitors/op/gather_tree.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/opsets/opset1.hpp"
#include "util/visitor.hpp"

using namespace ngraph;
using ngraph::test::NodeBuilder;
using ngraph::test::ValueMap;

TEST(attributes, gather_tree_op) {
NodeBuilder::get_ops().register_factory<opset1::GatherTree>();

auto step_ids = std::make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
auto parent_idx = std::make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
auto max_seq_len = std::make_shared<op::Parameter>(element::f32, Shape{2});
auto end_token = std::make_shared<op::Parameter>(element::f32, Shape{});

auto gather_tree = std::make_shared<opset1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
NodeBuilder builder(gather_tree);

const auto expected_attr_count = 0;
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
}

0 comments on commit deeb964

Please sign in to comment.