Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 2】24、为 Paddle 新增 nn.ChannelShuffle 组网 API #40743

Merged
merged 34 commits into from
Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d7ffa2c
Add infermeta for ChannelShuffle
BrilliantYuKaimin Mar 17, 2022
b1f6915
Create channel_shuffle_grad_kernel.h
BrilliantYuKaimin Mar 17, 2022
6ca6e5c
Create channel_shuffle_kernel.h
BrilliantYuKaimin Mar 17, 2022
9b5fa40
Create channel_shuffle_sig.cc
BrilliantYuKaimin Mar 17, 2022
7d77d46
Create channel_shuffle_op.cc
BrilliantYuKaimin Mar 17, 2022
4c34c10
Create channel_shuffle_kernel_impl.h
BrilliantYuKaimin Mar 17, 2022
be4df84
Create channel_shuffle_grad_kernel_impl.h
BrilliantYuKaimin Mar 17, 2022
51829fc
Add kernel register of channel shuffle and grad
BrilliantYuKaimin Mar 17, 2022
25feafb
add nn.functional.channel_shuffle
BrilliantYuKaimin Mar 17, 2022
2f1a958
add nn.ChannelShuffle
BrilliantYuKaimin Mar 17, 2022
d515b55
Create test_channel_shuffle.py
BrilliantYuKaimin Mar 17, 2022
fa164be
Update example of ChannelShuffle in vision.py
BrilliantYuKaimin Mar 17, 2022
1a29a72
Update test_channel_shuffle.py
BrilliantYuKaimin Mar 20, 2022
cc32215
修改channel_shuffle核函数的实现位置
BrilliantYuKaimin Mar 22, 2022
0c9dd64
修正代码格式
BrilliantYuKaimin Mar 23, 2022
cd6fe41
删除多余空格
BrilliantYuKaimin Mar 23, 2022
1cfdfd3
完善channel_shuffle的错误检查
BrilliantYuKaimin Mar 30, 2022
5cc340d
Update unary.cc
BrilliantYuKaimin Mar 30, 2022
88f6e2b
Update channel_shuffle_op.cc
BrilliantYuKaimin Mar 30, 2022
b848e45
Update test_channel_shuffle.py
BrilliantYuKaimin Mar 30, 2022
e927fc4
Update unary.cc
BrilliantYuKaimin Apr 1, 2022
c0d5651
add channel_shuffle
BrilliantYuKaimin Apr 1, 2022
3a7d322
Update test_channel_shuffle.py
BrilliantYuKaimin Apr 1, 2022
56f7951
Update vision.py
BrilliantYuKaimin Apr 1, 2022
048ef2b
调整代码格式
BrilliantYuKaimin Apr 8, 2022
76feb34
Merge branch 'PaddlePaddle:develop' into channel_shuffle
BrilliantYuKaimin Apr 18, 2022
11b3b03
Update channel_shuffle_sig.cc
BrilliantYuKaimin Apr 18, 2022
2362233
更新ChannelShuffle的文档
BrilliantYuKaimin Apr 19, 2022
dd9be7f
更新channel_shuffle的文档
BrilliantYuKaimin Apr 19, 2022
c4c7862
Merge branch 'PaddlePaddle:develop' into channel_shuffle
BrilliantYuKaimin Apr 21, 2022
d7ae774
remove ChannelShuffleOpArgumentMapping
BrilliantYuKaimin Apr 21, 2022
dbb8fd9
add ChannelShuffleGradInferMeta
BrilliantYuKaimin Apr 21, 2022
37d4a5e
Update channel_shuffle_op.cc
BrilliantYuKaimin Apr 21, 2022
e29fb30
调整channel_shuffle及其梯度的核函数的位置
BrilliantYuKaimin Apr 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions paddle/fluid/operators/channel_shuffle_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {

class ChannelShuffleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};

class ChannelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), "
"the input feature data of ChannelShuffleOp, the layout is "
"[N, C, H, W] or [N, H, W, C].");
AddOutput("Out",
"(Tensor, default Tensor<float>), the output of "
"ChannelShuffleOp. The layout is also [N, C, "
"H, W] or [N, H, W, C].");
AddAttr<int>("groups", "number of groups to divide channels in.");
AddAttr<std::string>(
"data_format",
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\", Specify the data format of the input data.")
.SetDefault("NCHW");

AddComment(R"DOC(
Channel Shuffle operator
This operator divides channels in a tensor of shape :math:`(*, C, H, W)`
into :math:`g` groups and rearranges them as :math:`(*, C/g, g, H, W)`
while keeping the original tensor shape.

Please refer to the paper:
`ShuffleNet: An Extremely Efficient Convolutional Neural Network for
Mobile Devices <https://arxiv.org/abs/1707.01083>`_
by Zhang et. al (2017) for more details.

)DOC");
}
};

class ChannelShuffleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::NotFound("Input(Out@Grad) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::NotFound("Output(X@Grad) should not be null"));

auto do_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(do_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, "
"H, W] or [N, H, W, C], but got %u.",
do_dims.size()));

auto dx_dims = do_dims;
ctx->SetOutputDim(framework::GradVarName("X"), dx_dims);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放入infermeta

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看其他算子也都是只把前向的形状推断放在infermeta中,而反向的形状推断放在xxx_op.cc中。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以放到infermeta里,参考下这个PR的review
#40728

};

template <typename T>
class ChannelShuffleGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("channel_shuffle_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(channel_shuffle, ChannelShuffleInferShapeFunctor,
PD_INFER_META(phi::ChannelShuffleInferMeta));

REGISTER_OPERATOR(channel_shuffle, ops::ChannelShuffleOp,
ops::ChannelShuffleOpMaker,
ops::ChannelShuffleGradOpMaker<paddle::framework::OpDesc>,
ops::ChannelShuffleGradOpMaker<paddle::imperative::OpBase>,
ChannelShuffleInferShapeFunctor);

REGISTER_OPERATOR(channel_shuffle_grad, ops::ChannelShuffleGradOp);
46 changes: 46 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3000,6 +3000,52 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) {
out->set_dtype(DataType::INT64);
}

void ChannelShuffleInferMeta(const MetaTensor& x,
int groups,
const std::string& data_format,
MetaTensor* out) {
auto input_dims = x.dims();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加对groups,dataformat的检查

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

PADDLE_ENFORCE_EQ(input_dims.size(),
4,
phi::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
input_dims.size()));
PADDLE_ENFORCE_GE(
groups,
1,
phi::errors::InvalidArgument("groups should be larger than 0."));
PADDLE_ENFORCE_EQ(data_format == "NCHW" || data_format == "NHWC",
true,
phi::errors::InvalidArgument(
"data_format must be one of "
"NCHW and NHWC. But recevied data_format: %s",
data_format));

const bool channel_last = (data_format == "NHWC");

if (!channel_last) {
PADDLE_ENFORCE_EQ(input_dims[1] % groups,
0,
phi::errors::InvalidArgument(
"The number of groups to divide channels in [%u] "
"should divide the number of channel [%u]",
groups,
input_dims[1]));
} else {
PADDLE_ENFORCE_EQ(input_dims[3] % groups,
0,
phi::errors::InvalidArgument(
"The number of groups to divide channels in [%u] "
"should divide the number of channel [%u]",
groups,
input_dims[3]));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

报错信息可以增加空格方便报错阅读

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是指把in[%u]channel[%u]改成in [%u]channel [%u]吗?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成。

}
auto output_dims = input_dims;
out->set_dtype(x.dtype());
out->set_dims(output_dims);
}

} // namespace phi

PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,4 +440,9 @@ void OneHotInferMeta(const MetaTensor& x, const Scalar& depth, MetaTensor* out);

void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out);

void ChannelShuffleInferMeta(const MetaTensor& x,
int groups,
const std::string& data_format,
MetaTensor* out);

} // namespace phi
74 changes: 74 additions & 0 deletions paddle/phi/kernels/channel_shuffle_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考:#40728 (comment) 进行同理修改。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/channel_shuffle_grad_kernel.h"
#include <string>
#include <vector>
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

