diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index 73332031aa4d..e8f164c1c579 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -419,7 +419,7 @@ def power(x, y): z : Expr The result. """ - return call_pure_intrin(x.dtype, "pow", x, y) + return _make._OpPow(convert(x), convert(y)) def popcount(x): @@ -482,12 +482,7 @@ def if_then_else(cond, t, f): Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions. """ - t = convert(t) - f = convert(f) - cond = convert(cond) - if cond.dtype != "bool": - raise TypeError("The condition's data type has to be bool") - return call_pure_intrin(t.dtype, "tvm_if_then_else", cond, t, f) + return _make._OpIfThenElse(convert(cond), convert(t), convert(f)) # Intrinsic rule related code diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 0e86ffb3863a..2216793898e3 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -196,6 +196,7 @@ REGISTER_MAKE_BINARY_OP(_OpDiv, operator/); REGISTER_MAKE_BINARY_OP(_OpMod, operator%); REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv); REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod); +REGISTER_MAKE_BINARY_OP(_OpPow, pow); REGISTER_MAKE_BINARY_OP(_OpMin, min); REGISTER_MAKE_BINARY_OP(_OpMax, max); REGISTER_MAKE_BINARY_OP(_OpEQ, operator==); @@ -211,6 +212,10 @@ REGISTER_MAKE_BIT_OP(bitwise_or, operator|); REGISTER_MAKE_BIT_OP(bitwise_xor, operator^); REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) REGISTER_MAKE_BIT_OP(right_shift, operator>>); +TVM_REGISTER_API("make._OpIfThenElse") +.set_body_typed([] (Expr cond, Expr true_value, Expr false_value) { + return if_then_else(cond, true_value, false_value); +}); } // namespace ir } // namespace tvm diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 6383b71b3b9b..cd61ccaa0147 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -238,7 +238,7 @@ Expr if_then_else(Expr cond, Expr true_value, Expr false_value) { using ir::IntImm; using ir::UIntImm; CHECK(cond.type() == Bool(1)) - << "if_then_else only accept a single condition"; + << "if_then_else only accept the condition to be boolean type."; BinaryOpMatchTypes(true_value, false_value); if (const UIntImm* op = cond.as()) { if (op->value != 0) { diff --git a/tests/python/unittest/test_lang_container.py b/tests/python/unittest/test_lang_container.py index 999e379ca48a..206e143029cf 100644 --- a/tests/python/unittest/test_lang_container.py +++ b/tests/python/unittest/test_lang_container.py @@ -27,7 +27,7 @@ def test_array_save_load_json(): a = tvm.convert([1,2,3]) json_str = tvm.save_json(a) a_loaded = tvm.load_json(json_str) - assert(a[1].value == 2) + assert(a_loaded[1].value == 2) def test_map(): diff --git a/tests/python/unittest/test_lang_operator.py b/tests/python/unittest/test_lang_operator.py index 8e7dcba3230c..2d30d29dac48 100644 --- a/tests/python/unittest/test_lang_operator.py +++ b/tests/python/unittest/test_lang_operator.py @@ -16,6 +16,15 @@ # under the License. import tvm +def check_throws(f): + try: + f() + except tvm.TVMError: + pass + else: + raise AssertionError("Should have raised an exception but didn't.") + + def test_const_fold(): def check(f, *args): x = f(*[tvm.const(x, "int32") for x in args]) @@ -47,14 +56,6 @@ def test_const_fold2(): assert isinstance((1 / x), tvm.expr.Div) def test_const_fold3(): - def check_throws(f): - try: - f() - except tvm.TVMError: - pass - else: - raise AssertionError("Should have raised an exception but didn't.") - # Test that using ints with logic operations is forbidden x = tvm.var("x") for val in [0, 1]: @@ -100,8 +101,92 @@ def test_const_fold4(): assert isinstance(y, tvm.expr.IntImm) and y.value == 6 +def test_binary_dtype_match(): + def verify_general_dtype_support(f, is_conditional=False): + rules = [[('bool', 'int32'), 'int32'], + [('int32', 'float32'), 'float32'], + [('int32', 'int64'), 'int64'], + [('uint32', 'int32'), 'int32']] + for (lhs_dtype, rhs_dtype), out_dtype in rules: + lhs = tvm.var('lhs', dtype=lhs_dtype) + rhs = tvm.var('rhs', dtype=rhs_dtype) + out = f(lhs, rhs) + if not is_conditional: + assert out.dtype == out_dtype + else: + assert out.dtype == 'bool' + if hasattr(out, 'a'): + assert out.a.dtype == out_dtype + assert out.b.dtype == out_dtype + elif hasattr(out, 'args'): + # CallOp + assert out.args[0].dtype == out_dtype + assert out.args[1].dtype == out_dtype + else: + raise ValueError('Unknown binary op format!') + + def verify_callop_float_only(f): + for lhs_dtype in ['int32', 'float32', 'float64']: + for rhs_dtype in ['int32', 'float32', 'float64']: + lhs = tvm.var('lhs', dtype=lhs_dtype) + rhs = tvm.var('rhs', dtype=rhs_dtype) + if 'float' not in lhs_dtype and 'float' not in rhs_dtype: + check_throws(lambda: f(lhs, rhs)) + elif 'float' in lhs_dtype and 'float' in rhs_dtype and lhs_dtype != rhs_dtype: + check_throws(lambda: f(lhs, rhs)) + elif 'float' in lhs_dtype: + out = f(lhs, rhs) + assert out.dtype == lhs_dtype + assert out.args[0].dtype == lhs_dtype + assert out.args[1].dtype == lhs_dtype + else: + out = f(lhs, rhs) + assert out.dtype == rhs_dtype + assert out.args[0].dtype == rhs_dtype + assert out.args[1].dtype == rhs_dtype + + verify_general_dtype_support(lambda a, b: a + b) + verify_general_dtype_support(lambda a, b: a * b) + verify_general_dtype_support(lambda a, b: a >= b, is_conditional=True) + verify_general_dtype_support(lambda a, b: a <= b, is_conditional=True) + verify_callop_float_only(lambda a, b: tvm.power(a, b)) + + +def test_if_then_else(): + cases = [[(tvm.var('cond', dtype='bool'), 'bool', 'int32'), 'int32'], + [(True, 'int32', 'float32'), 'float32'], + [(False, 'int32', 'int64'), 'int64'], + [(tvm.var('cond', dtype='bool'), 'uint32', 'int32'), 'int32'], + [(tvm.var('cond', dtype='int32'), 'uint32', 'int32'), 'int32']] + for (cond, lhs_dtype, rhs_dtype), out_dtype in cases: + lhs = tvm.var('lhs', dtype=lhs_dtype) + rhs = tvm.var('rhs', dtype=rhs_dtype) + if cond is True or cond is False: + out = tvm.if_then_else(cond, lhs, rhs) + out2 = tvm.if_then_else(not cond, rhs, lhs) + out3 = tvm.if_then_else(not cond, lhs, rhs) + assert tvm.ir_pass.Equal(out, out2) == 1 + if cond: + assert tvm.ir_pass.Equal(out, lhs.astype(out_dtype)) == 1 + assert tvm.ir_pass.Equal(out3, rhs.astype(out_dtype)) == 1 + else: + assert tvm.ir_pass.Equal(out, rhs.astype(out_dtype)) == 1 + assert tvm.ir_pass.Equal(out3, lhs.astype(out_dtype)) == 1 + elif cond.dtype == 'bool': + out = tvm.if_then_else(cond, lhs, rhs) + assert out.dtype == out_dtype + assert out.args[1].dtype == out_dtype + assert out.args[2].dtype == out_dtype + elif cond.dtype != 'bool': + check_throws(lambda: tvm.if_then_else(cond, lhs, rhs)) + else: + raise ValueError('Unknown combinations') + + if __name__ == "__main__": test_const_fold() test_const_fold2() test_const_fold3() test_const_fold4() + test_binary_dtype_match() + test_if_then_else()