Skip to content

Commit

Permalink
[Tensorflow, NNVM, TOPI] Support for logical operators
Browse files Browse the repository at this point in the history
fixes
  • Loading branch information
ashutoshparkhi committed Jan 17, 2019
1 parent d0f8366 commit f2a2ff5
Show file tree
Hide file tree
Showing 11 changed files with 161 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ List of operators
topi.not_equal
topi.greater_equal
topi.less_equal
topi.logical_and
topi.logical_or
topi.logical_not
topi.image.resize


Expand Down
6 changes: 6 additions & 0 deletions docs/nnvm_top.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ This level enables fully connected multi-layer perceptron.
nnvm.symbol.exp
nnvm.symbol.log
nnvm.symbol.sqrt
nnvm.symbol.logical_and
nnvm.symbol.logical_or
nnvm.symbol.logical_not
nnvm.symbol.elemwise_add
nnvm.symbol.elemwise_sub
nnvm.symbol.elemwise_mul
Expand Down Expand Up @@ -172,6 +175,9 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.exp
.. autofunction:: nnvm.symbol.log
.. autofunction:: nnvm.symbol.sqrt
.. autofunction:: nnvm.symbol.logical_and
.. autofunction:: nnvm.symbol.logical_or
.. autofunction:: nnvm.symbol.logical_not
.. autofunction:: nnvm.symbol.elemwise_add
.. autofunction:: nnvm.symbol.elemwise_sub
.. autofunction:: nnvm.symbol.elemwise_mul
Expand Down
2 changes: 2 additions & 0 deletions nnvm/python/nnvm/compiler/graph_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def set_shape_inputs(g, shape):
"uint16": 8,
"uint32": 9,
"uint64": 10,
"bool": 11,
}

TCODE_TO_DTYPE = {
Expand All @@ -54,6 +55,7 @@ def set_shape_inputs(g, shape):
8: "uint16",
9: "uint32",
10: "uint64",
11: "bool",
}

def set_dtype_inputs(g, dtype):
Expand Down
8 changes: 8 additions & 0 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,11 @@ def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):

return _sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis)

def _logical(name):
def _impl(inputs, attr, params):
return AttrCvt(op_name=name)(inputs, attr)
return _impl

# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -924,6 +929,9 @@ def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):
'Transpose' : _transpose(),
'Tanh' : AttrCvt('tanh'),
'Mean' : _mean(),
'LogicalAnd' : _logical('logical_and'),
'LogicalOr' : _logical('logical_or'),
'LogicalNot' : _logical('logical_not'),
'Less' : _broadcast('less'),
'Greater' : _broadcast('greater'),
'LessEqual' : _broadcast('less_equal'),
Expand Down
12 changes: 12 additions & 0 deletions nnvm/python/nnvm/top/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,18 @@ def _compute(attrs, x, _):
reg.register_pattern("__rshift_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rshift_scalar__", _fschedule_broadcast)

# logical_and
reg.register_pattern("logical_and", OpPattern.ELEMWISE)
reg.register_schedule("logical_and", _fschedule_broadcast)

# logical_or
reg.register_pattern("logical_or", OpPattern.ELEMWISE)
reg.register_schedule("logical_or", _fschedule_broadcast)

# logical_not
reg.register_pattern("logical_not", OpPattern.ELEMWISE)
reg.register_schedule("logical_not", _fschedule_broadcast)

# elemwise_add
reg.register_pattern("elemwise_add", OpPattern.BROADCAST)
reg.register_schedule("elemwise_add", _fschedule_broadcast)
Expand Down
3 changes: 3 additions & 0 deletions nnvm/src/compiler/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ int GetTypeFlag(tvm::Type type) {
if (type == tvm::UInt(16)) return 8;
if (type == tvm::UInt(32)) return 9;
if (type == tvm::UInt(64)) return 10;
if (type == tvm::UInt(1)) return 11;
LOG(FATAL) << "cannot convert " << type;
return 0;
}
Expand Down Expand Up @@ -68,6 +69,8 @@ Type GetTVMType(int type_flag) {
return tvm::UInt(32);
case 10:
return tvm::UInt(64);
case 11:
return tvm::UInt(1);
default:
LOG(FATAL) << "unknown type_flag=" << type_flag;
return Float(32);
Expand Down
38 changes: 38 additions & 0 deletions nnvm/src/top/tensor/elemwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,31 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_pow)
return Array<Tensor>{ topi::power(inputs[0], inputs[1]) };
});

