Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Api2.0] add pixel shuffle class #26071

Merged
merged 19 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 65 additions & 30 deletions paddle/fluid/operators/pixel_shuffle_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,40 +28,59 @@ class PixelShuffleOp : public framework::OperatorWithKernel {
"Output(Out) of PixelShuffleOp should not be null."));

auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(
input_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W], but got %u.",
input_dims.size()));
PADDLE_ENFORCE_EQ(input_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
input_dims.size()));

auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");

PADDLE_ENFORCE_EQ(input_dims[1] % (upscale_factor * upscale_factor), 0,
platform::errors::InvalidArgument(
"The square of upscale_factor[%u] should divide the "
"number of channel[%u]",
input_dims[1], upscale_factor * upscale_factor));

const std::string data_format =
ctx->Attrs().Get<std::string>("data_format");
const bool channel_last = (data_format == "NHWC");

if (!channel_last) {
PADDLE_ENFORCE_EQ(
input_dims[1] % (upscale_factor * upscale_factor), 0,
platform::errors::InvalidArgument(
"The square of upscale_factor[%u] should divide the "
"number of channel[%u]",
input_dims[1], upscale_factor * upscale_factor));
} else {
PADDLE_ENFORCE_EQ(
input_dims[3] % (upscale_factor * upscale_factor), 0,
platform::errors::InvalidArgument(
"The square of upscale_factor[%u] should divide the "
"number of channel[%u]",
input_dims[3], upscale_factor * upscale_factor));
}
auto output_dims = input_dims;
output_dims[0] = input_dims[0];
output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor);
output_dims[2] = input_dims[2] * upscale_factor;
output_dims[3] = input_dims[3] * upscale_factor;
if (!channel_last) {
output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor);
output_dims[2] = input_dims[2] * upscale_factor;
output_dims[3] = input_dims[3] * upscale_factor;
} else {
output_dims[1] = input_dims[1] * upscale_factor;
output_dims[2] = input_dims[2] * upscale_factor;
output_dims[3] = input_dims[3] / (upscale_factor * upscale_factor);
}
ctx->SetOutputDim("Out", output_dims);
}
};

class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"X",
"(Tensor, default Tensor<float>), "
"the input feature data of PixelShuffleOp, the layout is [N C H W].");
AddOutput(
"Out",
"(Tensor, default Tensor<float>), the output of "
"PixelShuffleOp. The layout is [N,C/factor^2,H*factor,W*factor].");
AddInput("X",
"(Tensor, default Tensor<float>), "
"the input feature data of PixelShuffleOp, the layout is [N, C, "
"H, W] or [N, H, W, C].");
AddOutput("Out",
"(Tensor, default Tensor<float>), the output of "
"PixelShuffleOp. The layout is [N, C/factor^2, H*factor, "
"W*factor] or [N, H*factor, W*factor, C/factor^2].");
AddAttr<int>("upscale_factor",
"the factor to increase spatial resolution by.")
.SetDefault(1)
Expand All @@ -70,6 +89,11 @@ class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
platform::errors::InvalidArgument(
"upscale_factor should be larger than 0."));
});
AddAttr<std::string>(
"data_format",
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\", Specify the data format of the input data.")
.SetDefault("NCHW");

AddComment(R"DOC(
Pixel Shuffle operator
Expand Down Expand Up @@ -114,19 +138,30 @@ class PixelShuffleGradOp : public framework::OperatorWithKernel {
platform::errors::NotFound("Output(X@Grad) should not be null"));

auto do_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(
do_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W], but got %u.",
do_dims.size()));
PADDLE_ENFORCE_EQ(do_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
do_dims.size()));

auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");

const std::string data_format =
ctx->Attrs().Get<std::string>("data_format");
const bool channel_last = (data_format == "NHWC");

auto dx_dims = do_dims;
dx_dims[0] = do_dims[0];
dx_dims[1] = do_dims[1] * (upscale_factor * upscale_factor);
dx_dims[2] = do_dims[2] / upscale_factor;
dx_dims[3] = do_dims[3] / upscale_factor;

