-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
deconv op implementing ... #4739
Changes from 16 commits
532f38d
1dd6dbb
c4d232c
416f590
da399ae
652f182
451863d
98dccc9
80ebc8d
5ec55e7
43aad98
e8cd4b7
e59ca75
d97a732
c33575a
7eeaae1
8e55736
502e725
64c5ecb
b3ab3ce
cc5e118
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/deconv2d_op.h" | ||
#include "paddle/operators/conv2d_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const { | ||
PADDLE_ENFORCE(ctx->HasInput("Input"), | ||
"Input(Input) of Deconv2DOp should not be null."); | ||
PADDLE_ENFORCE(ctx->HasInput("Filter"), | ||
"Input(Filter) of Deconv2DOp should not be null."); | ||
PADDLE_ENFORCE(ctx->HasOutput("Output"), | ||
"Output(Output) of Deconv2DOp should not be null."); | ||
|
||
auto in_dims = ctx->GetInputDim("Input"); | ||
auto filter_dims = ctx->GetInputDim("Filter"); | ||
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); | ||
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); | ||
|
||
for (int i = 0; i < paddings.size(); ++i) { | ||
PADDLE_ENFORCE_EQ(paddings[i], 0, "No Padding allowed in deconv op."); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check should be placed in "Deconv2DOpMaker", the current attribute checker doesn't support 'vector' type. @Canpio There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For |
||
|
||
PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Deconv2DOp input should be 4-D."); | ||
PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Deconv2DOp filter should be 4-D."); | ||
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Deconv2DOp filter should be 4-D." -> "Deconv2DOp filter should be 4-D tensor." There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
"input and kernel input dimension should be equal."); | ||
|
||
auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2]; | ||
auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3]; | ||
ctx->SetOutputDim("Output", | ||
{in_dims[0], filter_dims[1], output_height, output_width}); | ||
} | ||
|
||
Deconv2DOpMaker::Deconv2DOpMaker(framework::OpProto* proto, | ||
framework::OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput( | ||
"Input", | ||
"The input tensor of deconvolution operator. " | ||
"The format of input tensor is NMHW. Where N is batch size, M is the " | ||
"number of input channels, H and W is the height and width of image."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "(Tensor) The input tensor of transposed 2D convolution operator. " The () is used to denote the type, same as the following annotations. NMHW -> NCHW There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
AddInput("Filter", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Deconv case is a little different from Conv case. Like in Caffe2, Conv2d use NCHW for input and MCHW for filter and produces a tensor of shape NMHW; Caffe2 Deconv applies NCHW for input, CMHW for filter and produces output tensor with shape NMHW. I will make it clear in my codes. |
||
"The filter tensor of deconvolution operator." | ||
"The format of the filter tensor is MCHW, where M is the number of " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "MCHW" - >"NCHW" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
"input image channels, C is the number of output image channels, " | ||
"H and W is height and width of filter. " | ||
"We enforce groups number == 1 and padding == 0 in our " | ||
"deconvolution Scenario."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We enforce groups number == 1 and padding == 0 in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
AddOutput("Output", | ||
"The output tensor of deconvolution operator." | ||
"The format of output tensor is also NCHW."); | ||
AddAttr<std::vector<int>>("strides", "strides of deconvolution operator.") | ||
.SetDefault({1, 1}); | ||
AddAttr<std::vector<int>>("paddings", "paddings of deconvolution operator.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Attribute checker should be placed here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As @Canpio said, for current version to pass, we temporarily put our check here. |
||
.SetDefault({0, 0}); | ||
AddComment(R"DOC( | ||
The deconvolution operation calculates the output based on the input, filter | ||
and strides, paddings, groups parameters. The size of each dimension of the | ||
parameters is checked in the infer-shape. | ||
)DOC"); | ||
} | ||
|
||
void Deconv2DOpGrad::InferShape(framework::InferShapeContext* ctx) const { | ||
auto in_dims = ctx->GetInputDim("Input"); | ||
auto filter_dims = ctx->GetInputDim("Filter"); | ||
if (ctx->HasOutput(framework::GradVarName("Input"))) { | ||
ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); | ||
} | ||
if (ctx->HasOutput(framework::GradVarName("Filter"))) { | ||
ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); | ||
} | ||
} | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP(deconv2d, ops::Deconv2DOp, ops::Deconv2DOpMaker, deconv2d_grad, | ||
ops::Deconv2DOpGrad); | ||
|
||
REGISTER_OP_CPU_KERNEL( | ||
deconv2d, ops::GemmDeconv2DKernel<paddle::platform::CPUPlace, float>); | ||
REGISTER_OP_CPU_KERNEL( | ||
deconv2d_grad, | ||
ops::GemmDeconvGrad2DKernel<paddle::platform::CPUPlace, float>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/deconv2d_op.h" | ||
|
||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OP_GPU_KERNEL( | ||
deconv2d, ops::GemmDeconv2DKernel<paddle::platform::GPUPlace, float>); | ||
REGISTER_OP_GPU_KERNEL( | ||
deconv2d_grad, | ||
ops::GemmDeconvGrad2DKernel<paddle::platform::GPUPlace, float>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "glog/logging.h" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove glog. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
#include "paddle/framework/eigen.h" | ||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/operators/math/im2col.h" | ||
#include "paddle/operators/math/math_function.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using Tensor = framework::Tensor; | ||
using DDim = framework::DDim; | ||
|
||
// Define Op classes in .h file so that other deconv | ||
// operator implementations can reuse the code. | ||
class Deconv2DOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
Deconv2DOpMaker(framework::OpProto* proto, | ||
framework::OpAttrChecker* op_checker); | ||
}; | ||
|
||
class Deconv2DOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(framework::InferShapeContext* ctx) const override; | ||
}; | ||
|
||
class Deconv2DOpGrad : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(framework::InferShapeContext* ctx) const override; | ||
}; | ||
|
||
template <typename Place, typename T> | ||
class GemmDeconv2DKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
const Tensor* input = context.Input<Tensor>("Input"); | ||
// filter will be reshaped, so we do not use constant pointer here | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of "The filter will be reshaped in the calculations, so it should not be constant pointer." ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
Tensor filter = *context.Input<Tensor>("Filter"); | ||
|
||
Tensor* output = context.Output<Tensor>("Output"); | ||
|
||
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); | ||
|
||
// no paddings and groups allowed in deconv | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If to do in next PR, add TODO comments. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
int N = input->dims()[0]; | ||
int M = input->dims()[1]; | ||
int H = input->dims()[2]; | ||
int W = input->dims()[3]; | ||
|
||
int K_H = filter.dims()[2]; | ||
int K_W = filter.dims()[3]; | ||
|
||
int C = output->dims()[1]; // output channels | ||
int O_H = output->dims()[2]; | ||
int O_W = output->dims()[3]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
paddle::operators::math::Col2ImFunctor< | ||
paddle::operators::math::ColFormat::kCFO, Place, T> | ||
col2im; | ||
|
||
// use col_shape in the im2col and col2im calculation | ||
DDim col_shape = {C, K_H, K_W, H, W}; | ||
|
||
// use col_matrix_shape in the gemm calculation | ||
DDim col_matrix_shape = {C * K_H * K_W, H * W}; | ||
|
||
Tensor col; | ||
col.mutable_data<T>(col_shape, context.GetPlace()); | ||
// col_matrix shares the same piece of data with col, | ||
// but will be reshaped into a two-dimensional matrix shape | ||
// to call the matrix multiplication interface. | ||
Tensor col_matrix = col; | ||
col_matrix.Resize(col_matrix_shape); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That copy assign works as intended, but it looks a little unnatural to me at first glance, since for e.g. std::vector, copy assign copies the data. However, copy assignment does share data in this case because the data is stored inside a std::shared_ptr inside the Tensor class. Nevertheless, I would suggest the more explicit:
(I realize this is carried over from conv2d_op.h - maybe you could change it there, too?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
DDim output_shape = {C, O_H, O_W}; | ||
DDim input_matrix_shape = {M, H * W}; | ||
|
||
DDim filter_matrix_shape = {M, C * K_H * K_W}; | ||
filter.Resize(filter_matrix_shape); | ||
|
||
// deconvolution: gemm + col2im (similar to conv-backward on input) | ||
|
||
output->mutable_data<T>(context.GetPlace()); | ||
auto t = framework::EigenVector<T>::Flatten(*output); | ||
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0)); | ||
|
||
for (int i = 0; i < N; i++) { | ||
// batch with size (M, H * W) | ||
Tensor input_batch = input->Slice<T>(i, i + 1).Resize(input_matrix_shape); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update code, since the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
// filter size: (M, C * K_H * K_W) | ||
|
||
// output size: (C, O_H, O_W) | ||
Tensor output_batch = output->Slice<T>(i, i + 1).Resize(output_shape); | ||
|
||
// col_matrix = filter * input_batch | ||
// of shape (C * K_H * K_W, H * W) | ||
math::matmul<Place, T>(context.device_context(), filter, true, | ||
input_batch, false, T(1.0), &col_matrix, T(0.0)); | ||
col2im(context.device_context(), output_batch, col, strides[0], | ||
strides[1], 0, 0); | ||
} | ||
} | ||
}; | ||
|
||
template <typename Place, typename T> | ||
class GemmDeconvGrad2DKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
const Tensor* input = context.Input<Tensor>("Input"); | ||
const Tensor* output_grad = | ||
context.Input<Tensor>(framework::GradVarName("Output")); | ||
|
||
// For filter, we do not use const pointer b/c we will do reshape | ||
// but we should avoid modifying its value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add period. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
Tensor filter = *context.Input<Tensor>("Filter"); | ||
|
||
Tensor* input_grad = | ||
context.Output<Tensor>(framework::GradVarName("Input")); | ||
Tensor* filter_grad = | ||
context.Output<Tensor>(framework::GradVarName("Filter")); | ||
|
||
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); | ||
// Actually, no paddings and groups allowed in deconv | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add period. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); | ||
|
||
int N = input->dims()[0]; | ||
int M = input->dims()[1]; | ||
int H = input->dims()[2]; | ||
int W = input->dims()[3]; | ||
|
||
int K_H = filter.dims()[2]; | ||
int K_W = filter.dims()[3]; | ||
|
||
int C = output_grad->dims()[1]; // output channels | ||
int O_H = output_grad->dims()[2]; | ||
int O_W = output_grad->dims()[3]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
// Only im2col functor required for bp to get to the right shape | ||
paddle::operators::math::Im2ColFunctor< | ||
paddle::operators::math::ColFormat::kCFO, Place, T> | ||
im2col; | ||
|
||
// use col_shape in the im2col and col2im calculation | ||
DDim col_shape = {C, K_H, K_W, H, W}; | ||
|
||
// use col_matrix_shape in the gemm calculation | ||
DDim col_matrix_shape_f = {C * H * W, K_H * K_W}; | ||
|
||
Tensor col; | ||
col.mutable_data<T>(col_shape, context.GetPlace()); | ||
// col_matrix shares the same piece of data with col, | ||
// but will be reshaped into a two-dimensional matrix shape | ||
// to call the matrix multiplication interface. | ||
|
||
DDim output_shape = {C, O_H, O_W}; | ||
DDim input_matrix_shape = {M, H * W}; | ||
|
||
DDim filter_matrix_shape = {M, C * K_H * K_W}; | ||
filter.Resize(filter_matrix_shape); | ||
|
||
// deconvolution grad on input: | ||
// im2col + gemm (similar to conv-forward) | ||
// input need to compute gradient | ||
if (input_grad) { | ||
Tensor col_matrix = col; | ||
DDim col_matrix_shape = {C * K_H * K_W, H * W}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See above comment. I would prefer the more explicit:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
col_matrix.Resize(col_matrix_shape); | ||
|
||
input_grad->mutable_data<T>(context.GetPlace()); | ||
auto t = framework::EigenVector<T>::Flatten(*input_grad); | ||
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0)); | ||
|
||
for (int i = 0; i < N; i++) { | ||
// batch with size (C, O_H * O_W) | ||
Tensor output_grad_batch = | ||
output_grad->Slice<T>(i, i + 1).Resize(output_shape); | ||
// filter of size (M, C * K_H * K_W) | ||
|
||
// batch with size (M, H, W) | ||
Tensor input_grad_batch = | ||
input_grad->Slice<T>(i, i + 1).Resize(input_matrix_shape); | ||
|
||
// im2col: dy from (C, O_H, O_W) -> (C * K_H * K_W, H * W) | ||
im2col(context.device_context(), output_grad_batch, col, strides[0], | ||
strides[1], paddings[0], paddings[1]); | ||
|
||
// gemm: dx = filter * dy | ||
// (M, C * K_H * K_W) * (C * K_H * K_W, H * W) -> (M, C, H) | ||
math::matmul<Place, T>(context.device_context(), filter, false, | ||
col_matrix, false, T(1.0), &input_grad_batch, | ||
T(0.0)); | ||
} | ||
} | ||
|
||
// filter gradient required | ||
if (filter_grad) { | ||
Tensor col_matrix_f = col; | ||
DDim col_matrix_shape_f = {C * H * W, K_H * K_W}; | ||
col_matrix_f.Resize(col_matrix_shape_f); | ||
|
||
filter_grad->mutable_data<T>(context.GetPlace()); | ||
Tensor filter_grad_ = *filter_grad; | ||
filter_grad_.Resize(filter_matrix_shape); | ||
auto t = framework::EigenVector<T>::Flatten(filter_grad_); | ||
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0)); | ||
|
||
for (int i = 0; i < N; ++i) { | ||
// batch with size (C, O_H, O_W) | ||
Tensor output_grad_batch = | ||
output_grad->Slice<T>(i, i + 1).Resize(output_shape); | ||
// input batch | ||
Tensor in_batch = input->Slice<T>(i, i + 1).Resize(input_matrix_shape); | ||
|
||
// im2col: (C * H * W, K_H * K_W) | ||
im2col(context.device_context(), output_grad_batch, col, strides[0], | ||
strides[1], paddings[0], paddings[1]); | ||
|
||
// gemm: d_filter = x * y_grad^T | ||
// (M, C * H * W) * (K_H * K_W, C * H * W) -> (M, C, H) | ||
math::matmul<Place, T>(context.device_context(), in_batch, false, | ||
col_matrix_f, true, T(1.0), &filter_grad_, | ||
T(1.0)); | ||
} | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the comments: tensorflow/tensorflow#256 (comment)
How about rename Conv2DTranspose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great suggestion!