Skip to content

Commit

Permalink
[PT FE]: support aten::index (#15544)
Browse files Browse the repository at this point in the history
* [PT FE]: support aten::index

* bool indexing testing

* more tests, fix nonzero case

* apply code review
  • Loading branch information
eaidova authored Feb 24, 2023
1 parent ba45c99 commit 5c6ef54
Show file tree
Hide file tree
Showing 4 changed files with 372 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "transforms/append_list_unpack_replacer.hpp"
#include "transforms/aten_cat_replacer.hpp"
#include "transforms/aten_getitem_replacer.hpp"
#include "transforms/aten_index_replacer.hpp"
#include "transforms/aten_stack_list_construct_replacer.hpp"
#include "transforms/einsum_list_construct.hpp"
#include "transforms/listconstruct_replacer.hpp"
Expand Down Expand Up @@ -97,6 +98,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::frontend::pytorch::pass::PrimListUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenGetItemReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::ListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexToSelect>();
manager.register_pass<ov::frontend::pytorch::pass::PrimListConstructPadReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenEinsumListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::MinMaxPrimListConstructReplacer>();
Expand Down
271 changes: 271 additions & 0 deletions src/frontends/pytorch/src/transforms/aten_index_replacer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "aten_index_replacer.hpp"

#include "openvino/core/rt_info.hpp"
#include "openvino/frontend/pytorch/visibility.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/gather_elements.hpp"
#include "openvino/op/gather_nd.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/non_zero.hpp"
#include "openvino/op/reduce_prod.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {

using namespace ov::op;
namespace {

std::shared_ptr<Node> flatten(const Output<Node>& value, size_t axis) {
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of
// input tensor. The last dimension is the product of the rest of input tensor dimensions:
// [d_{axis}, ..., d_n]
Output<Node> output_shape;
if (axis == 0) {
output_shape = v0::Constant::create(element::i64, Shape{2}, {1, -1});
} else if (axis == 1) {
output_shape = v0::Constant::create(element::i64, Shape{2}, {0, -1});
} else {
const auto value_shape = std::make_shared<v3::ShapeOf>(value);
const auto value_rank = std::make_shared<v3::ShapeOf>(value_shape);
const auto axis_node = v0::Constant::create(element::i64, Shape{}, {axis});
auto start = v0::Constant::create(element::i64, Shape{}, {0});
auto step = v0::Constant::create(element::i64, Shape{}, {1});
const auto first_part_dims = std::make_shared<v8::Slice>(value_shape, start, axis_node, step);
auto zero = v0::Constant::create(element::i64, {}, {0});
auto first_part_dims_length = std::make_shared<ov::op::v1::ReduceProd>(first_part_dims, zero, true);

auto remaining_part_length = v0::Constant::create(element::i64, {1}, {-1});

output_shape = std::make_shared<v0::Concat>(OutputVector{first_part_dims_length, remaining_part_length}, 0);
}
return std::make_shared<v1::Reshape>(value, output_shape, true);
}
}; // namespace

AtenIndexToSelect::AtenIndexToSelect() {
auto index_op = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();

ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
auto index_op = cast_fw_node(m.get_match_root(), "aten::index");
if (!index_op) {
return false;
}
auto input_node = index_op->input_value(0).get_node_shared_ptr();
auto indicies = index_op->input_value(1).get_node_shared_ptr();
auto list_indicies = cast_fw_node(indicies, "prim::ListConstruct");
if (list_indicies) {
// Multiple tensors as indices. Each tensor could either be
// 1. prim::Constant()
// representing ":" in python indexing. E.g. tensor[:, :]
// 2. prim::Constant[value=...] or tensor output
// representing advanced indexing. E.g. tensor[[0, 1], [2, 0]].
// For more info on advanced indexing,
// check https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing

// Consider a general case of
// t: [x_1, y_1, y_2, ..., x_m, ..., y_n]
// where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes for
// ":". Same results can be achieved through transposing t into
// t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n]
// and use gather
// t: [x_1 * x_2 * ... * x_m, y_1 * y_2 * ... * y_n]
// tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j))
// After gather, reshape and transpose back.
auto ids = list_indicies->input_values();
std::vector<size_t> advanced_ids;
std::vector<bool> is_masked_bool;
OutputVector masked_indicies;
// for case when index is bool e.g. x[x>0], replace index with non_zero
for (size_t i = 0; i < ids.size(); i++) {
auto const_input = cast_fw_node(ids[i].get_node_shared_ptr(), "prim::Constant");

// skip dimensions where index is None
if (const_input) {
const auto& attrs = const_input->get_attrs();
if (attrs.find("none_value") != attrs.end()) {
masked_indicies.push_back(ids[i]);
is_masked_bool.push_back(false);
continue;
}
}
auto id_dtype = ids[i].get_node_shared_ptr()->get_element_type();
if (id_dtype == element::boolean || id_dtype == element::u8) {
auto idx = std::make_shared<ov::op::v0::Convert>(ids[i], element::u8);
auto nonzero = std::make_shared<ov::op::v3::NonZero>(idx);
auto input_order = v0::Constant::create(element::i64, Shape{2}, {1, 0});
auto masked_id = std::make_shared<v1::Transpose>(nonzero, input_order);
masked_indicies.push_back(masked_id);
is_masked_bool.push_back(true);
} else {
masked_indicies.push_back(ids[i]);
is_masked_bool.push_back(false);
}
advanced_ids.push_back(i);
}

// all indicies prim::Constant(None), return input as is
if (advanced_ids.size() == 0) {
copy_runtime_info({index_op, input_node}, input_node);
replace_node(index_op, input_node);
return true;
}
// perform gather for single element case
if (advanced_ids.size() == 1) {
auto index = masked_indicies[advanced_ids[0]];
index = std::make_shared<v0::Convert>(index, element::i64);
if (is_masked_bool[advanced_ids[0]]) {
auto gather = std::make_shared<v8::GatherND>(input_node, index);
copy_runtime_info({index_op, input_node, indicies}, gather);
replace_node(index_op, gather);
return true;
}
auto dim = v0::Constant::create(element::i64, Shape{}, {advanced_ids[0]});
auto gather = std::make_shared<v8::Gather>(input_node, index, dim);
copy_runtime_info({index_op, input_node, indicies}, gather);
replace_node(index_op, gather);
return true;
}
auto adv_idx_count = advanced_ids.size();
auto rank = input_node->get_input_partial_shape(0).rank();
if (rank.is_dynamic()) {
FRONT_END_CHECK_IMPLEMENTED(false, "indexing for tensor with dynamic rank is not implemented ");
}
auto input_shape = std::make_shared<v3::ShapeOf>(input_node);
auto zero = v0::Constant::create(element::i64, Shape{}, {0});
auto input_dims = std::make_shared<v1::Split>(input_shape, zero, rank.get_length());
std::vector<size_t> non_used_dims;
for (auto i = 0; i < rank.get_length(); i++) {
if (std::find(advanced_ids.begin(), advanced_ids.end(), i) == advanced_ids.end()) {
non_used_dims.push_back(i);
}
}
std::vector<size_t> permutation_dims;
permutation_dims.insert(permutation_dims.end(), advanced_ids.begin(), advanced_ids.end());
permutation_dims.insert(permutation_dims.end(), non_used_dims.begin(), non_used_dims.end());
auto transpose_dims = v0::Constant::create(element::i64, Shape{permutation_dims.size()}, permutation_dims);
auto transposed_input = std::make_shared<v1::Transpose>(input_node, transpose_dims);
auto flatten_input = flatten(transposed_input, adv_idx_count);
auto cum_adv_index = masked_indicies[advanced_ids[adv_idx_count - 1]];
auto multiplier = input_dims->output(advanced_ids[adv_idx_count - 1]);
for (int i = static_cast<int>(adv_idx_count) - 2; i > 0; i--) {
auto adv_index = std::make_shared<v1::Multiply>(masked_indicies[i], multiplier);
cum_adv_index = std::make_shared<v1::Add>(cum_adv_index, adv_index);
auto input_id = advanced_ids[i];
multiplier = std::make_shared<v1::Multiply>(multiplier, input_dims->output(input_id));
}
std::shared_ptr<Node> gather = std::make_shared<v8::Gather>(flatten_input, cum_adv_index, zero);
OutputVector concat_dims;
// check if all advanced indices are consecutive.
std::vector<size_t> consequence_dims;
auto cum_adv_index_shape_tensor = std::make_shared<v3::ShapeOf>(cum_adv_index);
for (size_t i = advanced_ids[0]; i <= advanced_ids[advanced_ids.size() - 1]; i++) {
consequence_dims.push_back(i);
}
// unfold regular index axes
if (advanced_ids == consequence_dims) {
OutputVector folded_adv_idx_shape_vector;
auto minus_one = v0::Constant::create(element::i64, Shape{1}, {-1});
folded_adv_idx_shape_vector.push_back(minus_one);
for (auto i : non_used_dims) {
folded_adv_idx_shape_vector.push_back(input_dims->output(i));
}
auto folded_adv_idx_shape = std::make_shared<v0::Concat>(folded_adv_idx_shape_vector, 0);
gather = std::make_shared<v1::Reshape>(gather, folded_adv_idx_shape, false);
std::vector<size_t> adv_idx_permute;
for (size_t i = 1; i < advanced_ids[0] + 1; i++) {
adv_idx_permute.push_back(i);
}
adv_idx_permute.push_back(0);
for (size_t i = advanced_ids[0] + 1; i < (rank.get_length() - adv_idx_count + 1); i++) {
adv_idx_permute.push_back(i);
}
// Transpose folded advanced indexed axis to its original location.
auto permute_indicies =
v0::Constant::create(element::i64, Shape{adv_idx_permute.size()}, adv_idx_permute);
gather = std::make_shared<v1::Transpose>(gather, permute_indicies);
// unfold advanced index axes
for (size_t i = 0; i <= advanced_ids[0]; i++) {
concat_dims.push_back(input_dims->output(i));
}
concat_dims.push_back(cum_adv_index_shape_tensor);
for (auto i : non_used_dims) {
if (i < advanced_ids[i]) {
continue;
}
concat_dims.push_back(input_dims->output(i));
}

} else {
concat_dims.push_back(cum_adv_index_shape_tensor);
for (auto i : non_used_dims) {
concat_dims.push_back(input_dims->output(i));
}
}
auto final_shape = std::make_shared<v0::Concat>(concat_dims, 0);
gather = std::make_shared<v1::Reshape>(gather, final_shape, false);
copy_runtime_info({index_op, input_node, indicies}, gather);
replace_node(index_op, gather);
return true;

} else {
auto const_input = cast_fw_node(indicies, "prim::Constant");

if (const_input) {
// index is None, stay input as is
const auto& attrs = const_input->get_attrs();
if (attrs.find("none_value") != attrs.end()) {
copy_runtime_info({index_op, input_node, indicies}, input_node);
replace_node(index_op, input_node);
return true;
}
}
auto index_dtype = indicies->get_output_element_type(0);
if (index_dtype == element::boolean || index_dtype == element::u8) {
auto nonzero = std::make_shared<v3::NonZero>(indicies);
auto input_order = v0::Constant::create(element::i64, Shape{2}, {1, 0});
auto masked_id = std::make_shared<v1::Transpose>(nonzero, input_order);
auto gather = std::make_shared<v8::GatherND>(input_node, masked_id);
copy_runtime_info({index_op, input_node, indicies}, gather);
replace_node(index_op, gather);
return true;
}
if (index_dtype != element::i32 && index_dtype != element::i64) {
indicies = std::make_shared<ov::op::v0::Convert>(indicies, element::i64);
}
auto dim = v0::Constant::create(element::i64, Shape{}, {0});
auto gather = std::make_shared<v8::Gather>(input_node, indicies, dim);
copy_runtime_info({index_op, input_node, indicies}, gather);
replace_node(index_op, gather);
return true;
}
return false;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(index_op, "ov::frontend::pytorch::pass::AtenIndexToSelect");
this->register_matcher(m, callback);
};

} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov
26 changes: 26 additions & 0 deletions src/frontends/pytorch/src/transforms/aten_index_replacer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/frontend/pytorch/visibility.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {

// This transformation replaces pattern prim::ListConstruct->aten::index
class PYTORCH_API AtenIndexToSelect : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::AtenIndexToSelect");
AtenIndexToSelect();
};

} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov
73 changes: 73 additions & 0 deletions tests/layer_tests/pytorch_tests/test_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import numpy as np

