Skip to content

Commit

Permalink
[TF FE] Support SparseTensorDenseMatMul operation (#26064)
Browse files Browse the repository at this point in the history
**Details:** Support SparseTensorDenseMatMul operation. Required for
customer model.

**Ticket:** 104539

---------

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Aug 14, 2024
1 parent eda21ba commit aa4455a
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/frontends/tensorflow/docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,7 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV
| SparseSparseMinimum | NO | |
| SparseSplit | NO | |
| SparseTensorDenseAdd | NO | |
| SparseTensorDenseMatMul | NO | |
| SparseTensorDenseMatMul | YES | |
| SparseTensorSliceDataset | NO | |
| SparseTensorToCSRSparseMatrix | NO | |
| SparseToDense | YES | |
Expand Down
3 changes: 2 additions & 1 deletion src/frontends/tensorflow/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"SaveV2", CreatorFunction(translate_no_op)},
{"ScatterNd", CreatorFunction(translate_scatter_nd_op)},
{"SegmentSum", CreatorFunction(translate_segment_sum_op)},
{"SparseToDense", CreatorFunction(translate_sparse_to_dense_op)},
{"Select", CreatorFunction(translate_select_op)},
{"SelectV2", CreatorFunction(translate_select_v2_op)},
{"Shape", CreatorFunction(translate_shape_op)},
Expand All @@ -381,6 +380,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"Softmax", CreatorFunction(translate_softmax_op)},
{"SpaceToDepth", CreatorFunction(translate_space_to_depth_op)},
{"SparseReshape", CreatorFunction(translate_sparse_reshape_op)},
{"SparseTensorDenseMatMul", CreatorFunction(translate_sparse_tensor_dense_mat_mul_op)},
{"SparseToDense", CreatorFunction(translate_sparse_to_dense_op)},
{"Split", CreatorFunction(translate_split_op)},
{"SplitV", CreatorFunction(translate_split_v_op)},
{"StopGradient", CreatorFunction(translate_identity_op)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ OP_CONVERTER(translate_rsqrt_op);
OP_CONVERTER(translate_scatter_nd_op);
OP_CONVERTER(translate_segment_sum_op);
OP_CONVERTER(translate_space_to_batch_nd_op);
OP_CONVERTER(translate_sparse_tensor_dense_mat_mul_op);
OP_CONVERTER(translate_sparse_to_dense_op);
OP_CONVERTER(translate_select_op);
OP_CONVERTER(translate_select_v2_op);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "common_op_table.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/scatter_nd_update.hpp"
#include "utils.hpp"

using namespace std;
using namespace ov::op;

namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_sparse_tensor_dense_mat_mul_op(const NodeContext& node) {
default_op_checks(node, 4, {"SparseTensorDenseMatMul"});
auto a_indices = node.get_input(0);
auto a_values = node.get_input(1);
auto a_shape = node.get_input(2);
auto b = node.get_input(3);
auto adjoint_a = node.get_attribute<bool>("adjoint_a", false);
auto adjoint_b = node.get_attribute<bool>("adjoint_b", false);

// create dense tensor
auto zero_const = create_same_type_const_scalar<int32_t>(a_values, 0);
ov::Output<ov::Node> a = make_shared<v3::Broadcast>(zero_const, a_shape);
a = make_shared<v15::ScatterNDUpdate>(a, a_indices, a_values);
auto res = make_shared<v0::MatMul>(a, b, adjoint_a, adjoint_b);
set_node_name(node.get_name(), res);
return {res};
}

} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov
100 changes: 100 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_SparseTensorDenseMatMul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import platform
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest

rng = np.random.default_rng(475912)


class TestSparseTensorDenseMatMul(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'a_indices:0' in inputs_info
assert 'a_values:0' in inputs_info
assert 'b:0' in inputs_info

a_values_shape = inputs_info['a_values:0']
b_shape = inputs_info['b:0']

inputs_data = {}
if np.issubdtype(self.data_type, np.floating):
inputs_data['a_values:0'] = rng.uniform(-5.0, 5.0, a_values_shape).astype(self.data_type)
inputs_data['b:0'] = rng.uniform(-5.0, 5.0, b_shape).astype(self.data_type)
elif np.issubdtype(self.data_type, np.signedinteger):
inputs_data['a_values:0'] = rng.integers(-8, 8, a_values_shape).astype(self.data_type)
inputs_data['b:0'] = rng.integers(-8, 8, b_shape).astype(self.data_type)
else:
inputs_data['a_values:0'] = rng.integers(0, 8, a_values_shape).astype(self.data_type)
inputs_data['b:0'] = rng.integers(0, 8, b_shape).astype(self.data_type)

a_rows_num = self.a_shape[0]
a_cols_num = self.a_shape[1]

# generate all possible indices
all_indices = []
for row_ind in range(0, a_rows_num):
for col_ind in range(0, a_cols_num):
all_indices.append([row_ind, col_ind])
inputs_data['a_indices:0'] = rng.choice(all_indices, self.nnz, replace=False).astype(self.indices_type)

return inputs_data

def create_sparse_tensor_dense_mat_mul_net(self, data_type, indices_type,
adjoint_a, adjoint_b,
a_shape, b_shape, nnz):
a_shape = a_shape.copy()
b_shape = b_shape.copy()
if adjoint_a:
a_shape.reverse()
if adjoint_b:
b_shape.reverse()

self.data_type = data_type
self.indices_type = indices_type
self.a_shape = a_shape
self.nnz = nnz
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
a_indices = tf.compat.v1.placeholder(indices_type, [nnz, 2], 'a_indices')
a_values = tf.compat.v1.placeholder(data_type, [nnz], 'a_values')
a_shape = tf.constant(a_shape, dtype=tf.int64)
b = tf.compat.v1.placeholder(data_type, b_shape, 'b')
tf.raw_ops.SparseTensorDenseMatMul(
a_indices=a_indices,
a_values=a_values,
a_shape=a_shape,
b=b,
adjoint_a=adjoint_a,
adjoint_b=adjoint_b)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

ref_net = None

return tf_net, ref_net

@pytest.mark.parametrize('data_type', [np.float32, np.float64, np.int32])
@pytest.mark.parametrize('indices_type', [np.int32, np.int64])
@pytest.mark.parametrize('adjoint_a', [True, False])
@pytest.mark.parametrize('adjoint_b', [True, False])
@pytest.mark.parametrize('a_shape, b_shape, nnz', [
[[4, 10], [10, 5], 8],
[[5, 5], [5, 5], 3],
])
@pytest.mark.precommit
@pytest.mark.nightly
def test_sparse_tensor_dense_mat_mul(self, data_type, indices_type,
adjoint_a, adjoint_b,
a_shape, b_shape, nnz,
ie_device, precision, ir_version, temp_dir,
use_legacy_frontend):
if ie_device == 'GPU':
pytest.skip("149830: ScatterNDUpdate-15 is not supported on GPU")
self._test(*self.create_sparse_tensor_dense_mat_mul_net(data_type, indices_type,
adjoint_a, adjoint_b,
a_shape, b_shape, nnz),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)

0 comments on commit aa4455a

Please sign in to comment.