Skip to content

Commit

Permalink
[TOPI][Relay] Add bitwise ops (apache#4815)
Browse files Browse the repository at this point in the history
* Add bitwise ops to topi

* Add the bitwise ops to relay.
  • Loading branch information
abergeron authored and alexwong committed Feb 26, 2020
1 parent 3268a0a commit dab7bcd
Show file tree
Hide file tree
Showing 9 changed files with 319 additions and 1 deletion.
7 changes: 7 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
register_schedule("abs", schedule_broadcast)
register_schedule("tanh", schedule_broadcast)
register_schedule("logical_not", schedule_broadcast)
register_schedule("bitwise_not", schedule_broadcast)
register_schedule("negative", schedule_broadcast)
register_schedule("copy", schedule_broadcast)

Expand All @@ -57,6 +58,9 @@
register_schedule("floor_mod", schedule_broadcast)
register_schedule("logical_and", schedule_broadcast)
register_schedule("logical_or", schedule_broadcast)
register_schedule("bitwise_and", schedule_broadcast)
register_schedule("bitwise_or", schedule_broadcast)
register_schedule("bitwise_xor", schedule_broadcast)
register_schedule("equal", schedule_broadcast)
register_schedule("not_equal", schedule_broadcast)
register_schedule("less", schedule_broadcast)
Expand Down Expand Up @@ -194,6 +198,9 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("floor_mod", False, broadcast_shape_func)
register_shape_func("logical_and", False, broadcast_shape_func)
register_shape_func("logical_or", False, broadcast_shape_func)
register_shape_func("bitwise_and", False, broadcast_shape_func)
register_shape_func("bitwise_or", False, broadcast_shape_func)
register_shape_func("bitwise_xor", False, broadcast_shape_func)
register_shape_func("equal", False, broadcast_shape_func)
register_shape_func("not_equal", False, broadcast_shape_func)
register_shape_func("less", False, broadcast_shape_func)
Expand Down
70 changes: 70 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,22 @@ def logical_not(data):
return _make.logical_not(data)


def bitwise_not(data):
"""Compute element-wise bitwise not of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.bitwise_not(data)


def add(lhs, rhs):
"""Addition with numpy-style broadcasting.
Expand Down Expand Up @@ -506,6 +522,60 @@ def logical_or(lhs, rhs):
return _make.logical_or(lhs, rhs)


def bitwise_and(lhs, rhs):
"""bitwise AND with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.bitwise_and(lhs, rhs)


def bitwise_or(lhs, rhs):
"""bitwise OR with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.bitwise_or(lhs, rhs)


def bitwise_xor(lhs, rhs):
"""bitwise XOR with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.bitwise_xor(lhs, rhs)


def equal(lhs, rhs):
"""Broadcasted elementwise test for (lhs == rhs).
Expand Down
18 changes: 18 additions & 0 deletions src/relay/op/tensor/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,24 @@ RELAY_REGISTER_BINARY_OP("logical_or")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or));


RELAY_REGISTER_BINARY_OP("bitwise_and")
.describe("Elementwise bitwise AND with broadcasting")
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_and));


RELAY_REGISTER_BINARY_OP("bitwise_or")
.describe("Elementwise bitwise OR with broadcasting")
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_or));


RELAY_REGISTER_BINARY_OP("bitwise_xor")
.describe("Elementwise bitwise XOR with broadcasting")
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_xor));


RELAY_REGISTER_CMP_OP("equal")
.describe("Elementwise equal compare with broadcasting")
.set_support_level(4)
Expand Down
13 changes: 12 additions & 1 deletion src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,24 @@ RELAY_REGISTER_UNARY_OP("logical_not")
.describe(R"code(Returns the logical inverse of input array, computed element-wise.
.. math::
~(x)
!(x)
)code" TVM_ADD_FILELINE)
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not));


