diff --git a/src/operator/nn/upsampling-inl.h b/src/operator/nn/upsampling-inl.h index bac9709f4e53..91254dad9046 100644 --- a/src/operator/nn/upsampling-inl.h +++ b/src/operator/nn/upsampling-inl.h @@ -23,8 +23,8 @@ * \brief * \author Bing Xu */ -#ifndef MXNET_OPERATOR_UPSAMPLING_INL_H_ -#define MXNET_OPERATOR_UPSAMPLING_INL_H_ +#ifndef MXNET_OPERATOR_NN_UPSAMPLING_INL_H_ +#define MXNET_OPERATOR_NN_UPSAMPLING_INL_H_ #include #include @@ -34,7 +34,8 @@ #include #include #include -#include "./operator_common.h" +#include "../operator_common.h" +#include "./deconvolution-inl.h" namespace mxnet { namespace op { @@ -82,17 +83,16 @@ struct UpSamplingParam : public dmlc::Parameter { }; // struct UpSamplingParam template -class UpSamplingNearestOp : public Operator { +class UpSamplingNearestOp { public: - explicit UpSamplingNearestOp(UpSamplingParam p) { + void Init(UpSamplingParam p) { this->param_ = p; } - virtual void Forward(const OpContext &ctx, + void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { + const std::vector &out_data) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), static_cast(param_.num_args)); @@ -125,19 +125,14 @@ class UpSamplingNearestOp : public Operator { } } - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, + void Backward(const OpContext &ctx, const TBlob &out_grad, const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_args) { + const std::vector &in_grad) { using namespace mshadow; using namespace mshadow::expr; - CHECK_EQ(out_grad.size(), 1U); CHECK_EQ(in_grad.size(), static_cast(param_.num_args)); Stream *s = ctx.get_stream(); - Tensor grad = out_grad[up_enum::kOut].get(s); + Tensor grad = out_grad.get(s); if (param_.num_args > 1) { int begin = 0; for (int i = 0; i < param_.num_args; ++i) { @@ -181,155 +176,68 @@ class UpSamplingNearestOp : public Operator { UpSamplingParam param_; }; // class UpSamplingNearestOp -template -Operator *CreateOp(UpSamplingParam param, int dtype); - - -#if DMLC_USE_CXX11 -class UpSamplingProp : public OperatorProperty { - public: - void Init(const std::vector >& kwargs) override { - param_.Init(kwargs); - } - - std::map GetParams() const override { - return param_.__DICT__(); - } - - std::vector ListArguments() const override { - if (param_.sample_type == up_enum::kNearest) { - std::vector ret; - for (int i = 0; i < param_.num_args; ++i) { - ret.push_back(std::string("arg") + std::to_string(i)); - } - return ret; - } else { - return {"data", "weight"}; - } - } - - bool InferShape(std::vector *in_shape, - std::vector *out_shape, - std::vector *aux_shape) const override { - CHECK_GE(in_shape->size(), 1U); - const TShape &dshape = (*in_shape)[0]; - TShape oshape = dshape; - if (param_.sample_type == up_enum::kNearest) { - CHECK_EQ(in_shape->size(), static_cast(param_.num_args)); - oshape[1] = 0; - for (auto& shape : *in_shape) { - CHECK_EQ(shape.ndim(), 4U) << \ - "UpSamplingNearest: Input data should be 4D in (batch, channel, y, x)"; - int oh = dshape[2]*param_.scale, ow = dshape[3]*param_.scale; - CHECK_EQ(oh%shape[2], 0U) << "UpSamplingNearest: input height of " << shape[2] << \ - "does not divide output height of " << oh; - CHECK_EQ(ow%shape[3], 0U) << "UpSamplingNearest: input width of " << shape[3] << \ - "does not divide output width of " << ow; - if (param_.multi_input_mode == up_enum::kSum) { - CHECK(oshape[1] == 0 || oshape[1] == shape[1]) << \ - "Number of channels must be the same when multi_input_mode==sum"; - oshape[1] = shape[1]; - } else { - oshape[1] += shape[1]; - } - } - } else { - CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]"; - CHECK_EQ(dshape.ndim(), 4U) << \ - "UpSamplingBilinear: Input data should be 4D in (batch, channel, y, x)"; - if (dshape.ndim() == 0) return false; - int kernel = 2 * param_.scale - param_.scale % 2; - SHAPE_ASSIGN_CHECK(*in_shape, - up_enum::kWeight, - mshadow::Shape4(dshape[1], 1, kernel, kernel)); - oshape = dshape; - } - oshape[2] = dshape[2] * param_.scale; - oshape[3] = dshape[3] * param_.scale; - out_shape->clear(); - out_shape->push_back(oshape); - return true; - } - - bool InferType(std::vector *in_type, - std::vector *out_type, - std::vector *aux_type) const override { - CHECK_GE(in_type->size(), 1U); - int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; - for (index_t i = 0; i < in_type->size(); ++i) { - if ((*in_type)[i] == -1) { - (*in_type)[i] = dtype; - } else { - UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]); - } - } - out_type->clear(); - out_type->push_back(dtype); - return true; - } +static inline DeconvolutionParam GetDeconvolutionParam(const UpSamplingParam& param) { + DeconvolutionParam p = DeconvolutionParam(); + int kernel = 2 * param.scale - param.scale % 2; + int stride = param.scale; + int pad = static_cast(ceil((param.scale - 1) / 2.)); + p.workspace = param.workspace; + p.num_group = param.num_filter; + p.num_filter = param.num_filter; + p.no_bias = true; + int shape[] = {1, 1}; + p.dilate = TShape(shape, shape + 2); + shape[0] = shape[1] = kernel; + p.kernel = TShape(shape, shape + 2); + shape[0] = shape[1] = stride; + p.stride = TShape(shape, shape + 2); + shape[0] = shape[1] = pad; + p.pad = TShape(shape, shape + 2); + return p; +} - OperatorProperty* Copy() const override { - auto ptr = new UpSamplingProp(); - ptr->param_ = this->param_; - return ptr; - } - - std::string TypeString() const override { - return "UpSampling"; - } - - std::vector DeclareBackwardDependency( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data) const override { - if (param_.sample_type == up_enum::kNearest) { - return {out_grad[up_enum::kOut]}; - } else { - return {out_grad[up_enum::kOut], in_data[up_enum::kData], in_data[up_enum::kWeight]}; - } - } - - std::vector > BackwardInplaceOption( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &in_grad) const override { - return {}; - } - - std::vector ForwardResource( - const std::vector &in_shape) const override { - if (param_.sample_type == up_enum::kNearest) { - return {}; - } else { - return {ResourceRequest::kTempSpace}; - } - } - - std::vector BackwardResource( - const std::vector &in_shape) const override { - if (param_.sample_type == up_enum::kNearest) { - return {}; - } else { - return {ResourceRequest::kTempSpace}; - } - } - - Operator* CreateOperator(Context ctx) const override { - LOG(FATAL) << "Not Implemented"; - return NULL; - } - - Operator* CreateOperatorEx(Context ctx, std::vector *in_shape, - std::vector *in_type) const override; +template +void UpSamplingCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const UpSamplingParam& param = nnvm::get(attrs.parsed); + if (param.sample_type == up_enum::kNearest) { + MSHADOW_REAL_TYPE_SWITCH(inputs[deconv::kData].type_flag_, DType, { + static thread_local UpSamplingNearestOp op; + op.Init(param); + op.Forward(ctx, inputs, req, outputs); + }); + } else if (param.sample_type == up_enum::kBilinear) { + DeconvolutionParam p = GetDeconvolutionParam(param); + _DeconvolutionCompute(p, ctx, inputs, req, outputs); + } else { + LOG(FATAL) << "Unknown sample type"; + } +} +template +void UpSamplingGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const UpSamplingParam& param = nnvm::get(attrs.parsed); + if (param.sample_type == up_enum::kNearest) { + MSHADOW_REAL_TYPE_SWITCH(inputs[deconv::kData].type_flag_, DType, { + CHECK_EQ(inputs.size(), 1U); + static thread_local UpSamplingNearestOp op; + op.Init(param); + op.Backward(ctx, inputs[0], req, outputs); + }); + } else if (param.sample_type == up_enum::kBilinear) { + DeconvolutionParam p = GetDeconvolutionParam(param); + _DeconvolutionGradCompute(p, ctx, inputs, req, outputs); + } else { + LOG(FATAL) << "Unknown sample type"; + } +} - private: - UpSamplingParam param_; -}; // class UpSamplingProp -#endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_UPSAMPLING_INL_H_ +#endif // MXNET_OPERATOR_NN_UPSAMPLING_INL_H_ diff --git a/src/operator/nn/upsampling.cc b/src/operator/nn/upsampling.cc index 8942e35ab325..87316a939718 100644 --- a/src/operator/nn/upsampling.cc +++ b/src/operator/nn/upsampling.cc @@ -21,7 +21,7 @@ * Copyright (c) 2015 by Contributors * \file upsampling_nearest.cc * \brief - * \author Bing Xu + * \author Bing Xu, Da Zheng */ #include "./upsampling-inl.h" @@ -30,51 +30,123 @@ namespace mxnet { namespace op { -template<> -Operator *CreateOp(UpSamplingParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - if (param.sample_type == up_enum::kNearest) { - op = new UpSamplingNearestOp(param); - } else if (param.sample_type == up_enum::kBilinear) { - DeconvolutionParam p = DeconvolutionParam(); - int kernel = 2 * param.scale - param.scale % 2; - int stride = param.scale; - int pad = static_cast(ceil((param.scale - 1) / 2.)); - p.workspace = param.workspace; - p.num_group = param.num_filter; - p.num_filter = param.num_filter; - p.no_bias = true; - int shape[] = {1, 1}; - p.dilate = TShape(shape, shape + 2); - shape[0] = shape[1] = kernel; - p.kernel = TShape(shape, shape + 2); - shape[0] = shape[1] = stride; - p.stride = TShape(shape, shape + 2); - shape[0] = shape[1] = pad; - p.pad = TShape(shape, shape + 2); - op = new DeconvolutionOp(p); - } else { - LOG(FATAL) << "Unknown sample type"; + +static bool UpSamplingShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, std::vector *out_shape) { + const UpSamplingParam& param_ = nnvm::get(attrs.parsed); + CHECK_GE(in_shape->size(), 1U); + const TShape &dshape = (*in_shape)[0]; + TShape oshape = dshape; + if (param_.sample_type == up_enum::kNearest) { + CHECK_EQ(in_shape->size(), static_cast(param_.num_args)); + oshape[1] = 0; + for (auto& shape : *in_shape) { + CHECK_EQ(shape.ndim(), 4U) << \ + "UpSamplingNearest: Input data should be 4D in (batch, channel, y, x)"; + int oh = dshape[2]*param_.scale, ow = dshape[3]*param_.scale; + CHECK_EQ(oh%shape[2], 0U) << "UpSamplingNearest: input height of " << shape[2] << \ + "does not divide output height of " << oh; + CHECK_EQ(ow%shape[3], 0U) << "UpSamplingNearest: input width of " << shape[3] << \ + "does not divide output width of " << ow; + if (param_.multi_input_mode == up_enum::kSum) { + CHECK(oshape[1] == 0 || oshape[1] == shape[1]) << \ + "Number of channels must be the same when multi_input_mode==sum"; + oshape[1] = shape[1]; + } else { + oshape[1] += shape[1]; + } + } + } else { + CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]"; + CHECK_EQ(dshape.ndim(), 4U) << \ + "UpSamplingBilinear: Input data should be 4D in (batch, channel, y, x)"; + if (dshape.ndim() == 0) return false; + int kernel = 2 * param_.scale - param_.scale % 2; + SHAPE_ASSIGN_CHECK(*in_shape, + up_enum::kWeight, + mshadow::Shape4(dshape[1], 1, kernel, kernel)); + oshape = dshape; + } + oshape[2] = dshape[2] * param_.scale; + oshape[3] = dshape[3] * param_.scale; + out_shape->clear(); + out_shape->push_back(oshape); + return true; +} + +static inline std::vector ListArguments(const UpSamplingParam& param) { + if (param.sample_type == up_enum::kNearest) { + std::vector ret; + for (int i = 0; i < param.num_args; ++i) { + ret.push_back(std::string("arg") + std::to_string(i)); } - }); - return op; + return ret; + } else { + return {"data", "weight"}; + } } -Operator* UpSamplingProp::CreateOperatorEx(Context ctx, std::vector *in_shape, - std::vector *in_type) const { - DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); +static bool UpSamplingType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, std::vector *out_type) { + const UpSamplingParam& param = nnvm::get(attrs.parsed); + CHECK_GE(in_type->size(), 1U); + int dtype = (*in_type)[0]; + CHECK_NE(dtype, -1) << "First input must have specified type"; + for (index_t i = 0; i < in_type->size(); ++i) { + if ((*in_type)[i] == -1) { + (*in_type)[i] = dtype; + } else { + UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param)[i]); + } + } + out_type->clear(); + out_type->push_back(dtype); + return true; } +struct UpSamplingGrad { + const char *op_name; + std::vector operator()(const nnvm::NodePtr& n, + const std::vector& ograds) const { + const UpSamplingParam& param_ = nnvm::get(n->attrs.parsed); + std::vector heads(ograds.begin(), ograds.end()); + if (param_.sample_type != up_enum::kNearest) { + heads.push_back(n->inputs[up_enum::kData]); + heads.push_back(n->inputs[up_enum::kWeight]); + } + return MakeGradNode(op_name, n, heads, n->attrs.dict); + } +}; + DMLC_REGISTER_PARAMETER(UpSamplingParam); -MXNET_REGISTER_OP_PROPERTY(UpSampling, UpSamplingProp) +NNVM_REGISTER_OP(UpSampling) .describe("Performs nearest neighbor/bilinear up sampling to inputs.") +.set_num_inputs([](const NodeAttrs& attrs) { + const UpSamplingParam& params = nnvm::get(attrs.parsed); + return params.sample_type == up_enum::kNearest ? params.num_args : 2; +}) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return ListArguments(nnvm::get(attrs.parsed)); +}) +.set_attr("FInferShape", UpSamplingShape) +.set_attr("FInferType", UpSamplingType) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + const UpSamplingParam& param = nnvm::get(n.parsed); + if (param.sample_type == up_enum::kNearest) { + return std::vector(); + } else { + return std::vector{ResourceRequest::kTempSpace}; + } +}) +.set_attr("FCompute", UpSamplingCompute) +.set_attr("FGradient", UpSamplingGrad{"_backward_UpSampling"}) +.set_attr("key_var_num_args", "num_args") .add_argument("data", "NDArray-or-Symbol[]", "Array of tensors to upsample") .add_arguments(UpSamplingParam::__FIELDS__()) -.set_key_var_num_args("num_args"); - -NNVM_REGISTER_OP(UpSampling) .set_attr("FSetInputVarAttrOnCompose", [](const nnvm::NodeAttrs& attrs, nnvm::NodePtr var, const int index) { if (var->attrs.dict.find("__init__") != var->attrs.dict.end()) return; @@ -82,5 +154,23 @@ NNVM_REGISTER_OP(UpSampling) var->attrs.dict["__init__"] = "[\"bilinear\", {}]"; } }); + +NNVM_REGISTER_OP(_backward_UpSampling) +.set_num_outputs([](const NodeAttrs& attrs) { + const UpSamplingParam& params = nnvm::get(attrs.parsed); + return params.sample_type == up_enum::kNearest ? params.num_args : 2; +}) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + const UpSamplingParam& param = nnvm::get(n.parsed); + if (param.sample_type == up_enum::kNearest) { + return std::vector(); + } else { + return std::vector{ResourceRequest::kTempSpace}; + } +}) +.set_attr_parser(ParamParser) +.set_attr("FCompute", UpSamplingGradCompute); + } // namespace op } // namespace mxnet diff --git a/src/operator/nn/upsampling.cu b/src/operator/nn/upsampling.cu index f83535a2b2e6..c5ff2fafd64a 100644 --- a/src/operator/nn/upsampling.cu +++ b/src/operator/nn/upsampling.cu @@ -21,7 +21,7 @@ * Copyright (c) 2015 by Contributors * \file upsampling_nearest.cc * \brief - * \author Bing Xu + * \author Bing Xu, Da Zheng */ #include "./deconvolution-inl.h" @@ -29,36 +29,12 @@ namespace mxnet { namespace op { -template<> -Operator *CreateOp(UpSamplingParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - if (param.sample_type == up_enum::kNearest) { - op = new UpSamplingNearestOp(param); - } else if (param.sample_type == up_enum::kBilinear) { - DeconvolutionParam p = DeconvolutionParam(); - int kernel = 2 * param.scale - param.scale % 2; - int stride = param.scale; - int pad = static_cast(ceil((param.scale - 1) / 2.)); - p.workspace = param.workspace; - p.num_group = param.num_filter; - p.num_filter = param.num_filter; - p.no_bias = true; - int shape[] = {1, 1}; - p.dilate = TShape(shape, shape + 2); - shape[0] = shape[1] = kernel; - p.kernel = TShape(shape, shape + 2); - shape[0] = shape[1] = stride; - p.stride = TShape(shape, shape + 2); - shape[0] = shape[1] = pad; - p.pad = TShape(shape, shape + 2); - op = new DeconvolutionOp(p); - } else { - LOG(FATAL) << "Unknown sample type"; - } - }); - return op; -} + +NNVM_REGISTER_OP(UpSampling) +.set_attr("FCompute", UpSamplingCompute); + +NNVM_REGISTER_OP(_backward_UpSampling) +.set_attr("FCompute", UpSamplingGradCompute); } // namespace op } // namespace mxnet