Skip to content

Commit

Permalink
[TF FE] Refactor RandomUniform support and provide more test coverage (
Browse files Browse the repository at this point in the history
…#14847)

Signed-off-by: Kazantsev, Roman <[email protected]>

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Dec 29, 2022
1 parent 1ef17c5 commit 36a16c8
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 30 deletions.
30 changes: 20 additions & 10 deletions src/frontends/tensorflow/src/op/random_uniform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,37 @@ namespace frontend {
namespace tensorflow {
namespace op {
ov::OutputVector translate_random_uniform_op(const NodeContext& node) {
default_op_checks(node, 1, {"RandomUniform"});
auto shape = node.get_input(0);

// retrieve attributes
auto seed = node.get_attribute<int64_t>("seed", 0);
auto seed2 = node.get_attribute<int64_t>("seed2", 0);
auto minval_const = make_shared<Constant>(element::f32, Shape{}, 0);
auto maxval_const = make_shared<Constant>(element::f32, Shape{}, 1);
auto ng_et = node.get_attribute<ov::element::Type>("dtype");
auto res = std::make_shared<RandomUniform>(shape, minval_const, maxval_const, ng_et, seed, seed2);
set_node_name(node.get_name(), res);
return res->outputs();
auto output_type = node.get_attribute<ov::element::Type>("dtype");

auto minval = make_shared<Constant>(output_type, Shape{}, 0);
auto maxval = make_shared<Constant>(output_type, Shape{}, 1);
auto random = std::make_shared<RandomUniform>(shape, minval, maxval, output_type, seed, seed2);

set_node_name(node.get_name(), random);
return random->outputs();
}

ov::OutputVector translate_random_uniform_int_op(const NodeContext& node) {
default_op_checks(node, 3, {"RandomUniformInt"});
auto shape = node.get_input(0);
auto minval = node.get_input(1);
auto maxval = node.get_input(2);

// retrieve attributes
auto seed = node.get_attribute<int64_t>("seed", 0);
auto seed2 = node.get_attribute<int64_t>("seed2", 0);
auto ng_et = minval.get_element_type();
auto res = std::make_shared<RandomUniform>(shape, minval, maxval, ng_et, seed, seed2);
set_node_name(node.get_name(), res);
return res->outputs();

auto output_type = minval.get_element_type();
auto random = std::make_shared<RandomUniform>(shape, minval, maxval, output_type, seed, seed2);

set_node_name(node.get_name(), random);
return random->outputs();
}
} // namespace op
} // namespace tensorflow
Expand Down
50 changes: 30 additions & 20 deletions tests/layer_tests/tensorflow_tests/test_tf_RandomUniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,32 @@
import tensorflow as tf
from common.layer_test_class import check_ir_version
from common.tf_layer_test_class import CommonTFLayerTest
from common.utils.tf_utils import permute_nchw_to_nhwc
from openvino.tools.mo.front.common.partial_infer.utils import int64_array

from openvino.tools.mo.front.common.partial_infer.utils import int64_array
from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, connect, \
shaped_data, connect_front


class TestTFRandomUniform(CommonTFLayerTest):
class TestRandomUniform(CommonTFLayerTest):
def create_tf_random_uniform_net(self, global_seed, op_seed, x_shape, min_val, max_val,
input_type, precision,
ir_version, use_new_frontend):
tf.compat.v1.reset_default_graph()

# Create the graph and model
with tf.compat.v1.Session() as sess:
tf_x_shape = x_shape.copy()

tf_x_shape = permute_nchw_to_nhwc(tf_x_shape, use_new_frontend)

x = tf.compat.v1.placeholder(input_type, tf_x_shape, 'Input')
x = tf.compat.v1.placeholder(input_type, x_shape, 'Input')
if global_seed is not None:
tf.compat.v1.random.set_random_seed(global_seed)
random_uniform = tf.random.uniform(tf_x_shape, seed=op_seed, dtype=input_type,
minval=min_val,
maxval=max_val) + x
tf.random.uniform(x_shape, seed=op_seed, dtype=input_type,
minval=min_val,
maxval=max_val) + x

tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

ref_net = None
if check_ir_version(10, None, ir_version) and not use_new_frontend:

if check_ir_version(10, None, ir_version):
const_for_layer_tests = lambda name, value, shape, shape1: {
**{name + '_dd': {'kind': 'data', 'value': value, 'shape': shape1}},
**{name: {'kind': 'op', 'type': 'Const'}},
Expand Down Expand Up @@ -83,25 +77,41 @@ def create_tf_random_uniform_net(self, global_seed, op_seed, x_shape, min_val, m

return tf_net, ref_net

test_data = [pytest.param(
test_data_basic = [
dict(global_seed=32465, op_seed=48971, min_val=0.0, max_val=1.0, x_shape=[3, 7],
input_type=tf.float32),
marks=pytest.mark.precommit),
dict(global_seed=78132, op_seed=None, min_val=-200, max_val=-50, x_shape=[5, 8],
input_type=tf.int32)
]

@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_tf_fe
def test_random_uniform_basic(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
if ie_device == 'GPU':
pytest.skip("RandomUniform is not supported on GPU")
self._test(
*self.create_tf_random_uniform_net(**params, precision=precision, ir_version=ir_version,
use_new_frontend=use_new_frontend), ie_device,
precision, temp_dir=temp_dir, ir_version=ir_version, use_new_frontend=use_new_frontend,
use_old_api=use_old_api, **params)

test_data_other = [
dict(global_seed=None, op_seed=56197, min_val=-100, max_val=100, x_shape=[6],
input_type=tf.float32),
dict(global_seed=None, op_seed=56197, min_val=-100, max_val=100, x_shape=[1, 2, 1, 1],
input_type=tf.float32),
pytest.param(dict(global_seed=78132, op_seed=None, min_val=-200, max_val=-50, x_shape=[5, 8],
input_type=tf.int32), marks=pytest.mark.precommit_tf_fe),
dict(global_seed=4571, op_seed=48971, min_val=1.5, max_val=2.3, x_shape=[7],
input_type=tf.float32),
dict(global_seed=32465, op_seed=12335, min_val=-150, max_val=-100, x_shape=[18],
input_type=tf.int32)]

@pytest.mark.parametrize("params", test_data)
@pytest.mark.parametrize("params", test_data_other)
@pytest.mark.nightly
def test_tf_random_uniform(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
def test_random_uniform_other(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
if ie_device == 'GPU':
pytest.skip("RandomUniform is not supported on GPU")
self._test(
Expand Down

0 comments on commit 36a16c8

Please sign in to comment.