diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a7f787484b2c..b007b41e61fe 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -922,6 +922,13 @@ class Erf(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): return _op.erf(inputs[0]) +class Where(OnnxOpConverter): + """Operator converter for Where + """ + @classmethod + def _impl_v9(cls, inputs, attr, params): + return _op.where(inputs[0], inputs[1], inputs[2]) + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1042,7 +1049,8 @@ def _get_convert_map(opset): 'Not': Not.get_converter(opset), 'And': And.get_converter(opset), 'Tile': Tile.get_converter(opset), - 'Erf': Erf.get_converter(opset) + 'Erf': Erf.get_converter(opset), + 'Where': Where.get_converter(opset) } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 16e717401174..3d1262f436bb 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1299,6 +1299,32 @@ def test_erf(): z = scipy.special.erf(x) verify_erf(x, z) +def verify_where(condition, x, y, dtype, outdata): + node = helper.make_node('Where', inputs=['condition', 'x', 'y'], outputs=['out']) + graph = helper.make_graph([node], + 'where_test', + inputs=[helper.make_tensor_value_info('condition', TensorProto.BOOL, list(condition.shape)), + helper.make_tensor_value_info('x', dtype, list(x.shape)), + helper.make_tensor_value_info('y', dtype, list(y.shape))], + outputs=[helper.make_tensor_value_info('out', dtype, list(outdata.shape))]) + model = helper.make_model(graph, producer_name='where_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [condition, x, y], target, ctx, outdata.shape) + tvm.testing.assert_allclose(outdata, tvm_out) + +def test_where(): + condition = np.array([[1, 0], [1, 1]], dtype=np.bool) + x = np.array([[1, 2], [3, 4]], dtype=np.int64) + y = np.array([[9, 8], [7, 6]], dtype=np.int64) + outdata = np.where(condition, x, y) + verify_where(condition, x, y, TensorProto.INT64, outdata) + + x = np.array([[1, 2], [3, 4]], dtype=np.float32) + y = np.array([[9, 8], [7, 6]], dtype=np.float32) + outdata = np.where(condition, x, y) + verify_where(condition, x, y, TensorProto.FLOAT, outdata) + if __name__ == '__main__': test_flatten() @@ -1347,3 +1373,4 @@ def test_erf(): test_and() test_tile() test_erf() + test_where()