if (!channel_last) {
dx_dims[1] = do_dims[1] * (upscale_factor * upscale_factor);
dx_dims[2] = do_dims[2] / upscale_factor;
dx_dims[3] = do_dims[3] / upscale_factor;
} else {
dx_dims[1] = do_dims[1] / upscale_factor;
dx_dims[2] = do_dims[2] / upscale_factor;
dx_dims[3] = do_dims[3] * (upscale_factor * upscale_factor);
}
ctx->SetOutputDim(framework::GradVarName("X"), dx_dims);
}
};
Expand Down
40 changes: 32 additions & 8 deletions paddle/fluid/operators/pixel_shuffle_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ limitations under the License. */

#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
Expand All @@ -24,23 +25,33 @@ class PixelShuffleOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");

out->mutable_data<T>(ctx.GetPlace());

int factor = ctx.Attr<int>("upscale_factor");

std::string data_format = ctx.Attr<std::string>("data_format");
bool channel_last = (data_format == "NHWC");

auto in_dims = in->dims();
auto o_dims = out->dims();

framework::Tensor t;
t.ShareDataWith(*in);
t.Resize({in_dims[0], o_dims[1], factor, factor, in_dims[2], in_dims[3]});

if (!channel_last) {
t.Resize({in_dims[0], o_dims[1], factor, factor, in_dims[2], in_dims[3]});
} else {
t.Resize({in_dims[0], in_dims[1], in_dims[2], o_dims[3], factor, factor});
}
std::vector<int> axis = {0, 1, 4, 2, 5, 3};

framework::Tensor o;
o.ShareDataWith(*out);
o.Resize({in_dims[0], o_dims[1], in_dims[2], factor, in_dims[3], factor});

