diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 4f6dd743b524..b7fe2cf62b5a 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -877,6 +877,14 @@ def _impl_v1(cls, inputs, attr, params): return _op.logical_not(inputs[0]) +class And(Elemwise): + """ Operator converter for And. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + return _op.logical_and(inputs[0], inputs[1]) + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -993,7 +1001,8 @@ def _get_convert_map(opset): 'Shape': Shape.get_converter(opset), 'Sign': Sign.get_converter(opset), 'Equal': Equal.get_converter(opset), - 'Not': Not.get_converter(opset) + 'Not': Not.get_converter(opset), + 'And': And.get_converter(opset) } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index e4c161dd8908..7e0e11f4686e 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1158,6 +1158,53 @@ def test_not(): verify_not(indata=(np.random.randn(3, 4, 5, 6) > 0), dtype=bool) +def verify_and(indata, dtype): + x = indata[0].astype(dtype) + y = indata[1].astype(dtype) + outdata = np.logical_and(x, y) + + node = helper.make_node('And', inputs=['in1', 'in2'], outputs=['out'], ) + + graph = helper.make_graph([node], + 'and_test', + inputs=[helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)), + helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))]) + + model = helper.make_model(graph, producer_name='and_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [x, y], target, ctx, outdata.shape) + tvm.testing.assert_allclose(outdata, tvm_out) + + +def test_and(): + # 2d + x = (np.random.randn(3, 4) > 0) + y = (np.random.randn(3, 4) > 0) + verify_and(indata=[x, y], dtype=bool) + + # 3d + x = (np.random.randn(3, 4, 5) > 0) + y = (np.random.randn(3, 4, 5) > 0) + verify_and(indata=[x, y], dtype=bool) + + # 4d + x = (np.random.randn(3, 4, 5, 6) > 0) + y = (np.random.randn(3, 4, 5, 6) > 0) + verify_and(indata=[x, y], dtype=bool) + + # 3d vs 1d + x = (np.random.randn(3, 4, 5) > 0) + y = (np.random.randn(5) > 0) + verify_and(indata=[x, y], dtype=bool) + + # 3d vs 2d + x = (np.random.randn(3, 4, 5) > 0) + y = (np.random.randn(4, 5) > 0) + verify_and(indata=[x, y], dtype=bool) + + if __name__ == '__main__': test_flatten() test_reshape() @@ -1202,3 +1249,4 @@ def test_not(): test_densenet() test_sign() test_not() + test_and()