// logical
NNVM_REGISTER_ELEMWISE_BINARY_OP(logical_and)
.describe(R"code(Elementwise compute the logical AND
)code")
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::logical_and(inputs[0], inputs[1]) };
});

NNVM_REGISTER_ELEMWISE_BINARY_OP(logical_or)
.describe(R"code(Elementwise compute the logical OR
)code")
.set_support_level(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::logical_or(inputs[0], inputs[1]) };
});

// negative
NNVM_REGISTER_ELEMWISE_UNARY_OP(negative)
.describe(R"code(Elemenwise numeric negative
Expand All @@ -383,6 +408,19 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(negative)
};
});

// logical NOT
NNVM_REGISTER_ELEMWISE_UNARY_OP(logical_not)
.describe(R"code(Elementwise compute the logical NOT
)code" NNVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::logical_not(inputs[0]) };
});

// copy
NNVM_REGISTER_ELEMWISE_UNARY_OP(copy)
.describe(R"code(Copy tensor to another one.
Expand Down
43 changes: 43 additions & 0 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,48 @@ def test_forward_pad():
_test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT")
_test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT", constant_values=1.0)

#######################################################################
# Logical operators
# --------------------
def test_logical_and():
with tf.Graph().as_default():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
out = tf.logical_and(in1, in2, name='out')
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')

def test_logical_or():
with tf.Graph().as_default():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
out = tf.logical_or(in1, in2, name='out')
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')

def test_logical_xor():
with tf.Graph().as_default():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2')
out = tf.logical_xor(in1, in2, name='out')
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0')

def test_logical_not():
with tf.Graph().as_default():
in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1')
out = tf.logical_not(in1, name='out')
in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool')
compare_tf_with_tvm(in_data1, 'in1:0', 'out:0')

def test_forward_logical():
test_logical_and()
test_logical_or()
test_logical_xor()
test_logical_not()

#######################################################################
# Inception V3
Expand Down Expand Up @@ -1205,3 +1247,4 @@ def test_forward_rel_ops():

# Relational ops
test_forward_rel_ops()
test_forward_logical()
27 changes: 27 additions & 0 deletions topi/include/topi/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,33 @@ inline tvm::Tensor broadcast_to(const tvm::Tensor& t,
return topi::OpName(A, B); \
}

/*!
* \fn logical_and
* \brief Compute A && B with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(logical_and, { return a && b; });
TOPI_DEFINE_OP_OVERLOAD(operator&&, logical_and);

/*!
* \fn logical_or
* \brief Compute A || B with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; });
TOPI_DEFINE_OP_OVERLOAD(operator||, logical_or);

/*!
* \fn add
Expand Down
17 changes: 17 additions & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,23 @@ inline Tensor negative(const Tensor& x,
}, name, tag);
}

/*!
* \brief Creates an operation that returns the logical NOT of a given tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the logical NOT operation
*/
inline Tensor logical_not(const Tensor& x,
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
return !x(i);
}, name, tag);
}

/*!
* \brief Creates an operation that clips each element of a tensor to
* the interval [a_min, a_max]
Expand Down
2 changes: 2 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ TOPI_REGISTER_BCAST_OP("topi.maximum", topi::maximum);
TOPI_REGISTER_BCAST_OP("topi.minimum", topi::minimum);
TOPI_REGISTER_BCAST_OP("topi.power", topi::power);
TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift);
TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and);
TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or);
TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift);
TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater);
TOPI_REGISTER_BCAST_OP("topi.less", topi::less);
Expand Down

0 comments on commit f2a2ff5

Please sign in to comment.