Skip to content

Commit

Permalink
[Relay][Frontend][ONNX] Add support for op Where (#4184)
Browse files Browse the repository at this point in the history
* Add support for op Where

* Update impl version
  • Loading branch information
soiferj authored and jroesch committed Oct 27, 2019
1 parent 9cc7874 commit 07606e4
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
10 changes: 9 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,13 @@ class Erf(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
return _op.erf(inputs[0])

class Where(OnnxOpConverter):
"""Operator converter for Where
"""
@classmethod
def _impl_v9(cls, inputs, attr, params):
return _op.where(inputs[0], inputs[1], inputs[2])


# compatible operators that do NOT require any conversion.
_identity_list = []
Expand Down Expand Up @@ -1042,7 +1049,8 @@ def _get_convert_map(opset):
'Not': Not.get_converter(opset),
'And': And.get_converter(opset),
'Tile': Tile.get_converter(opset),
'Erf': Erf.get_converter(opset)
'Erf': Erf.get_converter(opset),
'Where': Where.get_converter(opset)
}


Expand Down
27 changes: 27 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,32 @@ def test_erf():
z = scipy.special.erf(x)
verify_erf(x, z)

def verify_where(condition, x, y, dtype, outdata):
node = helper.make_node('Where', inputs=['condition', 'x', 'y'], outputs=['out'])
graph = helper.make_graph([node],
'where_test',
inputs=[helper.make_tensor_value_info('condition', TensorProto.BOOL, list(condition.shape)),
helper.make_tensor_value_info('x', dtype, list(x.shape)),
helper.make_tensor_value_info('y', dtype, list(y.shape))],
outputs=[helper.make_tensor_value_info('out', dtype, list(outdata.shape))])
model = helper.make_model(graph, producer_name='where_test')

for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [condition, x, y], target, ctx, outdata.shape)
tvm.testing.assert_allclose(outdata, tvm_out)

def test_where():
condition = np.array([[1, 0], [1, 1]], dtype=np.bool)
x = np.array([[1, 2], [3, 4]], dtype=np.int64)
y = np.array([[9, 8], [7, 6]], dtype=np.int64)
outdata = np.where(condition, x, y)
verify_where(condition, x, y, TensorProto.INT64, outdata)

x = np.array([[1, 2], [3, 4]], dtype=np.float32)
y = np.array([[9, 8], [7, 6]], dtype=np.float32)
outdata = np.where(condition, x, y)
verify_where(condition, x, y, TensorProto.FLOAT, outdata)


if __name__ == '__main__':
test_flatten()
Expand Down Expand Up @@ -1347,3 +1373,4 @@ def test_erf():
test_and()
test_tile()
test_erf()
test_where()

0 comments on commit 07606e4

Please sign in to comment.