diff --git a/nnvm/python/nnvm/top/tensor.py b/nnvm/python/nnvm/top/tensor.py index a9ff8c8fe4c3..a781d17c6b11 100644 --- a/nnvm/python/nnvm/top/tensor.py +++ b/nnvm/python/nnvm/top/tensor.py @@ -182,69 +182,30 @@ def compute_cast(attrs, inputs, _): reg.register_schedule("clip", _fschedule_elemwise) # elemwise sum -@reg.register_compute("elemwise_sum") -def compute_elemwise_sum(attrs, inputs, _): - """Compute definition of elemwise sum""" - num_args = attrs.get_int("num_args") - assert num_args == len(inputs), "Number of tensors does not match num_args." - return topi.tensor.elemwise_sum(inputs) reg.register_pattern("elemwise_sum", OpPattern.ELEMWISE) reg.register_schedule("elemwise_sum", _fschedule_elemwise) # full -@reg.register_compute("full") -def compute_full(attrs, inputs, _): - """Compute definition of full""" - shape = attrs.get_int_tuple("shape") - dtype = attrs.get_string("dtype") - fill_value = attrs.get_float("fill_value") - return topi.tensor.full(shape, dtype, fill_value) reg.register_pattern("full", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_schedule("full", _fschedule_elemwise) # full_like -@reg.register_compute("full_like") -def compute_full_like(attrs, inputs, _): - """Compute definition of full_like""" - fill_value = attrs.get_float("fill_value") - return topi.tensor.full_like(inputs[0], fill_value) reg.register_pattern("full_like", OpPattern.ELEMWISE) reg.register_schedule("full_like", _fschedule_elemwise) # zeros -@reg.register_compute("zeros") -def compute_zeros(attrs, inputs, _): - """Compute definition of zeros""" - shape = attrs.get_int_tuple("shape") - dtype = attrs.get_string("dtype") - return topi.tensor.full(shape, dtype, 0) reg.register_pattern("zeros", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_schedule("zeros", _fschedule_elemwise) # zeros_like -@reg.register_compute("zeros_like") -def compute_zeros_like(_, inputs, out_info): - """Compute definition of zeros_like""" - return topi.tensor.full_like(inputs[0], 0) reg.register_pattern("zeros_like", OpPattern.ELEMWISE) reg.register_schedule("zeros_like", _fschedule_elemwise) # ones -@reg.register_compute("ones") -def compute_ones(attrs, inputs, _): - """Compute definition of ones""" - shape = attrs.get_int_tuple("shape") - dtype = attrs.get_string("dtype") - #tvm.tensor.Tensor() - return topi.tensor.full(shape, dtype, 1) reg.register_pattern("ones", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_schedule("ones", _fschedule_elemwise) # ones_like -@reg.register_compute("ones_like") -def compute_ones_like(_, inputs, out_info): - """Compute definition of ones_like""" - return topi.tensor.full_like(inputs[0], 1) reg.register_pattern("ones_like", OpPattern.ELEMWISE) reg.register_schedule("ones_like", _fschedule_elemwise) diff --git a/nnvm/src/top/tensor/elemwise.cc b/nnvm/src/top/tensor/elemwise.cc index cd38817b51a5..1aaab4725c90 100644 --- a/nnvm/src/top/tensor/elemwise.cc +++ b/nnvm/src/top/tensor/elemwise.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include "../op_common.h" @@ -14,6 +15,7 @@ #include "topi/broadcast.h" #include "topi/elemwise.h" #include "topi/tags.h" +#include "../../compiler/compile_engine.h" namespace nnvm { namespace top { @@ -382,6 +384,16 @@ NNVM_REGISTER_INIT_OP(full) .set_attr("FInferShape", ZeroShape) .set_attr("FInferType", ZeroType) .set_attr("FCorrectLayout", ZeroLayout) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + const InitOpWithScalarParam& param = nnvm::get(attrs.parsed); + Array shape = ShapeToArray(param.shape); + Type dtype = GetTVMType(param.dtype); + Expr fill_value = tvm::make_const(dtype, param.fill_value); + return Array{ topi::full(shape, dtype, fill_value) }; +}) .set_support_level(4); NNVM_REGISTER_INIT_OP(zeros) @@ -395,6 +407,16 @@ NNVM_REGISTER_INIT_OP(zeros) .set_attr("FInferShape", ZeroShape) .set_attr("FInferType", ZeroType) .set_attr("FCorrectLayout", ZeroLayout) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + const InitOpParam& param = nnvm::get(attrs.parsed); + Array shape = ShapeToArray(param.shape); + Type dtype = GetTVMType(param.dtype); + Expr fill_value = tvm::make_const(dtype, 0); + return Array{ topi::full(shape, dtype, fill_value) }; +}) .set_support_level(4); NNVM_REGISTER_INIT_OP(ones) @@ -408,6 +430,16 @@ NNVM_REGISTER_INIT_OP(ones) .set_attr("FInferShape", ZeroShape) .set_attr("FInferType", ZeroType) .set_attr("FCorrectLayout", ZeroLayout) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + const InitOpParam& param = nnvm::get(attrs.parsed); + Array shape = ShapeToArray(param.shape); + Type dtype = GetTVMType(param.dtype); + Expr fill_value = tvm::make_const(dtype, 1); + return Array{ topi::full(shape, dtype, fill_value) }; +}) .set_support_level(4); // full_like @@ -419,6 +451,14 @@ as the input array .add_arguments(FillValueParam::__FIELDS__()) .set_attr_parser(ParamParser) .set_attr("FGetAttrDict", ParamGetAttrDict) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + const FillValueParam& param = nnvm::get(attrs.parsed); + const Expr fill_value = tvm::make_const(out_info[0]->dtype, param.fill_value); + return Array { topi::full_like(inputs[0], fill_value) }; +}) .set_support_level(4); NNVM_REGISTER_INIT_LIKE_OP(zeros_like) @@ -426,6 +466,13 @@ NNVM_REGISTER_INIT_LIKE_OP(zeros_like) as the input array. )code") +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + return Array { topi::full_like(inputs[0], + tvm::make_const(out_info[0]->dtype, 0)) }; +}) .set_support_level(4); NNVM_REGISTER_INIT_LIKE_OP(ones_like) @@ -433,6 +480,13 @@ NNVM_REGISTER_INIT_LIKE_OP(ones_like) as the input array. )code") +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + return Array { topi::full_like(inputs[0], + tvm::make_const(out_info[0]->dtype, 1)) }; +}) .set_support_level(4); // unary scalar op @@ -684,6 +738,14 @@ NNVM_REGISTER_ELEMWISE_REDUCE_OP(elemwise_sum) .describe(R"code(Adds all input arguments element-wise. )code" NNVM_ADD_FILELINE) +.set_attr( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info) { + const ElementWiseReduceParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.num_args, inputs.size()) << """Compute definition of elemwise sum"""; + return Array{ topi::elemwise_sum(inputs) }; +}) .set_attr( "FGradient", [](const NodePtr& n, const std::vector& ograds){