Skip to content

Commit

Permalink
[Frontend][Tensorflow] Update Select to SelectV2 (#13884)
Browse files Browse the repository at this point in the history
Fixes #13855
  • Loading branch information
balaram-cadence authored Feb 9, 2023
1 parent 1de5c72 commit 5cf3405
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
22 changes: 20 additions & 2 deletions python/tvm/relay/frontend/tensorflow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2380,6 +2380,24 @@ def _impl(inputs, attr, params, mod):


def _where():
def _impl(inputs, attr, params, mod):
if len(inputs) == 1:
return AttrCvt(op_name="argwhere")(inputs, attr)
cond_shape = _infer_shape(inputs[0], mod)
x_shape = _infer_shape(inputs[1], mod)
# Due to difference in broadcast behavior between Select and SelectV2,
# we adjust condition dimension with expand_dim and then broadcast.
if len(cond_shape) == 1 and cond_shape[0] == x_shape[0]:
for _ in range(len(x_shape) - 1):
inputs[0] = _op.expand_dims(inputs[0], axis=-1)
broadcast_cond = _op.broadcast_to(inputs[0], x_shape)
inputs[0] = _op.cast(broadcast_cond, "bool")
return AttrCvt(op_name="where")(inputs, attr)

return _impl


def _where_v2():
def _impl(inputs, attr, params, mod):
if len(inputs) == 1:
return AttrCvt(op_name="argwhere")(inputs, attr)
Expand Down Expand Up @@ -3088,7 +3106,7 @@ def _impl(inputs, attr, params, mod):
"Round": AttrCvt("round"),
"Rsqrt": _rsqrt(),
"Select": _where(),
"SelectV2": _where(),
"SelectV2": _where_v2(),
"Selu": _selu(),
"Shape": _shape(),
"Sigmoid": AttrCvt("sigmoid"),
Expand Down Expand Up @@ -3142,6 +3160,6 @@ def _impl(inputs, attr, params, mod):
"UniqueWithCounts": _unique(True),
"Unpack": _unpack(),
"UnravelIndex": _unravel_index(),
"Where": _where(),
"Where": _where_v2(),
"ZerosLike": AttrCvt("zeros_like"),
}
23 changes: 23 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,29 @@ def test_forward_argwhere():
_test_forward_where((5, 5, 5, 5, 5))


def _test_forward_where_with_broadcast(in_shape, cond_shape):
choice_list = list(np.arange(10).astype("float32"))
t1 = np.random.choice(choice_list, size=cond_shape)
t2 = np.random.choice(choice_list, size=cond_shape)
x = np.random.choice(choice_list, size=in_shape)
y = np.random.choice(choice_list, size=in_shape)

with tf.Graph().as_default():
in1 = tf.placeholder(shape=cond_shape, dtype="float32", name="in1")
in2 = tf.placeholder(shape=cond_shape, dtype="float32", name="in2")
condition = math_ops.less(in1, in2, name="less")
lhs = tf.placeholder(shape=in_shape, dtype="float32", name="x")
rhs = tf.placeholder(shape=in_shape, dtype="float32", name="y")
out = tf.where(condition, lhs, rhs)
compare_tf_with_tvm([t1, t2, x, y], ["in1:0", "in2:0", "x:0", "y:0"], out.name)


def test_forward_where_with_broadcast():
_test_forward_where_with_broadcast((5, 2), (5,))
_test_forward_where_with_broadcast((5, 7), (5,))
_test_forward_where_with_broadcast((3, 2, 5), (3,))


#######################################################################
# SpaceToBatchND
# --------------
Expand Down

0 comments on commit 5cf3405

Please sign in to comment.