Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Use NNVM interface for upsampling.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Nov 27, 2017
1 parent 94e8ee4 commit 563f7f8
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 229 deletions.
232 changes: 70 additions & 162 deletions src/operator/nn/upsampling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
* \brief
* \author Bing Xu
*/
#ifndef MXNET_OPERATOR_UPSAMPLING_INL_H_
#define MXNET_OPERATOR_UPSAMPLING_INL_H_
#ifndef MXNET_OPERATOR_NN_UPSAMPLING_INL_H_
#define MXNET_OPERATOR_NN_UPSAMPLING_INL_H_

#include <dmlc/logging.h>
#include <dmlc/parameter.h>
Expand All @@ -34,7 +34,8 @@
#include <vector>
#include <string>
#include <utility>
#include "./operator_common.h"
#include "../operator_common.h"
#include "./deconvolution-inl.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -82,17 +83,16 @@ struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
}; // struct UpSamplingParam

template<typename xpu, typename DType>
class UpSamplingNearestOp : public Operator {
class UpSamplingNearestOp {
public:
explicit UpSamplingNearestOp(UpSamplingParam p) {
void Init(UpSamplingParam p) {
this->param_ = p;
}

virtual void Forward(const OpContext &ctx,
void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
const std::vector<TBlob> &out_data) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(in_data.size(), static_cast<size_t>(param_.num_args));
Expand Down Expand Up @@ -125,19 +125,14 @@ class UpSamplingNearestOp : public Operator {
}
}

virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
void Backward(const OpContext &ctx, const TBlob &out_grad,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_args) {
const std::vector<TBlob> &in_grad) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(out_grad.size(), 1U);
CHECK_EQ(in_grad.size(), static_cast<size_t>(param_.num_args));
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4, DType> grad = out_grad[up_enum::kOut].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> grad = out_grad.get<xpu, 4, DType>(s);
if (param_.num_args > 1) {
int begin = 0;
for (int i = 0; i < param_.num_args; ++i) {
Expand Down Expand Up @@ -181,155 +176,68 @@ class UpSamplingNearestOp : public Operator {
UpSamplingParam param_;
}; // class UpSamplingNearestOp

template<typename xpu>
Operator *CreateOp(UpSamplingParam param, int dtype);


#if DMLC_USE_CXX11
class UpSamplingProp : public OperatorProperty {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
}

std::map<std::string, std::string> GetParams() const override {
return param_.__DICT__();
}

std::vector<std::string> ListArguments() const override {
if (param_.sample_type == up_enum::kNearest) {
std::vector<std::string> ret;
for (int i = 0; i < param_.num_args; ++i) {
ret.push_back(std::string("arg") + std::to_string(i));
}
return ret;
} else {
return {"data", "weight"};
}
}

bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const override {
CHECK_GE(in_shape->size(), 1U);
const TShape &dshape = (*in_shape)[0];
TShape oshape = dshape;
if (param_.sample_type == up_enum::kNearest) {
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
oshape[1] = 0;
for (auto& shape : *in_shape) {
CHECK_EQ(shape.ndim(), 4U) << \
"UpSamplingNearest: Input data should be 4D in (batch, channel, y, x)";
int oh = dshape[2]*param_.scale, ow = dshape[3]*param_.scale;
CHECK_EQ(oh%shape[2], 0U) << "UpSamplingNearest: input height of " << shape[2] << \
"does not divide output height of " << oh;
CHECK_EQ(ow%shape[3], 0U) << "UpSamplingNearest: input width of " << shape[3] << \
"does not divide output width of " << ow;
if (param_.multi_input_mode == up_enum::kSum) {
CHECK(oshape[1] == 0 || oshape[1] == shape[1]) << \
"Number of channels must be the same when multi_input_mode==sum";
oshape[1] = shape[1];
} else {
oshape[1] += shape[1];
}
}
} else {
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
CHECK_EQ(dshape.ndim(), 4U) << \
"UpSamplingBilinear: Input data should be 4D in (batch, channel, y, x)";
if (dshape.ndim() == 0) return false;
int kernel = 2 * param_.scale - param_.scale % 2;
SHAPE_ASSIGN_CHECK(*in_shape,
up_enum::kWeight,
mshadow::Shape4(dshape[1], 1, kernel, kernel));
oshape = dshape;
}
oshape[2] = dshape[2] * param_.scale;
oshape[3] = dshape[3] * param_.scale;
out_shape->clear();
out_shape->push_back(oshape);
return true;
}

bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
CHECK_GE(in_type->size(), 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
for (index_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
}
static inline DeconvolutionParam GetDeconvolutionParam(const UpSamplingParam& param) {
DeconvolutionParam p = DeconvolutionParam();
int kernel = 2 * param.scale - param.scale % 2;
int stride = param.scale;
int pad = static_cast<int>(ceil((param.scale - 1) / 2.));
p.workspace = param.workspace;
p.num_group = param.num_filter;
p.num_filter = param.num_filter;
p.no_bias = true;
int shape[] = {1, 1};
p.dilate = TShape(shape, shape + 2);
shape[0] = shape[1] = kernel;
p.kernel = TShape(shape, shape + 2);
shape[0] = shape[1] = stride;
p.stride = TShape(shape, shape + 2);
shape[0] = shape[1] = pad;
p.pad = TShape(shape, shape + 2);
return p;
}

OperatorProperty* Copy() const override {
auto ptr = new UpSamplingProp();
ptr->param_ = this->param_;
return ptr;
}

std::string TypeString() const override {
return "UpSampling";
}

std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const override {
if (param_.sample_type == up_enum::kNearest) {
return {out_grad[up_enum::kOut]};
} else {
return {out_grad[up_enum::kOut], in_data[up_enum::kData], in_data[up_enum::kWeight]};
}
}

std::vector<std::pair<int, void*> > BackwardInplaceOption(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data,
const std::vector<void*> &in_grad) const override {
return {};
}

std::vector<ResourceRequest> ForwardResource(
const std::vector<TShape> &in_shape) const override {
if (param_.sample_type == up_enum::kNearest) {
return {};
} else {
return {ResourceRequest::kTempSpace};
}
}

std::vector<ResourceRequest> BackwardResource(
const std::vector<TShape> &in_shape) const override {
if (param_.sample_type == up_enum::kNearest) {
return {};
} else {
return {ResourceRequest::kTempSpace};
}
}

Operator* CreateOperator(Context ctx) const override {
LOG(FATAL) << "Not Implemented";
return NULL;
}

Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const override;
template<typename xpu>
void UpSamplingCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const UpSamplingParam& param = nnvm::get<UpSamplingParam>(attrs.parsed);
if (param.sample_type == up_enum::kNearest) {
MSHADOW_REAL_TYPE_SWITCH(inputs[deconv::kData].type_flag_, DType, {
static thread_local UpSamplingNearestOp<xpu, DType> op;
op.Init(param);
op.Forward(ctx, inputs, req, outputs);
});
} else if (param.sample_type == up_enum::kBilinear) {
DeconvolutionParam p = GetDeconvolutionParam(param);
_DeconvolutionCompute<xpu>(p, ctx, inputs, req, outputs);
} else {
LOG(FATAL) << "Unknown sample type";
}
}

template<typename xpu>
void UpSamplingGradCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const UpSamplingParam& param = nnvm::get<UpSamplingParam>(attrs.parsed);
if (param.sample_type == up_enum::kNearest) {
MSHADOW_REAL_TYPE_SWITCH(inputs[deconv::kData].type_flag_, DType, {
CHECK_EQ(inputs.size(), 1U);
static thread_local UpSamplingNearestOp<xpu, DType> op;
op.Init(param);
op.Backward(ctx, inputs[0], req, outputs);
});
} else if (param.sample_type == up_enum::kBilinear) {
DeconvolutionParam p = GetDeconvolutionParam(param);
_DeconvolutionGradCompute<xpu>(p, ctx, inputs, req, outputs);
} else {
LOG(FATAL) << "Unknown sample type";
}
}

private:
UpSamplingParam param_;
}; // class UpSamplingProp
#endif // DMLC_USE_CXX11
} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_UPSAMPLING_INL_H_
#endif // MXNET_OPERATOR_NN_UPSAMPLING_INL_H_
Loading

0 comments on commit 563f7f8

Please sign in to comment.