diff --git a/nnvm/python/nnvm/top/tensor.py b/nnvm/python/nnvm/top/tensor.py index a781d17c6b110..bd486287abb38 100644 --- a/nnvm/python/nnvm/top/tensor.py +++ b/nnvm/python/nnvm/top/tensor.py @@ -53,11 +53,6 @@ def _compute(attrs, x, _): reg.register_schedule("copy", _fschedule_broadcast) # cast -@reg.register_compute("cast") -def compute_cast(attrs, inputs, _): - """Compute definition of cast""" - dtype = attrs.get_string("dtype") - return topi.cast(inputs[0], dtype) reg.register_pattern("cast", OpPattern.ELEMWISE) reg.register_schedule("cast", _fschedule_broadcast) @@ -210,18 +205,10 @@ def compute_cast(attrs, inputs, _): reg.register_schedule("ones_like", _fschedule_elemwise) # greater -@reg.register_compute("greater") -def compute_greater(_, inputs, out_info): - """Compute definition of greater""" - return topi.greater(inputs[0], inputs[1]).astype('float32') reg.register_pattern("greater", OpPattern.ELEMWISE) reg.register_schedule("greater", _fschedule_elemwise) # less -@reg.register_compute("less") -def compute_less(_, inputs, out_info): - """Compute definition of less""" - return topi.less(inputs[0], inputs[1]).astype('float32') reg.register_pattern("less", OpPattern.ELEMWISE) reg.register_schedule("less", _fschedule_elemwise) diff --git a/nnvm/src/top/tensor/elemwise.cc b/nnvm/src/top/tensor/elemwise.cc index 1aaab4725c90d..57e7137f51516 100644 --- a/nnvm/src/top/tensor/elemwise.cc +++ b/nnvm/src/top/tensor/elemwise.cc @@ -781,6 +781,12 @@ with 1.0 if (left > right), otherwise 0.0 element-wise. .add_argument("rhs", "Tensor", "Second input") .set_num_inputs(2) .set_attr("FInferShape", ElemwiseShape<2, 1>) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + return Array{ topi::cast(topi::greater(inputs[0], inputs[1]), out_info[0]->dtype) }; +}) .set_support_level(4); @@ -793,6 +799,12 @@ with 1.0 if (left < right), otherwise 0.0 element-wise. .add_argument("rhs", "Tensor", "Second input") .set_num_inputs(2) .set_attr("FInferShape", ElemwiseShape<2, 1>) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + return Array{ topi::cast(topi::less(inputs[0], inputs[1]), out_info[0]->dtype) }; +}) .set_support_level(4); NNVM_REGISTER_INDICATOR_OP(_max_mask) diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 0b0beb5b73a75..72e49a040efee 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -15,7 +15,9 @@ #include "../elemwise_op_common.h" #include "topi/nn/flatten.h" #include "topi/transform.h" +#include "topi/elemwise.h" #include "topi/detail/constant_utils.h" +#include "../../compiler/compile_engine.h" namespace nnvm { namespace top { @@ -413,6 +415,14 @@ NNVM_REGISTER_OP(cast) .set_attr("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>) .set_num_inputs(1) .set_num_outputs(1) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + const CastParam& param = nnvm::get(attrs.parsed); + Type dtype = GetTVMType(param.dtype); + return Array{ topi::cast(inputs[0], dtype) }; +}) .set_support_level(1);