diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 450573e4c524..68fb8d60e8a4 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -67,6 +67,9 @@ List of operators topi.not_equal topi.greater_equal topi.less_equal + topi.logical_and + topi.logical_or + topi.logical_not topi.image.resize diff --git a/docs/nnvm_top.rst b/docs/nnvm_top.rst index 717ce985e002..f05eed3308b3 100644 --- a/docs/nnvm_top.rst +++ b/docs/nnvm_top.rst @@ -35,6 +35,9 @@ This level enables fully connected multi-layer perceptron. nnvm.symbol.exp nnvm.symbol.log nnvm.symbol.sqrt + nnvm.symbol.logical_and + nnvm.symbol.logical_or + nnvm.symbol.logical_not nnvm.symbol.elemwise_add nnvm.symbol.elemwise_sub nnvm.symbol.elemwise_mul @@ -172,6 +175,9 @@ Detailed Definitions .. autofunction:: nnvm.symbol.exp .. autofunction:: nnvm.symbol.log .. autofunction:: nnvm.symbol.sqrt +.. autofunction:: nnvm.symbol.logical_and +.. autofunction:: nnvm.symbol.logical_or +.. autofunction:: nnvm.symbol.logical_not .. autofunction:: nnvm.symbol.elemwise_add .. autofunction:: nnvm.symbol.elemwise_sub .. autofunction:: nnvm.symbol.elemwise_mul diff --git a/nnvm/python/nnvm/compiler/graph_attr.py b/nnvm/python/nnvm/compiler/graph_attr.py index 3ce6c4b53239..2f1f0350d71b 100644 --- a/nnvm/python/nnvm/compiler/graph_attr.py +++ b/nnvm/python/nnvm/compiler/graph_attr.py @@ -39,6 +39,7 @@ def set_shape_inputs(g, shape): "uint16": 8, "uint32": 9, "uint64": 10, + "bool": 11, } TCODE_TO_DTYPE = { @@ -54,6 +55,7 @@ def set_shape_inputs(g, shape): 8: "uint16", 9: "uint32", 10: "uint64", + 11: "bool", } def set_dtype_inputs(g, dtype): diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 9a302da72ae6..7e0b41fe9c0a 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -862,6 +862,11 @@ def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1): return _sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis) +def _logical(name): + def _impl(inputs, attr, params): + return AttrCvt(op_name=name)(inputs, attr) + return _impl + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -924,6 +929,9 @@ def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1): 'Transpose' : _transpose(), 'Tanh' : AttrCvt('tanh'), 'Mean' : _mean(), + 'LogicalAnd' : _logical('logical_and'), + 'LogicalOr' : _logical('logical_or'), + 'LogicalNot' : _logical('logical_not'), 'Less' : _broadcast('less'), 'Greater' : _broadcast('greater'), 'LessEqual' : _broadcast('less_equal'), diff --git a/nnvm/python/nnvm/top/tensor.py b/nnvm/python/nnvm/top/tensor.py index e0214d6ddf16..5dae01695e3a 100644 --- a/nnvm/python/nnvm/top/tensor.py +++ b/nnvm/python/nnvm/top/tensor.py @@ -140,6 +140,18 @@ def _compute(attrs, x, _): reg.register_pattern("__rshift_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__rshift_scalar__", _fschedule_broadcast) +# logical_and +reg.register_pattern("logical_and", OpPattern.ELEMWISE) +reg.register_schedule("logical_and", _fschedule_broadcast) + +# logical_or +reg.register_pattern("logical_or", OpPattern.ELEMWISE) +reg.register_schedule("logical_or", _fschedule_broadcast) + +# logical_not +reg.register_pattern("logical_not", OpPattern.ELEMWISE) +reg.register_schedule("logical_not", _fschedule_broadcast) + # elemwise_add reg.register_pattern("elemwise_add", OpPattern.BROADCAST) reg.register_schedule("elemwise_add", _fschedule_broadcast) diff --git a/nnvm/src/compiler/compile_engine.cc b/nnvm/src/compiler/compile_engine.cc index 6df70b53ccae..ee0926d9885f 100644 --- a/nnvm/src/compiler/compile_engine.cc +++ b/nnvm/src/compiler/compile_engine.cc @@ -40,6 +40,7 @@ int GetTypeFlag(tvm::Type type) { if (type == tvm::UInt(16)) return 8; if (type == tvm::UInt(32)) return 9; if (type == tvm::UInt(64)) return 10; + if (type == tvm::UInt(1)) return 11; LOG(FATAL) << "cannot convert " << type; return 0; } @@ -68,6 +69,8 @@ Type GetTVMType(int type_flag) { return tvm::UInt(32); case 10: return tvm::UInt(64); + case 11: + return tvm::UInt(1); default: LOG(FATAL) << "unknown type_flag=" << type_flag; return Float(32); diff --git a/nnvm/src/top/tensor/elemwise.cc b/nnvm/src/top/tensor/elemwise.cc index 3ee52008eb1c..52d9aa4456ed 100644 --- a/nnvm/src/top/tensor/elemwise.cc +++ b/nnvm/src/top/tensor/elemwise.cc @@ -361,6 +361,31 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_pow) return Array{ topi::power(inputs[0], inputs[1]) }; }); +// logical +NNVM_REGISTER_ELEMWISE_BINARY_OP(logical_and) +.describe(R"code(Elementwise compute the logical AND + +)code") +.set_support_level(1) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + return Array{ topi::logical_and(inputs[0], inputs[1]) }; +}); + +NNVM_REGISTER_ELEMWISE_BINARY_OP(logical_or) +.describe(R"code(Elementwise compute the logical OR + +)code") +.set_support_level(1) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + return Array{ topi::logical_or(inputs[0], inputs[1]) }; +}); + // negative NNVM_REGISTER_ELEMWISE_UNARY_OP(negative) .describe(R"code(Elemenwise numeric negative @@ -383,6 +408,19 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(negative) }; }); +// logical NOT +NNVM_REGISTER_ELEMWISE_UNARY_OP(logical_not) +.describe(R"code(Elementwise compute the logical NOT + +)code" NNVM_ADD_FILELINE) +.set_support_level(3) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + return Array{ topi::logical_not(inputs[0]) }; +}); + // copy NNVM_REGISTER_ELEMWISE_UNARY_OP(copy) .describe(R"code(Copy tensor to another one. diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 0ea92248f0f5..2de577a6e8b3 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -777,6 +777,48 @@ def test_forward_pad(): _test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT") _test_pad((2, 3), [[1,1], [2,2]], mode="CONSTANT", constant_values=1.0) +####################################################################### +# Logical operators +# -------------------- +def test_logical_and(): + with tf.Graph().as_default(): + in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1') + in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2') + out = tf.logical_and(in1, in2, name='out') + in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool') + in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool') + compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0') + +def test_logical_or(): + with tf.Graph().as_default(): + in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1') + in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2') + out = tf.logical_or(in1, in2, name='out') + in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool') + in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool') + compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0') + +def test_logical_xor(): + with tf.Graph().as_default(): + in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1') + in2 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in2') + out = tf.logical_xor(in1, in2, name='out') + in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool') + in_data2 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool') + compare_tf_with_tvm([in_data1, in_data2], ['in1:0', 'in2:0'], 'out:0') + +def test_logical_not(): + with tf.Graph().as_default(): + in1 = tf.placeholder(tf.bool, shape=[1, 4, 4, 3], name='in1') + out = tf.logical_not(in1, name='out') + in_data1 = np.random.choice(a=[False, True],size=(1, 4, 4, 3)).astype('bool') + compare_tf_with_tvm(in_data1, 'in1:0', 'out:0') + +def test_forward_logical(): + test_logical_and() + test_logical_or() + test_logical_xor() + test_logical_not() ####################################################################### # Inception V3 @@ -1205,3 +1247,4 @@ def test_forward_rel_ops(): # Relational ops test_forward_rel_ops() + test_forward_logical() diff --git a/topi/include/topi/broadcast.h b/topi/include/topi/broadcast.h index ad1c04ae1327..88007ee94e85 100644 --- a/topi/include/topi/broadcast.h +++ b/topi/include/topi/broadcast.h @@ -93,6 +93,33 @@ inline tvm::Tensor broadcast_to(const tvm::Tensor& t, return topi::OpName(A, B); \ } +/*! + * \fn logical_and + * \brief Compute A && B with auto-broadcasting. + * + * \param A The first tensor, or Expr + * \param B The second tensor, or Expr + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return The result. + */ +TOPI_DEFINE_BCAST_OP(logical_and, { return a && b; }); +TOPI_DEFINE_OP_OVERLOAD(operator&&, logical_and); + +/*! + * \fn logical_or + * \brief Compute A || B with auto-broadcasting. + * + * \param A The first tensor, or Expr + * \param B The second tensor, or Expr + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return The result. + */ +TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; }); +TOPI_DEFINE_OP_OVERLOAD(operator||, logical_or); /*! * \fn add diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index 02bc51515159..40dffa09a9bf 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -71,6 +71,23 @@ inline Tensor negative(const Tensor& x, }, name, tag); } +/*! +* \brief Creates an operation that returns the logical NOT of a given tensor +* +* \param x The input tensor +* \param name The name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor whose op member is the logical NOT operation +*/ +inline Tensor logical_not(const Tensor& x, + std::string name = "tensor", + std::string tag = kElementWise) { + return compute(x->shape, [&](const Array& i) { + return !x(i); + }, name, tag); +} + /*! * \brief Creates an operation that clips each element of a tensor to * the interval [a_min, a_max] diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 7adcb11c5656..a06a026b12ed 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -115,6 +115,8 @@ TOPI_REGISTER_BCAST_OP("topi.maximum", topi::maximum); TOPI_REGISTER_BCAST_OP("topi.minimum", topi::minimum); TOPI_REGISTER_BCAST_OP("topi.power", topi::power); TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift); +TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and); +TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or); TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift); TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater); TOPI_REGISTER_BCAST_OP("topi.less", topi::less);