diff --git a/include/tvm/ir_operator.h b/include/tvm/ir_operator.h index 1324255037cbc..268205ecb07ba 100644 --- a/include/tvm/ir_operator.h +++ b/include/tvm/ir_operator.h @@ -42,12 +42,14 @@ TVM_DLL Expr max(Expr source, Array axis); */ TVM_DLL Expr min(Expr source, Array axis); + // Unary intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline Expr OpName(Expr x) { \ return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \ } \ + TVM_DECLARE_INTRIN_UNARY(exp); TVM_DECLARE_INTRIN_UNARY(tanh); TVM_DECLARE_INTRIN_UNARY(sigmoid); @@ -58,7 +60,14 @@ TVM_DECLARE_INTRIN_UNARY(ceil); TVM_DECLARE_INTRIN_UNARY(round); TVM_DECLARE_INTRIN_UNARY(trunc); +/*! + * \brief Calculate power(x, y) + * \param x The left operand. + * \param y The right operand. + */ inline Expr pow(Expr x, Expr y) { + match_types(x, y); + CHECK(x.type().is_float()) << "power only applies to float"; return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic); } diff --git a/python/tvm/expr.py b/python/tvm/expr.py index df5e51a53be9b..8bf46b7eee625 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -67,19 +67,19 @@ def __neg__(self): return self.__mul__(neg_one) def __lshift__(self, other): - return _make.Call(self.dtype, "shift_left", [self, other], Call.PureIntrinsic, None, 0) + return _make.left_shift(self, other) def __rshift__(self, other): - return _make.Call(self.dtype, "shift_right", [self, other], Call.PureIntrinsic, None, 0) + return _make.right_shift(self, other) def __and__(self, other): - return _make.Call(self.dtype, "bitwise_and", [self, other], Call.PureIntrinsic, None, 0) + return _make.bitwise_and(self, other) def __or__(self, other): - return _make.Call(self.dtype, "bitwise_or", [self, other], Call.PureIntrinsic, None, 0) + return _make.bitwise_or(self, other) def __xor__(self, other): - return _make.Call(self.dtype, "bitwise_xor", [self, other], Call.PureIntrinsic, None, 0) + return _make.bitwise_xor(self, other) def __invert__(self): return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0) diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index d0847aceceb38..f8fbe902ca0bc 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -111,12 +111,25 @@ TVM_REGISTER_API("make.CommReducer") *ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \ }) \ -#define REGISTER_MAKE_BINARY_OP(Node) \ +#define REGISTER_MAKE_BINARY_OP(Node, Func) \ TVM_REGISTER_API("make."#Node) \ .set_body([](TVMArgs args, TVMRetValue *ret) { \ Expr a = args[0], b = args[1]; \ - match_types(a, b); \ - *ret = Node::make(a, b); \ + *ret = (Func(a, b)); \ + }) + +#define REGISTER_MAKE_BIT_OP(Node, Func) \ + TVM_REGISTER_API("make."#Node) \ + .set_body([](TVMArgs args, TVMRetValue *ret) { \ + bool lhs_is_int = args[0].type_code() == kDLInt; \ + bool rhs_is_int = args[1].type_code() == kDLInt; \ + if (lhs_is_int) { \ + *ret = (Func(args[0].operator int(), args[1].operator Expr())); \ + } else if (rhs_is_int) { \ + *ret = (Func(args[0].operator Expr(), args[1].operator int())); \ + } else { \ + *ret = (Func(args[0].operator Expr(), args[1].operator Expr())); \ + } \ }) REGISTER_MAKE5(Reduce); @@ -126,21 +139,26 @@ REGISTER_MAKE2(IntImm); REGISTER_MAKE2(UIntImm); REGISTER_MAKE2(FloatImm); REGISTER_MAKE1(StringImm); -REGISTER_MAKE_BINARY_OP(Add); -REGISTER_MAKE_BINARY_OP(Sub); -REGISTER_MAKE_BINARY_OP(Mul); -REGISTER_MAKE_BINARY_OP(Div); -REGISTER_MAKE_BINARY_OP(Mod); -REGISTER_MAKE_BINARY_OP(Min); -REGISTER_MAKE_BINARY_OP(Max); -REGISTER_MAKE_BINARY_OP(EQ); -REGISTER_MAKE_BINARY_OP(NE); -REGISTER_MAKE_BINARY_OP(LT); -REGISTER_MAKE_BINARY_OP(LE); -REGISTER_MAKE_BINARY_OP(GT); -REGISTER_MAKE_BINARY_OP(GE); -REGISTER_MAKE_BINARY_OP(And); -REGISTER_MAKE_BINARY_OP(Or); +REGISTER_MAKE_BINARY_OP(Add, operator+); +REGISTER_MAKE_BINARY_OP(Sub, operator-); +REGISTER_MAKE_BINARY_OP(Mul, operator*); +REGISTER_MAKE_BINARY_OP(Div, operator/); +REGISTER_MAKE_BINARY_OP(Mod, operator%); +REGISTER_MAKE_BINARY_OP(Min, min); +REGISTER_MAKE_BINARY_OP(Max, max); +REGISTER_MAKE_BINARY_OP(EQ, operator==); +REGISTER_MAKE_BINARY_OP(NE, operator!=); +REGISTER_MAKE_BINARY_OP(LT, operator<); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(LE, operator<=); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(GT, operator>); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(GE, operator>=); +REGISTER_MAKE_BINARY_OP(And, operator&&); +REGISTER_MAKE_BINARY_OP(Or, operator||); +REGISTER_MAKE_BIT_OP(bitwise_and, operator&); +REGISTER_MAKE_BIT_OP(bitwise_or, operator|); +REGISTER_MAKE_BIT_OP(bitwise_xor, operator^); +REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) +REGISTER_MAKE_BIT_OP(right_shift, operator>>); REGISTER_MAKE1(Not); REGISTER_MAKE3(Select); REGISTER_MAKE3(Ramp); diff --git a/tests/python/unittest/test_lang_basic.py b/tests/python/unittest/test_lang_basic.py index 1461ecec100f0..c9a04747b56db 100644 --- a/tests/python/unittest/test_lang_basic.py +++ b/tests/python/unittest/test_lang_basic.py @@ -10,6 +10,8 @@ def test_make(): x = tvm.const(1) y = tvm.make.IntImm('int32', 1) z = x + y + assert isinstance(tvm.max(x, y), tvm.expr.Max) + assert isinstance(tvm.min(x, y), tvm.expr.Min) def test_ir(): x = tvm.const(1) @@ -132,6 +134,9 @@ def test_bitwise(): assert str(x | y) == 'bitwise_or(x, y)' assert str(x ^ y) == 'bitwise_xor(x, y)' assert str(~x) == 'bitwise_not(x)' + assert(tvm.const(1, "int8x2") >> 1).dtype == "int8x2" + assert(x >> tvm.const(1, "int32x2")).dtype == "int32x2" + assert(tvm.var("z", "int8x2") << tvm.const(1, "int8x2")).dtype == "int8x2" def test_equality(): diff --git a/tests/python/unittest/test_pass_inject_copy_intrin.py b/tests/python/unittest/test_pass_inject_copy_intrin.py index a44f3899c282c..370044d85c036 100644 --- a/tests/python/unittest/test_pass_inject_copy_intrin.py +++ b/tests/python/unittest/test_pass_inject_copy_intrin.py @@ -64,6 +64,7 @@ def cb(src, dst, pad_before, pad_after, pad_value): stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) def assert_expr_equal(a, b): + print(a, b) assert tvm.ir_pass.Simplify(a - b).value == 0 def test_copy_pad_split(): @@ -87,6 +88,7 @@ def test_copy_pad_split(): def cb(src, dst, pad_before, pad_after, pad_value): assert(dst.elem_offset.value == 0) assert_expr_equal(src.elem_offset, tvm.max(xo * 4, 1) - 1) + rpad_before = tvm.max(1 - xo * 4, 0) rpad_after = tvm.max(xo * 4 - 7, 0) assert_expr_equal(pad_before[0], rpad_before)