template <typename T, typename Context>
void ChannelShuffleGradKernel(const Context& ctx,
const DenseTensor& out_grad,
int groups,
const std::string& data_format,
DenseTensor* x_grad) {
auto* dout = &out_grad;
auto* dx = x_grad;
ctx.template Alloc<T>(dx);
bool channel_last = (data_format == "NHWC");
auto do_dims = dout->dims();
auto dx_dims = dx->dims();

DenseTensor t(*dout);
if (!channel_last) {
t.Resize({do_dims[0], do_dims[1] / groups, groups, do_dims[2], do_dims[3]});
} else {
t.Resize({do_dims[0], do_dims[1], do_dims[2], do_dims[3] / groups, groups});
}
auto axis = !channel_last ? std::vector<int>{0, 2, 1, 3, 4}
: std::vector<int>{0, 1, 2, 4, 3};

DenseTensor o(*dx);
if (!channel_last) {
o.Resize({dx_dims[0], groups, dx_dims[1] / groups, dx_dims[2], dx_dims[3]});
} else {
o.Resize({dx_dims[0], dx_dims[1], dx_dims[2], groups, dx_dims[3] / groups});
}
phi::funcs::Transpose<Context, T, 5> trans;
trans(ctx, t, &o, axis);
dx->Resize(dx_dims);
}

} // namespace phi

PD_REGISTER_KERNEL(channel_shuffle_grad,
CPU,
ALL_LAYOUT,
phi::ChannelShuffleGradKernel,
float,
double) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(channel_shuffle_grad,
GPU,
ALL_LAYOUT,
phi::ChannelShuffleGradKernel,
float,
double) {}
#endif
29 changes: 29 additions & 0 deletions paddle/phi/kernels/channel_shuffle_grad_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void ChannelShuffleGradKernel(const Context& ctx,
const DenseTensor& out_grad,
int groups,
const std::string& data_format,
DenseTensor* x_grad);

} // namespace phi
73 changes: 73 additions & 0 deletions paddle/phi/kernels/channel_shuffle_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考:#40728 (comment) 进行同理修改。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完成

//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/channel_shuffle_kernel.h"
#include <string>
#include <vector>
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

template <typename T, typename Context>
void ChannelShuffleKernel(const Context& ctx,
const DenseTensor& x,
int groups,
const std::string& data_format,
DenseTensor* out) {
auto* in = &x;
ctx.template Alloc<T>(out);
bool channel_last = (data_format == "NHWC");
auto in_dims = in->dims();
auto o_dims = out->dims();

DenseTensor t(*in);
if (!channel_last) {
t.Resize({in_dims[0], groups, in_dims[1] / groups, in_dims[2], in_dims[3]});
} else {
t.Resize({in_dims[0], in_dims[1], in_dims[2], groups, in_dims[3] / groups});
}
auto axis = !channel_last ? std::vector<int>{0, 2, 1, 3, 4}
: std::vector<int>{0, 1, 2, 4, 3};

DenseTensor o(*out);
if (!channel_last) {
o.Resize({in_dims[0], in_dims[1] / groups, groups, in_dims[2], in_dims[3]});
} else {
o.Resize({in_dims[0], in_dims[1], in_dims[2], in_dims[3] / groups, groups});
}
phi::funcs::Transpose<Context, T, 5> trans;
trans(ctx, t, &o, axis);
out->Resize(o_dims);
}

} // namespace phi

PD_REGISTER_KERNEL(channel_shuffle,
CPU,
ALL_LAYOUT,
phi::ChannelShuffleKernel,
float,
double) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(channel_shuffle,
GPU,
ALL_LAYOUT,
phi::ChannelShuffleKernel,
float,
double) {}
#endif
29 changes: 29 additions & 0 deletions paddle/phi/kernels/channel_shuffle_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void ChannelShuffleKernel(const Context& ctx,
const DenseTensor& x,
int groups,
const std::string& data_format,
DenseTensor* out);

} // namespace phi
Loading