From 5c6ef5412728a7c57970e12d75026fdbf9f8ea82 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Fri, 24 Feb 2023 17:33:00 +0400 Subject: [PATCH] [PT FE]: support aten::index (#15544) * [PT FE]: support aten::index * bool indexing testing * more tests, fix nonzero case * apply code review --- src/frontends/pytorch/src/frontend.cpp | 2 + .../src/transforms/aten_index_replacer.cpp | 271 ++++++++++++++++++ .../src/transforms/aten_index_replacer.hpp | 26 ++ tests/layer_tests/pytorch_tests/test_index.py | 73 +++++ 4 files changed, 372 insertions(+) create mode 100644 src/frontends/pytorch/src/transforms/aten_index_replacer.cpp create mode 100644 src/frontends/pytorch/src/transforms/aten_index_replacer.hpp create mode 100644 tests/layer_tests/pytorch_tests/test_index.py diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index 7fb3b7ce064def..90f93267524bfc 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -18,6 +18,7 @@ #include "transforms/append_list_unpack_replacer.hpp" #include "transforms/aten_cat_replacer.hpp" #include "transforms/aten_getitem_replacer.hpp" +#include "transforms/aten_index_replacer.hpp" #include "transforms/aten_stack_list_construct_replacer.hpp" #include "transforms/einsum_list_construct.hpp" #include "transforms/listconstruct_replacer.hpp" @@ -97,6 +98,7 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/src/frontends/pytorch/src/transforms/aten_index_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_index_replacer.cpp new file mode 100644 index 00000000000000..7affc4511d028a --- /dev/null +++ b/src/frontends/pytorch/src/transforms/aten_index_replacer.cpp @@ -0,0 +1,271 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "aten_index_replacer.hpp" + +#include "openvino/core/rt_info.hpp" +#include "openvino/frontend/pytorch/visibility.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/gather_elements.hpp" +#include "openvino/op/gather_nd.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/non_zero.hpp" +#include "openvino/op/reduce_prod.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/split.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/util/framework_node.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +using namespace ov::op; +namespace { + +std::shared_ptr flatten(const Output& value, size_t axis) { + // First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of + // input tensor. The last dimension is the product of the rest of input tensor dimensions: + // [d_{axis}, ..., d_n] + Output output_shape; + if (axis == 0) { + output_shape = v0::Constant::create(element::i64, Shape{2}, {1, -1}); + } else if (axis == 1) { + output_shape = v0::Constant::create(element::i64, Shape{2}, {0, -1}); + } else { + const auto value_shape = std::make_shared(value); + const auto value_rank = std::make_shared(value_shape); + const auto axis_node = v0::Constant::create(element::i64, Shape{}, {axis}); + auto start = v0::Constant::create(element::i64, Shape{}, {0}); + auto step = v0::Constant::create(element::i64, Shape{}, {1}); + const auto first_part_dims = std::make_shared(value_shape, start, axis_node, step); + auto zero = v0::Constant::create(element::i64, {}, {0}); + auto first_part_dims_length = std::make_shared(first_part_dims, zero, true); + + auto remaining_part_length = v0::Constant::create(element::i64, {1}, {-1}); + + output_shape = std::make_shared(OutputVector{first_part_dims_length, remaining_part_length}, 0); + } + return std::make_shared(value, output_shape, true); +} +}; // namespace + +AtenIndexToSelect::AtenIndexToSelect() { + auto index_op = ov::pass::pattern::wrap_type(); + + ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) { + auto index_op = cast_fw_node(m.get_match_root(), "aten::index"); + if (!index_op) { + return false; + } + auto input_node = index_op->input_value(0).get_node_shared_ptr(); + auto indicies = index_op->input_value(1).get_node_shared_ptr(); + auto list_indicies = cast_fw_node(indicies, "prim::ListConstruct"); + if (list_indicies) { + // Multiple tensors as indices. Each tensor could either be + // 1. prim::Constant() + // representing ":" in python indexing. E.g. tensor[:, :] + // 2. prim::Constant[value=...] or tensor output + // representing advanced indexing. E.g. tensor[[0, 1], [2, 0]]. + // For more info on advanced indexing, + // check https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing + + // Consider a general case of + // t: [x_1, y_1, y_2, ..., x_m, ..., y_n] + // where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes for + // ":". Same results can be achieved through transposing t into + // t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] + // and use gather + // t: [x_1 * x_2 * ... * x_m, y_1 * y_2 * ... * y_n] + // tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)) + // After gather, reshape and transpose back. + auto ids = list_indicies->input_values(); + std::vector advanced_ids; + std::vector is_masked_bool; + OutputVector masked_indicies; + // for case when index is bool e.g. x[x>0], replace index with non_zero + for (size_t i = 0; i < ids.size(); i++) { + auto const_input = cast_fw_node(ids[i].get_node_shared_ptr(), "prim::Constant"); + + // skip dimensions where index is None + if (const_input) { + const auto& attrs = const_input->get_attrs(); + if (attrs.find("none_value") != attrs.end()) { + masked_indicies.push_back(ids[i]); + is_masked_bool.push_back(false); + continue; + } + } + auto id_dtype = ids[i].get_node_shared_ptr()->get_element_type(); + if (id_dtype == element::boolean || id_dtype == element::u8) { + auto idx = std::make_shared(ids[i], element::u8); + auto nonzero = std::make_shared(idx); + auto input_order = v0::Constant::create(element::i64, Shape{2}, {1, 0}); + auto masked_id = std::make_shared(nonzero, input_order); + masked_indicies.push_back(masked_id); + is_masked_bool.push_back(true); + } else { + masked_indicies.push_back(ids[i]); + is_masked_bool.push_back(false); + } + advanced_ids.push_back(i); + } + + // all indicies prim::Constant(None), return input as is + if (advanced_ids.size() == 0) { + copy_runtime_info({index_op, input_node}, input_node); + replace_node(index_op, input_node); + return true; + } + // perform gather for single element case + if (advanced_ids.size() == 1) { + auto index = masked_indicies[advanced_ids[0]]; + index = std::make_shared(index, element::i64); + if (is_masked_bool[advanced_ids[0]]) { + auto gather = std::make_shared(input_node, index); + copy_runtime_info({index_op, input_node, indicies}, gather); + replace_node(index_op, gather); + return true; + } + auto dim = v0::Constant::create(element::i64, Shape{}, {advanced_ids[0]}); + auto gather = std::make_shared(input_node, index, dim); + copy_runtime_info({index_op, input_node, indicies}, gather); + replace_node(index_op, gather); + return true; + } + auto adv_idx_count = advanced_ids.size(); + auto rank = input_node->get_input_partial_shape(0).rank(); + if (rank.is_dynamic()) { + FRONT_END_CHECK_IMPLEMENTED(false, "indexing for tensor with dynamic rank is not implemented "); + } + auto input_shape = std::make_shared(input_node); + auto zero = v0::Constant::create(element::i64, Shape{}, {0}); + auto input_dims = std::make_shared(input_shape, zero, rank.get_length()); + std::vector non_used_dims; + for (auto i = 0; i < rank.get_length(); i++) { + if (std::find(advanced_ids.begin(), advanced_ids.end(), i) == advanced_ids.end()) { + non_used_dims.push_back(i); + } + } + std::vector permutation_dims; + permutation_dims.insert(permutation_dims.end(), advanced_ids.begin(), advanced_ids.end()); + permutation_dims.insert(permutation_dims.end(), non_used_dims.begin(), non_used_dims.end()); + auto transpose_dims = v0::Constant::create(element::i64, Shape{permutation_dims.size()}, permutation_dims); + auto transposed_input = std::make_shared(input_node, transpose_dims); + auto flatten_input = flatten(transposed_input, adv_idx_count); + auto cum_adv_index = masked_indicies[advanced_ids[adv_idx_count - 1]]; + auto multiplier = input_dims->output(advanced_ids[adv_idx_count - 1]); + for (int i = static_cast(adv_idx_count) - 2; i > 0; i--) { + auto adv_index = std::make_shared(masked_indicies[i], multiplier); + cum_adv_index = std::make_shared(cum_adv_index, adv_index); + auto input_id = advanced_ids[i]; + multiplier = std::make_shared(multiplier, input_dims->output(input_id)); + } + std::shared_ptr gather = std::make_shared(flatten_input, cum_adv_index, zero); + OutputVector concat_dims; + // check if all advanced indices are consecutive. + std::vector consequence_dims; + auto cum_adv_index_shape_tensor = std::make_shared(cum_adv_index); + for (size_t i = advanced_ids[0]; i <= advanced_ids[advanced_ids.size() - 1]; i++) { + consequence_dims.push_back(i); + } + // unfold regular index axes + if (advanced_ids == consequence_dims) { + OutputVector folded_adv_idx_shape_vector; + auto minus_one = v0::Constant::create(element::i64, Shape{1}, {-1}); + folded_adv_idx_shape_vector.push_back(minus_one); + for (auto i : non_used_dims) { + folded_adv_idx_shape_vector.push_back(input_dims->output(i)); + } + auto folded_adv_idx_shape = std::make_shared(folded_adv_idx_shape_vector, 0); + gather = std::make_shared(gather, folded_adv_idx_shape, false); + std::vector adv_idx_permute; + for (size_t i = 1; i < advanced_ids[0] + 1; i++) { + adv_idx_permute.push_back(i); + } + adv_idx_permute.push_back(0); + for (size_t i = advanced_ids[0] + 1; i < (rank.get_length() - adv_idx_count + 1); i++) { + adv_idx_permute.push_back(i); + } + // Transpose folded advanced indexed axis to its original location. + auto permute_indicies = + v0::Constant::create(element::i64, Shape{adv_idx_permute.size()}, adv_idx_permute); + gather = std::make_shared(gather, permute_indicies); + // unfold advanced index axes + for (size_t i = 0; i <= advanced_ids[0]; i++) { + concat_dims.push_back(input_dims->output(i)); + } + concat_dims.push_back(cum_adv_index_shape_tensor); + for (auto i : non_used_dims) { + if (i < advanced_ids[i]) { + continue; + } + concat_dims.push_back(input_dims->output(i)); + } + + } else { + concat_dims.push_back(cum_adv_index_shape_tensor); + for (auto i : non_used_dims) { + concat_dims.push_back(input_dims->output(i)); + } + } + auto final_shape = std::make_shared(concat_dims, 0); + gather = std::make_shared(gather, final_shape, false); + copy_runtime_info({index_op, input_node, indicies}, gather); + replace_node(index_op, gather); + return true; + + } else { + auto const_input = cast_fw_node(indicies, "prim::Constant"); + + if (const_input) { + // index is None, stay input as is + const auto& attrs = const_input->get_attrs(); + if (attrs.find("none_value") != attrs.end()) { + copy_runtime_info({index_op, input_node, indicies}, input_node); + replace_node(index_op, input_node); + return true; + } + } + auto index_dtype = indicies->get_output_element_type(0); + if (index_dtype == element::boolean || index_dtype == element::u8) { + auto nonzero = std::make_shared(indicies); + auto input_order = v0::Constant::create(element::i64, Shape{2}, {1, 0}); + auto masked_id = std::make_shared(nonzero, input_order); + auto gather = std::make_shared(input_node, masked_id); + copy_runtime_info({index_op, input_node, indicies}, gather); + replace_node(index_op, gather); + return true; + } + if (index_dtype != element::i32 && index_dtype != element::i64) { + indicies = std::make_shared(indicies, element::i64); + } + auto dim = v0::Constant::create(element::i64, Shape{}, {0}); + auto gather = std::make_shared(input_node, indicies, dim); + copy_runtime_info({index_op, input_node, indicies}, gather); + replace_node(index_op, gather); + return true; + } + return false; + }; + + auto m = std::make_shared(index_op, "ov::frontend::pytorch::pass::AtenIndexToSelect"); + this->register_matcher(m, callback); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/aten_index_replacer.hpp b/src/frontends/pytorch/src/transforms/aten_index_replacer.hpp new file mode 100644 index 00000000000000..84f6133253aea6 --- /dev/null +++ b/src/frontends/pytorch/src/transforms/aten_index_replacer.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/frontend/pytorch/visibility.hpp" +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +// This transformation replaces pattern prim::ListConstruct->aten::index +class PYTORCH_API AtenIndexToSelect : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::AtenIndexToSelect"); + AtenIndexToSelect(); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/tests/layer_tests/pytorch_tests/test_index.py b/tests/layer_tests/pytorch_tests/test_index.py new file mode 100644 index 00000000000000..967ef4c98afb6e --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_index.py @@ -0,0 +1,73 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import numpy as np + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestIndex(PytorchLayerTest): + def _prepare_input(self, input_shape, idx): + import numpy as np + return (np.random.randn(*input_shape).astype(np.float32), idx) + + def create_model(self, model="list"): + import torch + + class aten_index_list(torch.nn.Module): + + def forward(self, x, idx): + return x[idx] + + class aten_index_getitem(torch.nn.Module): + + def forward(self, x, idx): + return x.__getitem__(idx) + + + class aten_index_list_bool(torch.nn.Module): + + def forward(self, x, idx): + return x[idx.to(torch.bool)] + + class aten_index_getitem_bool(torch.nn.Module): + + def forward(self, x, idx): + return x.__getitem__(idx.to(torch.bool)) + cases = { + "list": aten_index_list, + "getitem": aten_index_getitem, + "list_with_bool": aten_index_list_bool, + "getitem_with_bool": aten_index_getitem_bool + } + + aten_index = cases[model] + + ref_net = None + + return aten_index(), ref_net, "aten::index" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("case", ["list", "getitem"]) + @pytest.mark.parametrize(("input_shape", "idx"), [ + ((1,), np.array(0).astype(int)), + ([2, 3], np.array(-1).astype(int)), + ([4, 5, 6], np.array((1, 2)).astype(int)), + ([7, 8, 9], np.array((-1, 2, -3)).astype(int)), + ([2, 2, 3, 4], np.array((1,)).astype(int))]) + def test_index(self, input_shape, idx, case, ie_device, precision, ir_version): + self._test(*self.create_model(case), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx}) + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("case", ["getitem_with_bool", "list_with_bool"]) + @pytest.mark.parametrize(("input_shape", "idx"), [ + ((1, 2), np.array([[1, 0]]).astype(bool)), + ((2, 2, 5), np.zeros([2, 2, 5]).astype(bool)), + ((2, 2, 5), np.ones([2, 2, 5]).astype(bool)), + ((2, 2, 5), np.random.rand(2, 2, 5) > 0) + ]) + def test_index_bool(self, input_shape, idx, case, ie_device, precision, ir_version): + self._test(*self.create_model(case), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx}) \ No newline at end of file