Skip to content

Commit

Permalink
[Relay][Frontend][ONNX] Broadcast condition, x, and y for Where op (#…
Browse files Browse the repository at this point in the history
…4774)

* ONNX frontend broadcast condition

* fix

* fix style

Co-authored-by: Jon Soifer <[email protected]>
  • Loading branch information
2 people authored and jroesch committed Jan 27, 2020
1 parent f71a10c commit de919cb
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
29 changes: 24 additions & 5 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,14 +1105,33 @@ class Where(OnnxOpConverter):
"""
@classmethod
def _impl_v9(cls, inputs, attr, params):
# x and y can be broadcasted
condition_shape = infer_shape(inputs[0])
x_shape = infer_shape(inputs[1])
y_shape = infer_shape(inputs[2])
if len(condition_shape) > len(x_shape):
inputs[1] = _op.broadcast_to(inputs[1], condition_shape)
if len(condition_shape) > len(y_shape):
inputs[2] = _op.broadcast_to(inputs[2], condition_shape)

# condition, x, and y can all be broadcasted.
# broadcast each of them to the longest shape.
# if two shapes have the same number of dimensions,
# try to choose the one that doesn't have "1" as
# a dimension.
shapes = [condition_shape, x_shape, y_shape]
shape_lens = [len(shape) for shape in shapes]
max_size = max(shape_lens)
max_size_idxs = [i for i, x in enumerate(shape_lens) if x == max_size]
broadcast_idx = max_size_idxs[0]
if len(max_size_idxs) > 1:
for idx in max_size_idxs:
if 1 not in shapes[idx]:
broadcast_idx = idx

broadcast_shape = shapes[broadcast_idx]

if condition_shape != broadcast_shape:
inputs[0] = _op.broadcast_to(inputs[0], broadcast_shape)
if x_shape != broadcast_shape:
inputs[1] = _op.broadcast_to(inputs[1], broadcast_shape)
if y_shape != broadcast_shape:
inputs[2] = _op.broadcast_to(inputs[2], broadcast_shape)
return _op.where(inputs[0], inputs[1], inputs[2])

class Or(Elemwise):
Expand Down
16 changes: 16 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,6 +1684,22 @@ def test_where():
outdata = np.where(condition, x, y)
verify_where(condition, x, y, TensorProto.FLOAT, outdata)

x = np.array([2], dtype=np.float32)
y = np.array(1, dtype=np.float32)
outdata = np.where(condition, x, y)
verify_where(condition, x, y, TensorProto.FLOAT, outdata)

condition = np.array(1, dtype=np.bool)
x = np.array([[1, 2], [3, 4]], dtype=np.float32)
y = np.array([[5, 6], [7, 8]], dtype=np.float32)
outdata = np.where(condition, x, y)
verify_where(condition, x, y, TensorProto.FLOAT, outdata)

x = np.array([[1, 2], [3, 4]], dtype=np.float32)
y = np.array([[1], [7]], dtype=np.float32)
outdata = np.where(condition, x, y)
verify_where(condition, x, y, TensorProto.FLOAT, outdata)


def verify_or(indata, dtype):
x = indata[0].astype(dtype)
Expand Down

0 comments on commit de919cb

Please sign in to comment.