diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index b4a3697ad8f19..c1d02bd56d1b5 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -44,6 +44,7 @@ register_schedule("abs", schedule_broadcast) register_schedule("tanh", schedule_broadcast) register_schedule("logical_not", schedule_broadcast) +register_schedule("bitwise_not", schedule_broadcast) register_schedule("negative", schedule_broadcast) register_schedule("copy", schedule_broadcast) @@ -57,6 +58,9 @@ register_schedule("floor_mod", schedule_broadcast) register_schedule("logical_and", schedule_broadcast) register_schedule("logical_or", schedule_broadcast) +register_schedule("bitwise_and", schedule_broadcast) +register_schedule("bitwise_or", schedule_broadcast) +register_schedule("bitwise_xor", schedule_broadcast) register_schedule("equal", schedule_broadcast) register_schedule("not_equal", schedule_broadcast) register_schedule("less", schedule_broadcast) @@ -194,6 +198,9 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("floor_mod", False, broadcast_shape_func) register_shape_func("logical_and", False, broadcast_shape_func) register_shape_func("logical_or", False, broadcast_shape_func) +register_shape_func("bitwise_and", False, broadcast_shape_func) +register_shape_func("bitwise_or", False, broadcast_shape_func) +register_shape_func("bitwise_xor", False, broadcast_shape_func) register_shape_func("equal", False, broadcast_shape_func) register_shape_func("not_equal", False, broadcast_shape_func) register_shape_func("less", False, broadcast_shape_func) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 0d3f09873be85..f1f8dd5a8c904 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -318,6 +318,22 @@ def logical_not(data): return _make.logical_not(data) +def bitwise_not(data): + """Compute element-wise bitwise not of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.bitwise_not(data) + + def add(lhs, rhs): """Addition with numpy-style broadcasting. @@ -506,6 +522,60 @@ def logical_or(lhs, rhs): return _make.logical_or(lhs, rhs) +def bitwise_and(lhs, rhs): + """bitwise AND with numpy-style broadcasting. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.bitwise_and(lhs, rhs) + + +def bitwise_or(lhs, rhs): + """bitwise OR with numpy-style broadcasting. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.bitwise_or(lhs, rhs) + + +def bitwise_xor(lhs, rhs): + """bitwise XOR with numpy-style broadcasting. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.bitwise_xor(lhs, rhs) + + def equal(lhs, rhs): """Broadcasted elementwise test for (lhs == rhs). diff --git a/src/relay/op/tensor/binary.cc b/src/relay/op/tensor/binary.cc index 6561dd16819d1..d1b915cfa1429 100644 --- a/src/relay/op/tensor/binary.cc +++ b/src/relay/op/tensor/binary.cc @@ -124,6 +124,24 @@ RELAY_REGISTER_BINARY_OP("logical_or") .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or)); +RELAY_REGISTER_BINARY_OP("bitwise_and") +.describe("Elementwise bitwise AND with broadcasting") +.set_support_level(4) +.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_and)); + + +RELAY_REGISTER_BINARY_OP("bitwise_or") +.describe("Elementwise bitwise OR with broadcasting") +.set_support_level(4) +.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_or)); + + +RELAY_REGISTER_BINARY_OP("bitwise_xor") +.describe("Elementwise bitwise XOR with broadcasting") +.set_support_level(4) +.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_xor)); + + RELAY_REGISTER_CMP_OP("equal") .describe("Elementwise equal compare with broadcasting") .set_support_level(4) diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index d85d316e523e8..7f6db50bf7027 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -266,13 +266,24 @@ RELAY_REGISTER_UNARY_OP("logical_not") .describe(R"code(Returns the logical inverse of input array, computed element-wise. .. math:: - ~(x) + !(x) )code" TVM_ADD_FILELINE) .set_support_level(4) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not)); +RELAY_REGISTER_UNARY_OP("bitwise_not") +.describe(R"code(Returns the bitwise inverse of input array, computed element-wise. + +.. math:: + ~(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(4) +.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::bitwise_not)); + + // shape_of TVM_REGISTER_NODE_TYPE(ShapeOfAttrs); diff --git a/topi/include/topi/broadcast.h b/topi/include/topi/broadcast.h index a56d206e6ba9e..30bc584272e87 100644 --- a/topi/include/topi/broadcast.h +++ b/topi/include/topi/broadcast.h @@ -140,6 +140,48 @@ TOPI_DEFINE_OP_OVERLOAD(operator&&, logical_and); TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; }); TOPI_DEFINE_OP_OVERLOAD(operator||, logical_or); +/*! + * \fn bitwise_and + * \brief Compute A & B with auto-broadcasting. + * + * \param A The first tensor, or Expr + * \param B The second tensor, or Expr + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return The result. + */ +TOPI_DEFINE_BCAST_OP(bitwise_and, { return a & b; }); +TOPI_DEFINE_OP_OVERLOAD(operator&, bitwise_and); + +/*! + * \fn bitwise_or + * \brief Compute A | B with auto-broadcasting. + * + * \param A The first tensor, or Expr + * \param B The second tensor, or Expr + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return The result. + */ +TOPI_DEFINE_BCAST_OP(bitwise_or, { return a | b; }); +TOPI_DEFINE_OP_OVERLOAD(operator|, bitwise_or); + +/*! + * \fn bitwise_xor + * \brief Compute A ^ B with auto-broadcasting. + * + * \param A The first tensor, or Expr + * \param B The second tensor, or Expr + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return The result. + */ +TOPI_DEFINE_BCAST_OP(bitwise_xor, { return a ^ b; }); +TOPI_DEFINE_OP_OVERLOAD(operator^, bitwise_xor); + /*! * \fn add * \brief Compute A + B with auto-broadcasting. diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index b6343f19c076c..e3f4678c11637 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -178,6 +178,23 @@ inline Tensor logical_not(const Tensor& x, }, name, tag); } +/*! +* \brief Creates an operation that returns the bitwise NOT of a given tensor +* +* \param x The input tensor +* \param name The name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor whose op member is the bitwise NOT operation +*/ +inline Tensor bitwise_not(const Tensor& x, + std::string name = "T_bitwise_not", + std::string tag = kElementWise) { + return compute(x->shape, [&](const Array& i) { + return ~x(i); + }, name, tag); +} + /*! * \brief Returns the sign of the tensor * diff --git a/topi/python/topi/broadcast.py b/topi/python/topi/broadcast.py index 0ed4e0aafc35b..ba39c9aed35b2 100644 --- a/topi/python/topi/broadcast.py +++ b/topi/python/topi/broadcast.py @@ -420,6 +420,63 @@ def logical_or(lhs, rhs): return _cpp.logical_or(lhs, rhs) +def bitwise_and(lhs, rhs): + """Compute element-wise bitwise and of data. + + Parameters + ---------- + lhs : tvm.Tensor or Expr + The left operand + rhs : tvm.Tensor or Expr + The right operand + + Returns + ------- + ret : tvm.Tensor or Expr + Returns Expr if both operands are Expr. + Otherwise returns Tensor. + """ + return _cpp.bitwise_and(lhs, rhs) + + +def bitwise_or(lhs, rhs): + """Compute element-wise bitwise or of data. + + Parameters + ---------- + lhs : tvm.Tensor or Expr + The left operand + rhs : tvm.Tensor or Expr + The right operand + + Returns + ------- + ret : tvm.Tensor or Expr + Returns Expr if both operands are Expr. + Otherwise returns Tensor. + """ + return _cpp.bitwise_or(lhs, rhs) + + +def bitwise_xor(lhs, rhs): + """Compute element-wise bitwise xor of data. + + Parameters + ---------- + lhs : tvm.Tensor or Expr + The left operand + rhs : tvm.Tensor or Expr + The right operand + + Returns + ------- + ret : tvm.Tensor or Expr + Returns Expr if both operands are Expr. + Otherwise returns Tensor. + """ + return _cpp.bitwise_xor(lhs, rhs) + + def logical_not(data): """Compute element-wise logical not of data. @@ -434,3 +491,19 @@ def logical_not(data): Otherwise returns Tensor. """ return _cpp.logical_not(data) + + +def bitwise_not(data): + """Compute element-wise bitwise not of data. + + Parameters + ---------- + data : tvm.Tensor or Expr + + Returns + ------- + ret : tvm.Tensor or Expr + Returns Expr if the operand are Expr. + Otherwise returns Tensor. + """ + return _cpp.bitwise_not(data) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 95422d86974e1..2b2142bb57590 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -133,6 +133,9 @@ TOPI_REGISTER_BCAST_OP("topi.power", topi::power); TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift); TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and); TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or); +TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and); +TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or); +TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor); TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift); TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater); TOPI_REGISTER_BCAST_OP("topi.less", topi::less); @@ -151,6 +154,11 @@ TVM_REGISTER_GLOBAL("topi.logical_not") *rv = logical_not(args[0]); }); +TVM_REGISTER_GLOBAL("topi.bitwise_not") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = bitwise_not(args[0]); + }); + /* Ops from elemwise.h */ TVM_REGISTER_GLOBAL("topi.exp") .set_body([](TVMArgs args, TVMRetValue *rv) { diff --git a/topi/tests/python/test_topi_broadcast.py b/topi/tests/python/test_topi_broadcast.py index 4361f8fb675cd..5a0a940d3d7b2 100644 --- a/topi/tests/python/test_topi_broadcast.py +++ b/topi/tests/python/test_topi_broadcast.py @@ -270,6 +270,47 @@ def check_device(device): test_apply(topi.logical_not, "logical_not", np.logical_not, np.array(np.arange(5) < 3)) +def test_bitwise_not(): + def test_apply( + func, + name, + f_numpy, + shape, + dtype="int32", + ): + # Build the logic and compile the function + A = tvm.placeholder(shape=shape, name="A", dtype=dtype) + B = func(A) + + if isinstance(A, tvm.expr.PrimExpr): + assert (isinstance(B, tvm.expr.PrimExpr)) + return + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_broadcast(B) + foo = tvm.build(s, [A, B], device, name=name) + + data_npy = np.random.uniform(size=shape).astype(A.dtype) + data_nd = tvm.nd.array(data_npy, ctx) + + out_npy = f_numpy(data_npy) + out_nd = tvm.nd.array(np.empty(data_npy.shape).astype(B.dtype), ctx) + foo(data_nd, out_nd) + tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) + + for device in get_all_backend(): + check_device(device) + + test_apply(topi.bitwise_not, "bitwise_not", np.bitwise_not, ()) + test_apply(topi.bitwise_not, "bitwise_not", np.bitwise_not, (2, 1, 2)) + + def test_logical_binary_ele(): def test_apply( func, @@ -314,6 +355,33 @@ def check_device(device): test_apply(topi.logical_or, "logical_or", np.logical_or, [True, False], [False, False]) +def test_bitwise_and(): + verify_broadcast_binary_ele( + None, None, topi.bitwise_and, np.bitwise_and, + dtype="int32") + verify_broadcast_binary_ele( + (2, 1, 2), (2, 1, 2), topi.bitwise_and, np.bitwise_and, + dtype="int32") + + +def test_bitwise_or(): + verify_broadcast_binary_ele( + None, None, topi.bitwise_or, np.bitwise_or, + dtype="int32") + verify_broadcast_binary_ele( + (2, 1, 2), (2, 1, 2), topi.bitwise_or, np.bitwise_or, + dtype="int32") + + +def test_bitwise_xor(): + verify_broadcast_binary_ele( + None, None, topi.bitwise_xor, np.bitwise_xor, + dtype="int32") + verify_broadcast_binary_ele( + (2, 1, 2), (2, 1, 2), topi.bitwise_xor, np.bitwise_xor, + dtype="int32") + + if __name__ == "__main__": test_add() test_shift() @@ -328,4 +396,8 @@ def check_device(device): test_power() test_broadcast_to() test_logical_single_ele() + test_bitwise_not() test_logical_binary_ele() + test_bitwise_and() + test_bitwise_or() + test_bitwise_xor()