-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 2】29、为 Paddle 新增 PixelUnshuffle 组网 API #40728
Merged
jeff41404
merged 40 commits into
PaddlePaddle:develop
from
BrilliantYuKaimin:pixel_unshuffle
Apr 26, 2022
Merged
Changes from 39 commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
995bdad
增加PixelUnshuffle的形状推断
BrilliantYuKaimin 8c401fb
增加PixelUnshuffle的算子注册
BrilliantYuKaimin 7a62b6e
增加PixelUnshuffle及其梯度的核函数
BrilliantYuKaimin 9591a48
增加PixelUnshuffle算子的描述
BrilliantYuKaimin f6ad365
增加PixelUnshuffle算子的签名
BrilliantYuKaimin 73aed02
在Python层面增加PixelUnshuffle
BrilliantYuKaimin 8a259c0
增加PixelUnshuffle的单测
BrilliantYuKaimin b28157e
Update test_pixel_unshuffle.py
BrilliantYuKaimin d16b545
test=document_fix
BrilliantYuKaimin 89e36a0
Update test_pixel_unshuffle.py
BrilliantYuKaimin 3792573
修正代码格式
BrilliantYuKaimin 3388c97
Update test_pixel_unshuffle.py
BrilliantYuKaimin 9e28fef
修改pixel_unshuffle核函数的实现位置
BrilliantYuKaimin ef6f8ea
修正代码格式
BrilliantYuKaimin 51bb6f8
完善对输入的检查
BrilliantYuKaimin cf80ace
Update test_pixel_unshuffle.py
BrilliantYuKaimin 4ca1ab4
完善pixel_unshuffle的输入检查
BrilliantYuKaimin ea07a17
Update pixel_unshuffle_op.cc
BrilliantYuKaimin 41c0705
Merge branch 'develop' into pixel_unshuffle
BrilliantYuKaimin b0cc19a
Update unary.cc
BrilliantYuKaimin e96d98a
add pixel_unshuffle
BrilliantYuKaimin fc1ff53
Update test_pixel_unshuffle.py
BrilliantYuKaimin b3c084b
Update vision.py
BrilliantYuKaimin bcb06dd
调整代码格式
BrilliantYuKaimin b3126ed
Update vision.py
BrilliantYuKaimin be2f4ad
Merge branch 'PaddlePaddle:develop' into pixel_unshuffle
BrilliantYuKaimin 034c17f
Merge branch 'develop' into pixel_unshuffle
BrilliantYuKaimin 8653519
Delete extra spaces
BrilliantYuKaimin b13f476
Merge branch 'PaddlePaddle:develop' into pixel_unshuffle
BrilliantYuKaimin 9aeb8b8
Merge branch 'PaddlePaddle:develop' into pixel_unshuffle
BrilliantYuKaimin c3fbce6
Update pixel_unshuffle_sig.cc
BrilliantYuKaimin d0e0351
Merge branch 'PaddlePaddle:develop' into pixel_unshuffle
BrilliantYuKaimin 2310fc8
Update vision.py
BrilliantYuKaimin 777c9f8
Update vision.py
BrilliantYuKaimin b937d0e
add PixelUnshuffleGradInferMeta
BrilliantYuKaimin 73fabcb
remove PixelUnshuffleOpArgumentMapping
BrilliantYuKaimin d5f6874
Update pixel_unshuffle_op.cc
BrilliantYuKaimin e871227
调整pixel_unshuffle及其梯度的核函数的实现位置
BrilliantYuKaimin 948f32b
Update pixel_unshuffle_op.cc
BrilliantYuKaimin e270ab7
Merge branch 'develop' into pixel_unshuffle
BrilliantYuKaimin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/infershape_utils.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/op_version_registry.h" | ||
#include "paddle/phi/core/infermeta_utils.h" | ||
#include "paddle/phi/infermeta/backward.h" | ||
#include "paddle/phi/infermeta/unary.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class PixelUnshuffleOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
}; | ||
|
||
class PixelUnshuffleOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("X", | ||
"(Tensor, default Tensor<float>), " | ||
"the input feature data of PixelUnshuffleOp, the layout is " | ||
"[N, C, H, W] or [N, H, W, C]."); | ||
AddOutput("Out", | ||
"(Tensor, default Tensor<float>), the output of " | ||
"PixelUnshuffleOp. The layout is [N, C*factor^2, H/factor, " | ||
"W/factor] or [N, H/factor, W/factor, C*factor^2]."); | ||
AddAttr<int>("downscale_factor", | ||
"the factor to decrease spatial resolution by.") | ||
.SetDefault(1); | ||
AddAttr<std::string>( | ||
"data_format", | ||
"An optional string from: \"NHWC\", \"NCHW\". " | ||
"Defaults to \"NHWC\", Specify the data format of the input data.") | ||
.SetDefault("NCHW"); | ||
|
||
AddComment(R"DOC( | ||
Pixel Unshuffle operator | ||
This operator rearranges elements in a tensor of shape :math:`(*, C, H, W)` | ||
to a tensor of shape :math:`(*, C\times r^2, H / r, W / r)`. | ||
|
||
This operation is the reversion of PixelShuffle operation. | ||
|
||
Please refer to the paper: | ||
`Real-Time Single Image and Video Super-Resolution Using an Efficient | ||
Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158v2>`_ | ||
by Shi et. al (2016) for more details. | ||
|
||
)DOC"); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class PixelUnshuffleGradOpMaker : public framework::SingleGradOpMaker<T> { | ||
public: | ||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
||
protected: | ||
void Apply(GradOpPtr<T> op) const override { | ||
op->SetType("pixel_unshuffle_grad"); | ||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); | ||
op->SetAttrMap(this->Attrs()); | ||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); | ||
} | ||
}; | ||
|
||
class PixelUnshuffleGradOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
DECLARE_INFER_SHAPE_FUNCTOR(pixel_unshuffle, PixelUnshuffleInferShapeFunctor, | ||
PD_INFER_META(phi::PixelUnshuffleInferMeta)); | ||
|
||
REGISTER_OPERATOR(pixel_unshuffle, ops::PixelUnshuffleOp, | ||
ops::PixelUnshuffleOpMaker, | ||
ops::PixelUnshuffleGradOpMaker<paddle::framework::OpDesc>, | ||
ops::PixelUnshuffleGradOpMaker<paddle::imperative::OpBase>, | ||
PixelUnshuffleInferShapeFunctor); | ||
|
||
DECLARE_INFER_SHAPE_FUNCTOR(pixel_unshuffle_grad, | ||
PixelUnshuffleGradInferShapeFunctor, | ||
PD_INFER_META(phi::PixelUnshuffleGradInferMeta)); | ||
|
||
REGISTER_OPERATOR(pixel_unshuffle_grad, ops::PixelUnshuffleGradOp, | ||
PixelUnshuffleGradInferShapeFunctor); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1423,6 +1423,66 @@ void PixelShuffleGradInferMeta(const MetaTensor& out_grad, | |
x_grad->set_dtype(out_grad.dtype()); | ||
} | ||
|
||
void PixelUnshuffleInferMeta(const MetaTensor& x, | ||
int downscale_factor, | ||
const std::string& data_format, | ||
MetaTensor* out) { | ||
auto input_dims = x.dims(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 增加对downscale_factor的检查、对输入format的检查 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 完成 |
||
PADDLE_ENFORCE_EQ(input_dims.size(), | ||
4, | ||
phi::errors::InvalidArgument( | ||
"Input should be a 4-D tensor of format [N, C, H, W] " | ||
"or [N, H, W, C], but got %u.", | ||
input_dims.size())); | ||
PADDLE_ENFORCE_GE(downscale_factor, | ||
1, | ||
phi::errors::InvalidArgument( | ||
"downscale_factor should be larger than 0.")); | ||
PADDLE_ENFORCE_EQ(data_format == "NCHW" || data_format == "NHWC", | ||
true, | ||
phi::errors::InvalidArgument( | ||
"data_format must be one of " | ||
"NCHW and NHWC. But recevied data_format: %s", | ||
data_format)); | ||
|
||
const bool channel_last = (data_format == "NHWC"); | ||
|
||
if (!channel_last) { | ||
PADDLE_ENFORCE_EQ( | ||
(input_dims[2] % downscale_factor) == 0 && | ||
(input_dims[3] % downscale_factor) == 0, | ||
true, | ||
phi::errors::InvalidArgument("Downscale factor[%u] should divide both " | ||
"height[%u] and width[%u]", | ||
downscale_factor, | ||
input_dims[2], | ||
input_dims[3])); | ||
} else { | ||
PADDLE_ENFORCE_EQ( | ||
(input_dims[1] % downscale_factor) == 0 && | ||
(input_dims[2] % downscale_factor) == 0, | ||
true, | ||
phi::errors::InvalidArgument("Downscale factor[%u] should divide both " | ||
"height[%u] and width[%u]", | ||
downscale_factor, | ||
input_dims[1], | ||
input_dims[2])); | ||
} | ||
auto output_dims = input_dims; | ||
output_dims[0] = input_dims[0]; | ||
if (!channel_last) { | ||
output_dims[1] = input_dims[1] * (downscale_factor * downscale_factor); | ||
output_dims[2] = input_dims[2] / downscale_factor; | ||
output_dims[3] = input_dims[3] / downscale_factor; | ||
} else { | ||
output_dims[1] = input_dims[1] / downscale_factor; | ||
output_dims[2] = input_dims[2] / downscale_factor; | ||
output_dims[3] = input_dims[3] * (downscale_factor * downscale_factor); | ||
} | ||
out->set_dtype(x.dtype()); | ||
out->set_dims(output_dims); | ||
} | ||
|
||
void PNormInferMeta(const MetaTensor& x, | ||
float porder, | ||
int axis, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" | ||
#include "paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h" | ||
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
PD_REGISTER_KERNEL(pixel_unshuffle_grad, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::PixelUnshuffleGradKernel, | ||
float, | ||
double) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" | ||
#include "paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h" | ||
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
PD_REGISTER_KERNEL(pixel_unshuffle, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::PixelUnshuffleKernel, | ||
float, | ||
double) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h" | ||
#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" | ||
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
PD_REGISTER_KERNEL(pixel_unshuffle_grad, | ||
GPU, | ||
ALL_LAYOUT, | ||
phi::PixelUnshuffleGradKernel, | ||
float, | ||
double) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h" | ||
#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" | ||
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
PD_REGISTER_KERNEL(pixel_unshuffle, | ||
GPU, | ||
ALL_LAYOUT, | ||
phi::PixelUnshuffleKernel, | ||
float, | ||
double) {} |
58 changes: 58 additions & 0 deletions
58
paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
#include "paddle/phi/core/dense_tensor.h" | ||
#include "paddle/phi/kernels/funcs/math_function.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void PixelUnshuffleGradKernel(const Context& dev_ctx, | ||
const DenseTensor& out_grad, | ||
int downscale_factor, | ||
const std::string& data_format, | ||
DenseTensor* x_grad) { | ||
auto* dout = &out_grad; | ||
auto* dx = x_grad; | ||
dev_ctx.template Alloc<T>(dx); | ||
int factor = downscale_factor; | ||
bool channel_last = (data_format == "NHWC"); | ||
auto do_dims = dout->dims(); | ||
auto dx_dims = dx->dims(); | ||
|
||
DenseTensor t(*dout); | ||
if (!channel_last) { | ||
t.Resize({do_dims[0], dx_dims[1], factor, factor, do_dims[2], do_dims[3]}); | ||
} else { | ||
t.Resize({do_dims[0], do_dims[1], do_dims[2], dx_dims[3], factor, factor}); | ||
} | ||
std::vector<int> axis = {0, 1, 4, 2, 5, 3}; | ||
|
||
DenseTensor o(*dx); | ||
if (!channel_last) { | ||
o.Resize({do_dims[0], dx_dims[1], do_dims[2], factor, do_dims[3], factor}); | ||
} else { | ||
o.Resize({do_dims[0], do_dims[1], factor, do_dims[2], factor, dx_dims[3]}); | ||
} | ||
phi::funcs::Transpose<Context, T, 6> trans; | ||
trans(dev_ctx, t, &o, axis); | ||
dx->Resize(dx_dims); | ||
} | ||
|
||
} // namespace phi |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apply 函数增加 protected 关键字限制访问。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
完成