Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF FE] MatrixBandPart operation for TensorFlow Hub models #23082

Merged
merged 32 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8165924
Create MatrixBandPart.cpp
himanshugupta11002 Jan 15, 2024
d1abc7d
Delete src/frontends/tensorflow_common/src/op/MatrixBandPart.cpp
himanshugupta11002 Jan 17, 2024
a0af6d6
Create matrix_band_part.cpp
himanshugupta11002 Jan 17, 2024
c76cdbf
Update op_table.cpp
himanshugupta11002 Jan 17, 2024
d3b0f9c
Update matrix_band_part.cpp
himanshugupta11002 Jan 19, 2024
57a7629
Update matrix_band_part.cpp
himanshugupta11002 Jan 23, 2024
1efcaaa
Create test_tf_MatrixBandPart.py
himanshugupta11002 Jan 23, 2024
dd1dbd1
Merge branch 'master' into master
himanshugupta11002 Feb 1, 2024
74300ea
Update supported_ops.md
himanshugupta11002 Feb 1, 2024
c4c7e2e
Merge branch 'master' into master
himanshugupta11002 Feb 1, 2024
4507c8a
add declaration
himanshugupta11002 Feb 11, 2024
2824540
Update common_op_table.hpp
himanshugupta11002 Feb 11, 2024
24d3ed9
Apply suggestions from code review
rkazants Feb 12, 2024
184b0a5
Update matrix_band_part.cpp
himanshugupta11002 Feb 12, 2024
97d62ab
added in_band(m,n) conditions
himanshugupta11002 Feb 13, 2024
95b62b6
Update src/frontends/tensorflow_common/src/op/matrix_band_part.cpp
rkazants Feb 14, 2024
08d4983
Update matrix_band_part.cpp
himanshugupta11002 Feb 15, 2024
bea6081
update constant part
himanshugupta11002 Feb 18, 2024
9f2881b
created sepeerate variable for zero
himanshugupta11002 Feb 23, 2024
bcd97c2
Merge branch 'master' into master
himanshugupta11002 Feb 25, 2024
17dc0db
Restore test models
rkazants Feb 25, 2024
ec23766
Fix translator
rkazants Feb 26, 2024
a12584d
Merge remote-tracking branch 'upstream/master' into himanshugupta1100…
rkazants Feb 26, 2024
96d5fe3
Merge branch 'master' into himanshugupta11002_master
rkazants Feb 26, 2024
1c16f90
Update tests/layer_tests/tensorflow_tests/test_tf_MatrixBandPart.py
rkazants Feb 26, 2024
d96f515
Update tests/layer_tests/tensorflow_tests/test_tf_MatrixBandPart.py
rkazants Feb 26, 2024
d4c0f2b
Merge branch 'master' into himanshugupta11002_master
rkazants Feb 26, 2024
aab0a39
Update tests/layer_tests/tensorflow_tests/test_tf_MatrixBandPart.py
rkazants Feb 26, 2024
152ba00
Update tests/layer_tests/tensorflow_tests/test_tf_MatrixBandPart.py
rkazants Feb 26, 2024
09feb1d
Update tests/layer_tests/tensorflow_tests/test_tf_MatrixBandPart.py
rkazants Feb 26, 2024
7a7d224
Merge remote-tracking branch 'rkazants/himanshugupta11002_master' int…
rkazants Feb 26, 2024
01b0c43
Extend layer tests for MatrixBandPart
rkazants Feb 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/frontends/tensorflow/docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV
| MatMul | YES | |
| MatchingFiles | NO | |
| MatchingFilesDataset | NO | |
| MatrixBandPart | NO | |
| MatrixBandPart | YES | |
| MatrixDeterminant | NO | |
| MatrixDiag | YES | |
| MatrixDiagPart | NO | |
Expand Down
1 change: 1 addition & 0 deletions src/frontends/tensorflow/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"LookupTableInsertV2", CreatorFunction(translate_no_op)},
{"LRN", CreatorFunction(translate_lrn_op)},
{"MatMul", CreatorFunction(translate_mat_mul_op)},
{"MatrixBandPart", CreatorFunction(translate_matrix_band_part_op)},
{"MatrixDiag", CreatorFunction(translate_matrix_diag_op)},
{"MaxPool", CreatorFunction(translate_max_pool_op)},
{"MaxPoolV2", CreatorFunction(translate_max_pool_op)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ OP_CONVERTER(translate_log_1p_op);
OP_CONVERTER(translate_lrn_op);
OP_CONVERTER(translate_mat_mul_op);
OP_CONVERTER(translate_matrix_diag_op);
OP_CONVERTER(translate_matrix_band_part_op);
OP_CONVERTER(translate_max_pool_op);
OP_CONVERTER_NAMED(translate_max_pool_with_argmax);
OP_CONVERTER(translate_mirror_pad_op);
Expand Down
90 changes: 90 additions & 0 deletions src/frontends/tensorflow_common/src/op/matrix_band_part.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/less.hpp"
#include "openvino/op/less_eq.hpp"
#include "openvino/op/logical_and.hpp"
#include "openvino/op/logical_or.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils.hpp"

using namespace std;
using namespace ov::op;

namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {

OutputVector translate_matrix_band_part_op(const NodeContext& node) {
default_op_checks(node, 3, {"MatrixBandPart"});

// Input tensor and parameters
auto input = node.get_input(0);
auto num_lower = node.get_input(1);
auto num_upper = node.get_input(2);

// create scalar auxiliary constants
auto const_zero = make_shared<v0::Constant>(element::i64, Shape{}, 0);
auto const_one = make_shared<v0::Constant>(element::i64, Shape{}, 1);
auto const_two = make_shared<v0::Constant>(element::i64, Shape{}, 2);

// input has a shape [I, J, K, ..., M, N]
// compute sizes of two last dimensions of M and N
auto input_shape = make_shared<v3::ShapeOf>(input, element::i64);
auto input_rank = make_shared<v3::ShapeOf>(input_shape, element::i64);
auto input_rank_minus_one = make_shared<v1::Subtract>(input_rank, const_one);
auto input_rank_minus_two = make_shared<v1::Subtract>(input_rank, const_two);
auto slice_step = make_shared<v0::Constant>(element::i64, Shape{1}, 1);
auto slice_axis = make_shared<v0::Constant>(element::i64, Shape{1}, 0);
auto m = make_shared<v8::Slice>(input_shape, input_rank_minus_two, input_rank_minus_one, slice_step, slice_axis)
->output(0);
auto n = make_shared<v8::Slice>(input_shape, input_rank_minus_one, input_rank, slice_step, slice_axis)->output(0);

// generate ranges [0, M) and [0, N)
auto scalar_shape = make_shared<v0::Constant>(element::i64, Shape{0}, vector<int64_t>{});
m = make_shared<v1::Reshape>(m, scalar_shape, false);
n = make_shared<v1::Reshape>(n, scalar_shape, false);
auto range_m = make_shared<v4::Range>(const_zero, m, const_one, element::i64)->output(0);
auto range_n = make_shared<v4::Range>(const_zero, n, const_one, element::i64)->output(0);
range_m = make_shared<v0::Unsqueeze>(range_m, const_one);
range_n = make_shared<v0::Unsqueeze>(range_n, const_zero);

// adjust num_lower and num_upper to have them of type i64
// the same as M and N
// it is needed for in_band computation
num_lower = make_shared<v0::Convert>(num_lower, element::i64);
num_upper = make_shared<v0::Convert>(num_upper, element::i64);

// compute in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && (num_upper < 0 || (n-m) <= num_upper)
auto num_lower_less_zero = make_shared<v1::Less>(num_lower, const_zero);
auto i_minus_j = make_shared<v1::Subtract>(range_m, range_n);
auto i_minus_j_less_eq_num_lower = make_shared<v1::LessEqual>(i_minus_j, num_lower);
auto num_upper_less_zero = make_shared<v1::Less>(num_upper, const_zero);
auto j_minus_i = make_shared<v1::Subtract>(range_n, range_m);
auto j_minus_i_less_eq_num_upper = make_shared<v1::LessEqual>(j_minus_i, num_upper);
auto in_band1 = make_shared<v1::LogicalOr>(num_lower_less_zero, i_minus_j_less_eq_num_lower);
auto in_band2 = make_shared<v1::LogicalOr>(num_upper_less_zero, j_minus_i_less_eq_num_upper);
auto in_band = make_shared<v1::LogicalAnd>(in_band1, in_band2);

// create zero constant of the same type as input
auto zero = create_same_type_const_scalar<int32_t>(input, 0);

auto result = make_shared<v1::Select>(in_band, input, zero);

set_node_name(node.get_name(), result);
return {result};
}

} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov
41 changes: 41 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_MatrixBandPart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest

rng = np.random.default_rng()


class TestMatrixBandPart(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'input:0' in inputs_info
input_shape = inputs_info['input:0']
inputs_data = {}
inputs_data['input:0'] = rng.integers(-8, 8, input_shape).astype(self.input_type)
return inputs_data

def create_matrix_band_part_net(self, input_shape, input_type, num_lower, num_upper):
self.input_type = input_type
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
input_tensor = tf.compat.v1.placeholder(input_type, input_shape, 'input')
tf.raw_ops.MatrixBandPart(input=input_tensor, num_lower=num_lower, num_upper=num_upper)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None

@pytest.mark.parametrize('input_shape', [[5, 5], [3, 4, 4], [1, 2, 5, 5], [3, 5, 4]])
@pytest.mark.parametrize('input_type', [np.float32, np.int32])
@pytest.mark.parametrize('num_lower', [-4, -1, 0, 1, 4])
@pytest.mark.parametrize('num_upper', [-4, -1, 0, 1, 4])
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_matrix_band_part_basic(self, input_shape, input_type, num_lower, num_upper,
ie_device, precision, ir_version, temp_dir,
use_legacy_frontend):
self._test(*self.create_matrix_band_part_net(input_shape, input_type, num_lower, num_upper),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)
Loading