from pytorch_layer_test_class import PytorchLayerTest


class TestIndex(PytorchLayerTest):
def _prepare_input(self, input_shape, idx):
import numpy as np
return (np.random.randn(*input_shape).astype(np.float32), idx)

def create_model(self, model="list"):
import torch

class aten_index_list(torch.nn.Module):

def forward(self, x, idx):
return x[idx]

class aten_index_getitem(torch.nn.Module):

def forward(self, x, idx):
return x.__getitem__(idx)


class aten_index_list_bool(torch.nn.Module):

def forward(self, x, idx):
return x[idx.to(torch.bool)]

class aten_index_getitem_bool(torch.nn.Module):

def forward(self, x, idx):
return x.__getitem__(idx.to(torch.bool))
cases = {
"list": aten_index_list,
"getitem": aten_index_getitem,
"list_with_bool": aten_index_list_bool,
"getitem_with_bool": aten_index_getitem_bool
}

aten_index = cases[model]

ref_net = None

return aten_index(), ref_net, "aten::index"

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("case", ["list", "getitem"])
@pytest.mark.parametrize(("input_shape", "idx"), [
((1,), np.array(0).astype(int)),
([2, 3], np.array(-1).astype(int)),
([4, 5, 6], np.array((1, 2)).astype(int)),
([7, 8, 9], np.array((-1, 2, -3)).astype(int)),
([2, 2, 3, 4], np.array((1,)).astype(int))])
def test_index(self, input_shape, idx, case, ie_device, precision, ir_version):
self._test(*self.create_model(case), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx})

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("case", ["getitem_with_bool", "list_with_bool"])
@pytest.mark.parametrize(("input_shape", "idx"), [
((1, 2), np.array([[1, 0]]).astype(bool)),
((2, 2, 5), np.zeros([2, 2, 5]).astype(bool)),
((2, 2, 5), np.ones([2, 2, 5]).astype(bool)),
((2, 2, 5), np.random.rand(2, 2, 5) > 0)
])
def test_index_bool(self, input_shape, idx, case, ie_device, precision, ir_version):
self._test(*self.create_model(case), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx})

0 comments on commit 5c6ef54

Please sign in to comment.