Skip to content

Commit

Permalink
[OP] Improve bitwise op type checks
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jul 10, 2018
1 parent f0ae174 commit 77b5238
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 23 deletions.
9 changes: 9 additions & 0 deletions include/tvm/ir_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ TVM_DLL Expr max(Expr source, Array<IterVar> axis);
*/
TVM_DLL Expr min(Expr source, Array<IterVar> axis);


// Unary intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \
return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \
} \


TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
Expand All @@ -58,7 +60,14 @@ TVM_DECLARE_INTRIN_UNARY(ceil);
TVM_DECLARE_INTRIN_UNARY(round);
TVM_DECLARE_INTRIN_UNARY(trunc);

/*!
* \brief Calculate power(x, y)
* \param x The left operand.
* \param y The right operand.
*/
inline Expr pow(Expr x, Expr y) {
match_types(x, y);
CHECK(x.type().is_float()) << "power only applies to float";
return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
}

Expand Down
10 changes: 5 additions & 5 deletions python/tvm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,19 @@ def __neg__(self):
return self.__mul__(neg_one)

def __lshift__(self, other):
return _make.Call(self.dtype, "shift_left", [self, other], Call.PureIntrinsic, None, 0)
return _make.left_shift(self, other)

def __rshift__(self, other):
return _make.Call(self.dtype, "shift_right", [self, other], Call.PureIntrinsic, None, 0)
return _make.right_shift(self, other)

def __and__(self, other):
return _make.Call(self.dtype, "bitwise_and", [self, other], Call.PureIntrinsic, None, 0)
return _make.bitwise_and(self, other)

def __or__(self, other):
return _make.Call(self.dtype, "bitwise_or", [self, other], Call.PureIntrinsic, None, 0)
return _make.bitwise_or(self, other)

def __xor__(self, other):
return _make.Call(self.dtype, "bitwise_xor", [self, other], Call.PureIntrinsic, None, 0)
return _make.bitwise_xor(self, other)

def __invert__(self):
return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0)
Expand Down
41 changes: 23 additions & 18 deletions src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,36 +111,41 @@ TVM_REGISTER_API("make.CommReducer")
*ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \
}) \

#define REGISTER_MAKE_BINARY_OP(Node) \
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
Expr a = args[0], b = args[1]; \
match_types(a, b); \
*ret = Node::make(a, b); \
*ret = (Func(a, b)); \
})


REGISTER_MAKE5(Reduce);
REGISTER_MAKE4(AttrStmt);

REGISTER_MAKE2(IntImm);
REGISTER_MAKE2(UIntImm);
REGISTER_MAKE2(FloatImm);
REGISTER_MAKE1(StringImm);
REGISTER_MAKE_BINARY_OP(Add);
REGISTER_MAKE_BINARY_OP(Sub);
REGISTER_MAKE_BINARY_OP(Mul);
REGISTER_MAKE_BINARY_OP(Div);
REGISTER_MAKE_BINARY_OP(Mod);
REGISTER_MAKE_BINARY_OP(Min);
REGISTER_MAKE_BINARY_OP(Max);
REGISTER_MAKE_BINARY_OP(EQ);
REGISTER_MAKE_BINARY_OP(NE);
REGISTER_MAKE_BINARY_OP(LT);
REGISTER_MAKE_BINARY_OP(LE);
REGISTER_MAKE_BINARY_OP(GT);
REGISTER_MAKE_BINARY_OP(GE);
REGISTER_MAKE_BINARY_OP(And);
REGISTER_MAKE_BINARY_OP(Or);
REGISTER_MAKE_BINARY_OP(Add, operator+);
REGISTER_MAKE_BINARY_OP(Sub, operator-);
REGISTER_MAKE_BINARY_OP(Mul, operator*);
REGISTER_MAKE_BINARY_OP(Div, operator/);
REGISTER_MAKE_BINARY_OP(Mod, operator%);
REGISTER_MAKE_BINARY_OP(Min, min);
REGISTER_MAKE_BINARY_OP(Max, max);
REGISTER_MAKE_BINARY_OP(EQ, operator==);
REGISTER_MAKE_BINARY_OP(NE, operator!=);
REGISTER_MAKE_BINARY_OP(LT, operator<); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(LE, operator<=); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(GT, operator>); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(GE, operator>=);
REGISTER_MAKE_BINARY_OP(And, operator&&);
REGISTER_MAKE_BINARY_OP(Or, operator||);
REGISTER_MAKE_BINARY_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(right_shift, operator>>);
REGISTER_MAKE_BINARY_OP(bitwise_and, operator&);
REGISTER_MAKE_BINARY_OP(bitwise_or, operator|);
REGISTER_MAKE_BINARY_OP(bitwise_xor, operator^);
REGISTER_MAKE1(Not);
REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
Expand Down
4 changes: 4 additions & 0 deletions tests/python/unittest/test_lang_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ def test_make():
x = tvm.const(1)
y = tvm.make.IntImm('int32', 1)
z = x + y
assert isinstance(tvm.max(x, y), tvm.expr.Max)
assert isinstance(tvm.min(x, y), tvm.expr.Min)

def test_ir():
x = tvm.const(1)
Expand Down Expand Up @@ -132,6 +134,8 @@ def test_bitwise():
assert str(x | y) == 'bitwise_or(x, y)'
assert str(x ^ y) == 'bitwise_xor(x, y)'
assert str(~x) == 'bitwise_not(x)'
assert(x >> tvm.const(1, "int32x2")).dtype == "int32x2"
assert(tvm.var("z", "int8x2") << tvm.const(1, "int8x2")).dtype == "int8x2"


def test_equality():
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_pass_inject_copy_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def cb(src, dst, pad_before, pad_after, pad_value):
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)

def assert_expr_equal(a, b):
print(a, b)
assert tvm.ir_pass.Simplify(a - b).value == 0

def test_copy_pad_split():
Expand All @@ -87,6 +88,7 @@ def test_copy_pad_split():
def cb(src, dst, pad_before, pad_after, pad_value):
assert(dst.elem_offset.value == 0)
assert_expr_equal(src.elem_offset, tvm.max(xo * 4, 1) - 1)

rpad_before = tvm.max(1 - xo * 4, 0)
rpad_after = tvm.max(xo * 4 - 7, 0)
assert_expr_equal(pad_before[0], rpad_before)
Expand Down

0 comments on commit 77b5238

Please sign in to comment.