From dafcf321f1b9e7e12f868b460feb019c1e9ae837 Mon Sep 17 00:00:00 2001 From: Balaram Makam Date: Tue, 31 Jan 2023 15:38:24 -0800 Subject: [PATCH] [Frontend][Tensorflow] Update Select to SelectV2. Fixes #13855 --- python/tvm/relay/frontend/tensorflow_ops.py | 22 ++++++++++++++++-- .../frontend/tensorflow/test_forward.py | 23 +++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index e9bb15e1d1c6c..ab773f9a2a8a1 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -2380,6 +2380,24 @@ def _impl(inputs, attr, params, mod): def _where(): + def _impl(inputs, attr, params, mod): + if len(inputs) == 1: + return AttrCvt(op_name="argwhere")(inputs, attr) + cond_shape = _infer_shape(inputs[0], mod) + x_shape = _infer_shape(inputs[1], mod) + # Due to difference in broadcast behavior between Select and SelectV2, + # we adjust condition dimension with expand_dim and then broadcast. + if len(cond_shape) == 1 and cond_shape[0] == x_shape[0]: + for _ in range(len(x_shape) - 1): + inputs[0] = _op.expand_dims(inputs[0], axis=-1) + broadcast_cond = _op.broadcast_to(inputs[0], x_shape) + inputs[0] = _op.cast(broadcast_cond, "bool") + return AttrCvt(op_name="where")(inputs, attr) + + return _impl + + +def _where_v2(): def _impl(inputs, attr, params, mod): if len(inputs) == 1: return AttrCvt(op_name="argwhere")(inputs, attr) @@ -3088,7 +3106,7 @@ def _impl(inputs, attr, params, mod): "Round": AttrCvt("round"), "Rsqrt": _rsqrt(), "Select": _where(), - "SelectV2": _where(), + "SelectV2": _where_v2(), "Selu": _selu(), "Shape": _shape(), "Sigmoid": AttrCvt("sigmoid"), @@ -3142,6 +3160,6 @@ def _impl(inputs, attr, params, mod): "UniqueWithCounts": _unique(True), "Unpack": _unpack(), "UnravelIndex": _unravel_index(), - "Where": _where(), + "Where": _where_v2(), "ZerosLike": AttrCvt("zeros_like"), } diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 2fb7c74f60a1e..1e1bd435d51fc 100755 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1234,6 +1234,29 @@ def test_forward_argwhere(): _test_forward_where((5, 5, 5, 5, 5)) +def _test_forward_where_with_broadcast(in_shape, cond_shape): + choice_list = list(np.arange(10).astype("float32")) + t1 = np.random.choice(choice_list, size=cond_shape) + t2 = np.random.choice(choice_list, size=cond_shape) + x = np.random.choice(choice_list, size=in_shape) + y = np.random.choice(choice_list, size=in_shape) + + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=cond_shape, dtype="float32", name="in1") + in2 = tf.placeholder(shape=cond_shape, dtype="float32", name="in2") + condition = math_ops.less(in1, in2, name="less") + lhs = tf.placeholder(shape=in_shape, dtype="float32", name="x") + rhs = tf.placeholder(shape=in_shape, dtype="float32", name="y") + out = tf.where(condition, lhs, rhs) + compare_tf_with_tvm([t1, t2, x, y], ["in1:0", "in2:0", "x:0", "y:0"], out.name) + + +def test_forward_where_with_broadcast(): + _test_forward_where_with_broadcast((5, 2), (5,)) + _test_forward_where_with_broadcast((5, 7), (5,)) + _test_forward_where_with_broadcast((3, 2, 5), (3,)) + + ####################################################################### # SpaceToBatchND # --------------