Skip to content

Commit

Permalink
[TF FE] MatrixBandPart operation for TensorFlow Hub models (#23082)
Browse files Browse the repository at this point in the history
**Details:** `MatrixBandPart` is needed to support Keras StableDiffusion
model. This is reserved PR for
#22447

**Ticket:** CVS-133786

---------

Signed-off-by: Kazantsev, Roman <[email protected]>
Co-authored-by: himanshugupta11002 <[email protected]>
  • Loading branch information
rkazants and himanshugupta11002 authored Feb 26, 2024
1 parent a5f6308 commit 089fc0d
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/frontends/tensorflow/docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV
| MatMul | YES | |
| MatchingFiles | NO | |
| MatchingFilesDataset | NO | |
| MatrixBandPart | NO | |
| MatrixBandPart | YES | |
| MatrixDeterminant | NO | |
| MatrixDiag | YES | |
| MatrixDiagPart | NO | |
Expand Down
1 change: 1 addition & 0 deletions src/frontends/tensorflow/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"LookupTableInsertV2", CreatorFunction(translate_no_op)},
{"LRN", CreatorFunction(translate_lrn_op)},
{"MatMul", CreatorFunction(translate_mat_mul_op)},
{"MatrixBandPart", CreatorFunction(translate_matrix_band_part_op)},
{"MatrixDiag", CreatorFunction(translate_matrix_diag_op)},
{"MaxPool", CreatorFunction(translate_max_pool_op)},
{"MaxPoolV2", CreatorFunction(translate_max_pool_op)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ OP_CONVERTER(translate_log_1p_op);
OP_CONVERTER(translate_lrn_op);
OP_CONVERTER(translate_mat_mul_op);
OP_CONVERTER(translate_matrix_diag_op);
OP_CONVERTER(translate_matrix_band_part_op);
OP_CONVERTER(translate_max_pool_op);
OP_CONVERTER_NAMED(translate_max_pool_with_argmax);
OP_CONVERTER(translate_mirror_pad_op);
Expand Down
90 changes: 90 additions & 0 deletions src/frontends/tensorflow_common/src/op/matrix_band_part.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/less.hpp"
#include "openvino/op/less_eq.hpp"
#include "openvino/op/logical_and.hpp"
#include "openvino/op/logical_or.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils.hpp"

using namespace std;
using namespace ov::op;

namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {

OutputVector translate_matrix_band_part_op(const NodeContext& node) {
default_op_checks(node, 3, {"MatrixBandPart"});

// Input tensor and parameters
auto input = node.get_input(0);
auto num_lower = node.get_input(1);
auto num_upper = node.get_input(2);

// create scalar auxiliary constants
auto const_zero = make_shared<v0::Constant>(element::i64, Shape{}, 0);
auto const_one = make_shared<v0::Constant>(element::i64, Shape{}, 1);
auto const_two = make_shared<v0::Constant>(element::i64, Shape{}, 2);

// input has a shape [I, J, K, ..., M, N]
// compute sizes of two last dimensions of M and N
auto input_shape = make_shared<v3::ShapeOf>(input, element::i64);
auto input_rank = make_shared<v3::ShapeOf>(input_shape, element::i64);
auto input_rank_minus_one = make_shared<v1::Subtract>(input_rank, const_one);
auto input_rank_minus_two = make_shared<v1::Subtract>(input_rank, const_two);
auto slice_step = make_shared<v0::Constant>(element::i64, Shape{1}, 1);
auto slice_axis = make_shared<v0::Constant>(element::i64, Shape{1}, 0);
auto m = make_shared<v8::Slice>(input_shape, input_rank_minus_two, input_rank_minus_one, slice_step, slice_axis)
->output(0);
auto n = make_shared<v8::Slice>(input_shape, input_rank_minus_one, input_rank, slice_step, slice_axis)->output(0);

// generate ranges [0, M) and [0, N)
auto scalar_shape = make_shared<v0::Constant>(element::i64, Shape{0}, vector<int64_t>{});
m = make_shared<v1::Reshape>(m, scalar_shape, false);
n = make_shared<v1::Reshape>(n, scalar_shape, false);
auto range_m = make_shared<v4::Range>(const_zero, m, const_one, element::i64)->output(0);
auto range_n = make_shared<v4::Range>(const_zero, n, const_one, element::i64)->output(0);
range_m = make_shared<v0::Unsqueeze>(range_m, const_one);
range_n = make_shared<v0::Unsqueeze>(range_n, const_zero);

// adjust num_lower and num_upper to have them of type i64
// the same as M and N
// it is needed for in_band computation
num_lower = make_shared<v0::Convert>(num_lower, element::i64);
num_upper = make_shared<v0::Convert>(num_upper, element::i64);

// compute in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && (num_upper < 0 || (n-m) <= num_upper)
auto num_lower_less_zero = make_shared<v1::Less>(num_lower, const_zero);
auto i_minus_j = make_shared<v1::Subtract>(range_m, range_n);
auto i_minus_j_less_eq_num_lower = make_shared<v1::LessEqual>(i_minus_j, num_lower);
auto num_upper_less_zero = make_shared<v1::Less>(num_upper, const_zero);
auto j_minus_i = make_shared<v1::Subtract>(range_n, range_m);
auto j_minus_i_less_eq_num_upper = make_shared<v1::LessEqual>(j_minus_i, num_upper);
auto in_band1 = make_shared<v1::LogicalOr>(num_lower_less_zero, i_minus_j_less_eq_num_lower);
auto in_band2 = make_shared<v1::LogicalOr>(num_upper_less_zero, j_minus_i_less_eq_num_upper);
auto in_band = make_shared<v1::LogicalAnd>(in_band1, in_band2);

// create zero constant of the same type as input
auto zero = create_same_type_const_scalar<int32_t>(input, 0);

auto result = make_shared<v1::Select>(in_band, input, zero);

set_node_name(node.get_name(), result);
return {result};
}

} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov
41 changes: 41 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_MatrixBandPart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest

rng = np.random.default_rng()


class TestMatrixBandPart(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'input:0' in inputs_info
input_shape = inputs_info['input:0']
inputs_data = {}
inputs_data['input:0'] = rng.integers(-8, 8, input_shape).astype(self.input_type)
return inputs_data

def create_matrix_band_part_net(self, input_shape, input_type, num_lower, num_upper):
self.input_type = input_type
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
input_tensor = tf.compat.v1.placeholder(input_type, input_shape, 'input')
tf.raw_ops.MatrixBandPart(input=input_tensor, num_lower=num_lower, num_upper=num_upper)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None

@pytest.mark.parametrize('input_shape', [[5, 5], [3, 4, 4], [1, 2, 5, 5], [3, 5, 4]])
@pytest.mark.parametrize('input_type', [np.float32, np.int32])
@pytest.mark.parametrize('num_lower', [-4, -1, 0, 1, 4])
@pytest.mark.parametrize('num_upper', [-4, -1, 0, 1, 4])
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_matrix_band_part_basic(self, input_shape, input_type, num_lower, num_upper,
ie_device, precision, ir_version, temp_dir,
use_legacy_frontend):
self._test(*self.create_matrix_band_part_net(input_shape, input_type, num_lower, num_upper),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)

0 comments on commit 089fc0d

Please sign in to comment.