if (!channel_last) {
o.Resize({in_dims[0], o_dims[1], in_dims[2], factor, in_dims[3], factor});
} else {
o.Resize({in_dims[0], in_dims[1], factor, in_dims[2], factor, o_dims[3]});
}
math::Transpose<DeviceContext, T, 6> trans;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
trans(dev_ctx, t, &o, axis);
Expand All @@ -58,19 +69,32 @@ class PixelShuffleGradOpKernel : public framework::OpKernel<T> {

int factor = ctx.Attr<int>("upscale_factor");

std::string data_format = ctx.Attr<std::string>("data_format");
bool channel_last = (data_format == "NHWC");

auto do_dims = dout->dims();
auto dx_dims = dx->dims();

framework::Tensor t;
t.ShareDataWith(*dout);
t.Resize({do_dims[0], do_dims[1], dx_dims[2], factor, dx_dims[3], factor});

if (!channel_last) {
t.Resize(
{do_dims[0], do_dims[1], dx_dims[2], factor, dx_dims[3], factor});
} else {
t.Resize(
{do_dims[0], dx_dims[1], factor, dx_dims[2], factor, do_dims[3]});
}
std::vector<int> axis = {0, 1, 3, 5, 2, 4};

framework::Tensor o;
o.ShareDataWith(*dx);
o.Resize({do_dims[0], do_dims[1], factor, factor, dx_dims[2], dx_dims[3]});

if (!channel_last) {
o.Resize(
{do_dims[0], do_dims[1], factor, factor, dx_dims[2], dx_dims[3]});
} else {
o.Resize(
{do_dims[0], dx_dims[1], dx_dims[2], do_dims[3], factor, factor});
}
math::Transpose<DeviceContext, T, 6> trans;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
trans(dev_ctx, t, &o, axis);
Expand Down
78 changes: 70 additions & 8 deletions python/paddle/fluid/tests/unittests/test_pixel_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@
import unittest
import numpy as np
from op_test import OpTest
import paddle


class TestPixelShuffle(OpTest):
def setUp(self):
self.op_type = "pixel_shuffle"
n, c, h, w = 2, 9, 4, 4
up_factor = 3
shape = [n, c, h, w]
x = np.random.random(shape).astype("float64")
def pixel_shuffle_np(x, up_factor, data_format="NCHW"):
if data_format == "NCHW":
n, c, h, w = x.shape
new_shape = (n, c // (up_factor * up_factor), up_factor, up_factor, h,
w)
# reshape to (num,output_channel,upscale_factor,upscale_factor,h,w)
Expand All @@ -34,10 +31,42 @@ def setUp(self):
npresult = npresult.transpose(0, 1, 4, 2, 5, 3)
oshape = [n, c // (up_factor * up_factor), h * up_factor, w * up_factor]
npresult = np.reshape(npresult, oshape)
return npresult
else:
n, h, w, c = x.shape
new_shape = (n, h, w, c // (up_factor * up_factor), up_factor,
up_factor)
# reshape to (num,h,w,output_channel,upscale_factor,upscale_factor)
npresult = np.reshape(x, new_shape)
# transpose to (num,h,upscale_factor,w,upscale_factor,output_channel)
npresult = npresult.transpose(0, 1, 4, 2, 5, 3)
oshape = [n, h * up_factor, w * up_factor, c // (up_factor * up_factor)]
npresult = np.reshape(npresult, oshape)
return npresult


class TestPixelShuffle(OpTest):
def setUp(self):
self.op_type = "pixel_shuffle"
self.init_data_format()
n, c, h, w = 2, 9, 4, 4

if self.format == "NCHW":
shape = [n, c, h, w]
if self.format == "NHWC":
shape = [n, h, w, c]

up_factor = 3

x = np.random.random(shape).astype("float64")
npresult = pixel_shuffle_np(x, up_factor, self.format)

self.inputs = {'X': x}
self.outputs = {'Out': npresult}
self.attrs = {'upscale_factor': up_factor}
self.attrs = {'upscale_factor': up_factor, "data_format": self.format}

def init_data_format(self):
self.format = "NCHW"

def test_check_output(self):
self.check_output()
Expand All @@ -46,5 +75,38 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestChannelLast(TestPixelShuffle):
def init_data_format(self):
self.format = "NHWC"


class TestPixelShuffleDygraph(unittest.TestCase):
def run_pixel_shuffle(self, up_factor, data_format):

n, c, h, w = 2, 9, 4, 4

if data_format == "NCHW":
shape = [n, c, h, w]
if data_format == "NHWC":
shape = [n, h, w, c]

x = np.random.random(shape).astype("float64")

npresult = pixel_shuffle_np(x, up_factor, data_format)

paddle.disable_static()
pixel_shuffle = paddle.nn.PixelShuffle(
up_factor, data_format=data_format)
result = pixel_shuffle(paddle.to_tensor(x))

self.assertTrue(np.allclose(result.numpy(), npresult))

def test_pixel_shuffle(self):
self.run_pixel_shuffle(3, "NCHW")

def test_channel_last(self):
self.run_pixel_shuffle(3, "NHWC")


if __name__ == '__main__':
unittest.main()
3 changes: 3 additions & 0 deletions python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@
# from .layer.rnn import LSTMCell #DEFINE_ALIAS
from .layer.distance import PairwiseDistance #DEFINE_ALIAS

from .layer.vision import PixelShuffle

from .layer import loss #DEFINE_ALIAS
from .layer import conv #DEFINE_ALIAS
from .layer import vision #DEFINE_ALIAS
from ..fluid.dygraph.layers import Layer #DEFINE_ALIAS
from ..fluid.dygraph.container import LayerList, ParameterList, Sequential #DEFINE_ALIAS
4 changes: 2 additions & 2 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@
from .vision import box_coder #DEFINE_ALIAS
from .vision import box_decoder_and_assign #DEFINE_ALIAS
from .vision import collect_fpn_proposals #DEFINE_ALIAS
# from .vision import deformable_conv #DEFINE_ALIAS
# from .vision import deformable_conv #DEFINE_ALIAS
from .vision import deformable_roi_pooling #DEFINE_ALIAS
from .vision import density_prior_box #DEFINE_ALIAS
from .vision import detection_output #DEFINE_ALIAS
Expand All @@ -179,7 +179,7 @@
from .vision import grid_sampler #DEFINE_ALIAS
from .vision import image_resize #DEFINE_ALIAS
from .vision import image_resize_short #DEFINE_ALIAS
# from .vision import multi_box_head #DEFINE_ALIAS
# from .vision import multi_box_head #DEFINE_ALIAS
from .vision import pixel_shuffle #DEFINE_ALIAS
from .vision import prior_box #DEFINE_ALIAS
from .vision import prroi_pool #DEFINE_ALIAS
Expand Down
Loading