RELAY_REGISTER_UNARY_OP("bitwise_not")
.describe(R"code(Returns the bitwise inverse of input array, computed element-wise.
.. math::
~(x)
)code" TVM_ADD_FILELINE)
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::bitwise_not));


// shape_of
TVM_REGISTER_NODE_TYPE(ShapeOfAttrs);

Expand Down
42 changes: 42 additions & 0 deletions topi/include/topi/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,48 @@ TOPI_DEFINE_OP_OVERLOAD(operator&&, logical_and);
TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; });
TOPI_DEFINE_OP_OVERLOAD(operator||, logical_or);

/*!
* \fn bitwise_and
* \brief Compute A & B with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(bitwise_and, { return a & b; });
TOPI_DEFINE_OP_OVERLOAD(operator&, bitwise_and);

/*!
* \fn bitwise_or
* \brief Compute A | B with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(bitwise_or, { return a | b; });
TOPI_DEFINE_OP_OVERLOAD(operator|, bitwise_or);

/*!
* \fn bitwise_xor
* \brief Compute A ^ B with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(bitwise_xor, { return a ^ b; });
TOPI_DEFINE_OP_OVERLOAD(operator^, bitwise_xor);

/*!
* \fn add
* \brief Compute A + B with auto-broadcasting.
Expand Down
17 changes: 17 additions & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,23 @@ inline Tensor logical_not(const Tensor& x,
}, name, tag);
}

/*!
* \brief Creates an operation that returns the bitwise NOT of a given tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the bitwise NOT operation
*/
inline Tensor bitwise_not(const Tensor& x,
std::string name = "T_bitwise_not",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
return ~x(i);
}, name, tag);
}

/*!
* \brief Returns the sign of the tensor
*
Expand Down
73 changes: 73 additions & 0 deletions topi/python/topi/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,63 @@ def logical_or(lhs, rhs):
return _cpp.logical_or(lhs, rhs)


def bitwise_and(lhs, rhs):
"""Compute element-wise bitwise and of data.
Parameters
----------
lhs : tvm.Tensor or Expr
The left operand
rhs : tvm.Tensor or Expr
The right operand
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.bitwise_and(lhs, rhs)


def bitwise_or(lhs, rhs):
"""Compute element-wise bitwise or of data.
Parameters
----------
lhs : tvm.Tensor or Expr
The left operand
rhs : tvm.Tensor or Expr
The right operand
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.bitwise_or(lhs, rhs)


def bitwise_xor(lhs, rhs):
"""Compute element-wise bitwise xor of data.
Parameters
----------
lhs : tvm.Tensor or Expr
The left operand
rhs : tvm.Tensor or Expr
The right operand
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.bitwise_xor(lhs, rhs)


def logical_not(data):
"""Compute element-wise logical not of data.
Expand All @@ -434,3 +491,19 @@ def logical_not(data):
Otherwise returns Tensor.
"""
return _cpp.logical_not(data)


def bitwise_not(data):
"""Compute element-wise bitwise not of data.
Parameters
----------
data : tvm.Tensor or Expr
Returns
-------
ret : tvm.Tensor or Expr
Returns Expr if the operand are Expr.
Otherwise returns Tensor.
"""
return _cpp.bitwise_not(data)
8 changes: 8 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ TOPI_REGISTER_BCAST_OP("topi.power", topi::power);
TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift);
TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and);
TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or);
TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and);
TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or);
TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor);
TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift);
TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater);
TOPI_REGISTER_BCAST_OP("topi.less", topi::less);
Expand All @@ -151,6 +154,11 @@ TVM_REGISTER_GLOBAL("topi.logical_not")
*rv = logical_not(args[0]);
});

TVM_REGISTER_GLOBAL("topi.bitwise_not")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = bitwise_not(args[0]);
});

/* Ops from elemwise.h */
TVM_REGISTER_GLOBAL("topi.exp")
.set_body([](TVMArgs args, TVMRetValue *rv) {
Expand Down
Loading

0 comments on commit dab7bcd

Please sign in to comment.