diff --git a/src/tir/ir/op.cc b/src/tir/ir/op.cc index cf1c24c4c7cd..4ad244ff02b2 100644 --- a/src/tir/ir/op.cc +++ b/src/tir/ir/op.cc @@ -469,6 +469,9 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); + if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) << + "Shift amount must be non-negative and less than " << rtype.bits() + << " for type " << rtype; if (pa && pb) return IntImm(rtype, (pa->value >> pb->value)); if (pb) { if (pb->value == 0) return a; @@ -484,6 +487,9 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); + if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) << + "Shift amount must be non-negative and less than " << rtype.bits() + << " for type " << rtype; if (pa && pb) return IntImm(rtype, (pa->value << pb->value)); if (pb) { if (pb->value == 0) return a; diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 7e2c8b55a69b..290495381283 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -207,6 +207,23 @@ def test_float_bitwise(): pass +def test_shift_bounds(): + x = te.var('x') + for test in [lambda lhs, rhs : lhs << rhs, + lambda lhs, rhs : lhs >> rhs]: + #negative case + for testcase in [(x,-1), (x,32)]: + try: + test(*testcase) + assert False + except tvm.TVMError: + pass + + #positive case + for testcase in [(x,0), (x,16), (x,31)]: + test(*testcase) + + def test_divide_by_zero(): for test in [lambda lhs, rhs : tvm.tir.floormod(lhs,rhs), lambda lhs, rhs : tvm.tir.floordiv(lhs,rhs), @@ -293,6 +310,7 @@ def test_vars(): test_all() test_bitwise() test_float_bitwise() + test_shift_bounds() test_divide_by_zero() test_isnan() test_equality()