Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 2】24、为 Paddle 新增 nn.ChannelShuffle 组网 API #40743

Merged
merged 34 commits into from
Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d7ffa2c
Add infermeta for ChannelShuffle
BrilliantYuKaimin Mar 17, 2022
b1f6915
Create channel_shuffle_grad_kernel.h
BrilliantYuKaimin Mar 17, 2022
6ca6e5c
Create channel_shuffle_kernel.h
BrilliantYuKaimin Mar 17, 2022
9b5fa40
Create channel_shuffle_sig.cc
BrilliantYuKaimin Mar 17, 2022
7d77d46
Create channel_shuffle_op.cc
BrilliantYuKaimin Mar 17, 2022
4c34c10
Create channel_shuffle_kernel_impl.h
BrilliantYuKaimin Mar 17, 2022
be4df84
Create channel_shuffle_grad_kernel_impl.h
BrilliantYuKaimin Mar 17, 2022
51829fc
Add kernel register of channel shuffle and grad
BrilliantYuKaimin Mar 17, 2022
25feafb
add nn.functional.channel_shuffle
BrilliantYuKaimin Mar 17, 2022
2f1a958
add nn.ChannelShuffle
BrilliantYuKaimin Mar 17, 2022
d515b55
Create test_channel_shuffle.py
BrilliantYuKaimin Mar 17, 2022
fa164be
Update example of ChannelShuffle in vision.py
BrilliantYuKaimin Mar 17, 2022
1a29a72
Update test_channel_shuffle.py
BrilliantYuKaimin Mar 20, 2022
cc32215
修改channel_shuffle核函数的实现位置
BrilliantYuKaimin Mar 22, 2022
0c9dd64
修正代码格式
BrilliantYuKaimin Mar 23, 2022
cd6fe41
删除多余空格
BrilliantYuKaimin Mar 23, 2022
1cfdfd3
完善channel_shuffle的错误检查
BrilliantYuKaimin Mar 30, 2022
5cc340d
Update unary.cc
BrilliantYuKaimin Mar 30, 2022
88f6e2b
Update channel_shuffle_op.cc
BrilliantYuKaimin Mar 30, 2022
b848e45
Update test_channel_shuffle.py
BrilliantYuKaimin Mar 30, 2022
e927fc4
Update unary.cc
BrilliantYuKaimin Apr 1, 2022
c0d5651
add channel_shuffle
BrilliantYuKaimin Apr 1, 2022
3a7d322
Update test_channel_shuffle.py
BrilliantYuKaimin Apr 1, 2022
56f7951
Update vision.py
BrilliantYuKaimin Apr 1, 2022
048ef2b
调整代码格式
BrilliantYuKaimin Apr 8, 2022
76feb34
Merge branch 'PaddlePaddle:develop' into channel_shuffle
BrilliantYuKaimin Apr 18, 2022
11b3b03
Update channel_shuffle_sig.cc
BrilliantYuKaimin Apr 18, 2022
2362233
更新ChannelShuffle的文档
BrilliantYuKaimin Apr 19, 2022
dd9be7f
更新channel_shuffle的文档
BrilliantYuKaimin Apr 19, 2022
c4c7862
Merge branch 'PaddlePaddle:develop' into channel_shuffle
BrilliantYuKaimin Apr 21, 2022
d7ae774
remove ChannelShuffleOpArgumentMapping
BrilliantYuKaimin Apr 21, 2022
dbb8fd9
add ChannelShuffleGradInferMeta
BrilliantYuKaimin Apr 21, 2022
37d4a5e
Update channel_shuffle_op.cc
BrilliantYuKaimin Apr 21, 2022
e29fb30
调整channel_shuffle及其梯度的核函数的位置
BrilliantYuKaimin Apr 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 20 additions & 33 deletions paddle/fluid/operators/channel_shuffle_op.cc
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按照这个修改下吧~
image

//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
Expand Down Expand Up @@ -62,25 +63,6 @@ class ChannelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
class ChannelShuffleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::NotFound("Input(Out@Grad) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::NotFound("Output(X@Grad) should not be null"));

auto do_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(do_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, "
"H, W] or [N, H, W, C], but got %u.",
do_dims.size()));

auto dx_dims = do_dims;
ctx->SetOutputDim(framework::GradVarName("X"), dx_dims);
}
};

template <typename T>
Expand Down Expand Up @@ -110,4 +92,9 @@ REGISTER_OPERATOR(channel_shuffle, ops::ChannelShuffleOp,
ops::ChannelShuffleGradOpMaker<paddle::imperative::OpBase>,
ChannelShuffleInferShapeFunctor);

REGISTER_OPERATOR(channel_shuffle_grad, ops::ChannelShuffleGradOp);
DECLARE_INFER_SHAPE_FUNCTOR(channel_shuffle_grad,
ChannelShuffleGradInferShapeFunctor,
PD_INFER_META(phi::ChannelShuffleGradInferMeta));

REGISTER_OPERATOR(channel_shuffle_grad, ops::ChannelShuffleGradOp,
ChannelShuffleGradInferShapeFunctor);
16 changes: 16 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,22 @@ void BilinearTensorProductGradInferMeta(const MetaTensor& x,
}
}

void ChannelShuffleGradInferMeta(const MetaTensor& out_grad,
int groups,
const std::string& data_format,
MetaTensor* x_grad) {
auto do_dims = out_grad.dims();
PADDLE_ENFORCE_EQ(do_dims.size(),
4,
phi::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
do_dims.size()));
auto dx_dims = do_dims;
x_grad->set_dims(dx_dims);
x_grad->set_dtype(out_grad.dtype());
}

void ConvTransposeGradInferMeta(const MetaTensor& x,
const MetaTensor& filter,
const MetaTensor& dout,
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ void BilinearTensorProductGradInferMeta(const MetaTensor& x,
MetaTensor* dweight,
MetaTensor* dbias);

void ChannelShuffleGradInferMeta(const MetaTensor& out_grad,
int groups,
const std::string& data_format,
MetaTensor* x_grad);

void ConvTransposeGradInferMeta(const MetaTensor& x,
const MetaTensor& filter,
const MetaTensor& dout,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/channel_shuffle_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
namespace phi {

template <typename T, typename Context>
void ChannelShuffleGradKernel(const Context& ctx,
void ChannelShuffleGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
int groups,
const std::string& data_format,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/channel_shuffle_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
namespace phi {

template <typename T, typename Context>
void ChannelShuffleKernel(const Context& ctx,
void ChannelShuffleKernel(const Context& dev_ctx,
const DenseTensor& x,
int groups,
const std::string& data_format,
Expand Down
26 changes: 26 additions & 0 deletions paddle/phi/kernels/cpu/channel_shuffle_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/channel_shuffle_grad_kernel.h"
#include "paddle/phi/kernels/impl/channel_shuffle_grad_kernel_impl.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

PD_REGISTER_KERNEL(channel_shuffle_grad,
CPU,
ALL_LAYOUT,
phi::ChannelShuffleGradKernel,
float,
double) {}
26 changes: 26 additions & 0 deletions paddle/phi/kernels/cpu/channel_shuffle_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/channel_shuffle_kernel.h"
#include "paddle/phi/kernels/impl/channel_shuffle_kernel_impl.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

PD_REGISTER_KERNEL(channel_shuffle,
CPU,
ALL_LAYOUT,
phi::ChannelShuffleKernel,
float,
double) {}
26 changes: 26 additions & 0 deletions paddle/phi/kernels/gpu/channel_shuffle_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/channel_shuffle_grad_kernel.h"
#include "paddle/phi/kernels/impl/channel_shuffle_grad_kernel_impl.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

PD_REGISTER_KERNEL(channel_shuffle_grad,
GPU,
ALL_LAYOUT,
phi::ChannelShuffleGradKernel,
float,
double) {}
26 changes: 26 additions & 0 deletions paddle/phi/kernels/gpu/channel_shuffle_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/channel_shuffle_kernel.h"
#include "paddle/phi/kernels/impl/channel_shuffle_kernel_impl.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

PD_REGISTER_KERNEL(channel_shuffle,
GPU,
ALL_LAYOUT,
phi::ChannelShuffleKernel,
float,
double) {}
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/channel_shuffle_grad_kernel.h"
#pragma once

#include <string>
#include <vector>
#include "paddle/phi/backends/all_context.h"

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

template <typename T, typename Context>
void ChannelShuffleGradKernel(const Context& ctx,
void ChannelShuffleGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
int groups,
const std::string& data_format,
DenseTensor* x_grad) {
auto* dout = &out_grad;
auto* dx = x_grad;
ctx.template Alloc<T>(dx);
dev_ctx.template Alloc<T>(dx);
bool channel_last = (data_format == "NHWC");
auto do_dims = dout->dims();
auto dx_dims = dx->dims();
Expand All @@ -51,24 +51,8 @@ void ChannelShuffleGradKernel(const Context& ctx,
o.Resize({dx_dims[0], dx_dims[1], dx_dims[2], groups, dx_dims[3] / groups});
}
phi::funcs::Transpose<Context, T, 5> trans;
trans(ctx, t, &o, axis);
trans(dev_ctx, t, &o, axis);
dx->Resize(dx_dims);
}

} // namespace phi

PD_REGISTER_KERNEL(channel_shuffle_grad,
CPU,
ALL_LAYOUT,
phi::ChannelShuffleGradKernel,
float,
double) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(channel_shuffle_grad,
GPU,
ALL_LAYOUT,
phi::ChannelShuffleGradKernel,
float,
double) {}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/channel_shuffle_kernel.h"
#pragma once

#include <string>
#include <vector>
#include "paddle/phi/backends/all_context.h"

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

template <typename T, typename Context>
void ChannelShuffleKernel(const Context& ctx,
void ChannelShuffleKernel(const Context& dev_ctx,
const DenseTensor& x,
int groups,
const std::string& data_format,
DenseTensor* out) {
auto* in = &x;
ctx.template Alloc<T>(out);
dev_ctx.template Alloc<T>(out);
bool channel_last = (data_format == "NHWC");
auto in_dims = in->dims();
auto o_dims = out->dims();
Expand All @@ -50,24 +50,8 @@ void ChannelShuffleKernel(const Context& ctx,
o.Resize({in_dims[0], in_dims[1], in_dims[2], in_dims[3] / groups, groups});
}
phi::funcs::Transpose<Context, T, 5> trans;
trans(ctx, t, &o, axis);
trans(dev_ctx, t, &o, axis);
out->Resize(o_dims);
}

} // namespace phi

PD_REGISTER_KERNEL(channel_shuffle,
CPU,
ALL_LAYOUT,
phi::ChannelShuffleKernel,
float,
double) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(channel_shuffle,
GPU,
ALL_LAYOUT,
phi::ChannelShuffleKernel,
float,
double) {}
#endif
8 changes: 0 additions & 8 deletions paddle/phi/ops/compat/channel_shuffle_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@

namespace phi {

KernelSignature ChannelShuffleOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"channel_shuffle", {"X"}, {"groups", "data_format"}, {"Out"});
}

KernelSignature ChannelShuffleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("channel_shuffle_grad",
Expand All @@ -32,7 +26,5 @@ KernelSignature ChannelShuffleGradOpArgumentMapping(

} // namespace phi

PD_REGISTER_ARG_MAPPING_FN(channel_shuffle,
phi::ChannelShuffleOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(channel_shuffle_grad,
phi::ChannelShuffleGradOpArgumentMapping);