Skip to content

Commit

Permalink
Add conj pixel shuffle yaml (#41499)
Browse files Browse the repository at this point in the history
* ad conj flip yaml

* add flip conj pixel shuffle
  • Loading branch information
phlrain authored Apr 8, 2022
1 parent 9844aaf commit bc88fbb
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 44 deletions.
42 changes: 5 additions & 37 deletions paddle/fluid/operators/pixel_shuffle_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,42 +82,6 @@ class PixelShuffleGradMaker : public framework::SingleGradOpMaker<T> {
class PixelShuffleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::NotFound("Input(Out@Grad) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::NotFound("Output(X@Grad) should not be null"));

auto do_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(do_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
do_dims.size()));

auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");

const std::string data_format =
ctx->Attrs().Get<std::string>("data_format");
const bool channel_last = (data_format == "NHWC");

auto dx_dims = do_dims;
dx_dims[0] = do_dims[0];

if (!channel_last) {
dx_dims[1] = do_dims[1] * (upscale_factor * upscale_factor);
dx_dims[2] = do_dims[2] / upscale_factor;
dx_dims[3] = do_dims[3] / upscale_factor;
} else {
dx_dims[1] = do_dims[1] / upscale_factor;
dx_dims[2] = do_dims[2] / upscale_factor;
dx_dims[3] = do_dims[3] * (upscale_factor * upscale_factor);
}
ctx->SetOutputDim(framework::GradVarName("X"), dx_dims);
}
};

} // namespace operators
Expand All @@ -132,7 +96,11 @@ REGISTER_OPERATOR(pixel_shuffle, ops::PixelShuffleOp, ops::PixelShuffleOpMaker,
ops::PixelShuffleGradMaker<paddle::imperative::OpBase>,
PixelShuffleInferShapeFunctor);

REGISTER_OPERATOR(pixel_shuffle_grad, ops::PixelShuffleGradOp);
DECLARE_INFER_SHAPE_FUNCTOR(pixel_shuffle_grad,
PixelShuffleGradInferShapeFunctor,
PD_INFER_META(phi::PixelShuffleGradInferMeta));
REGISTER_OPERATOR(pixel_shuffle_grad, ops::PixelShuffleGradOp,
PixelShuffleGradInferShapeFunctor);

REGISTER_OP_VERSION(pixel_shuffle)
.AddCheckpoint(
Expand Down
30 changes: 30 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,36 @@ void PixelShuffleInferMeta(const MetaTensor& x,
out->set_dims(output_dims);
}

void PixelShuffleGradInferMeta(const MetaTensor& out_grad,
int upscale_factor,
const std::string& data_format,
MetaTensor* x_grad) {
auto do_dims = out_grad.dims();
PADDLE_ENFORCE_EQ(do_dims.size(),
4,
phi::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
do_dims.size()));

const bool channel_last = (data_format == "NHWC");

auto dx_dims = do_dims;
dx_dims[0] = do_dims[0];

if (!channel_last) {
dx_dims[1] = do_dims[1] * (upscale_factor * upscale_factor);
dx_dims[2] = do_dims[2] / upscale_factor;
dx_dims[3] = do_dims[3] / upscale_factor;
} else {
dx_dims[1] = do_dims[1] / upscale_factor;
dx_dims[2] = do_dims[2] / upscale_factor;
dx_dims[3] = do_dims[3] * (upscale_factor * upscale_factor);
}
x_grad->set_dims(dx_dims);
x_grad->set_dtype(out_grad.dtype());
}

void PNormInferMeta(const MetaTensor& x,
float porder,
int axis,
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ void PixelShuffleInferMeta(const MetaTensor& x,
const std::string& data_format,
MetaTensor* out);

void PixelShuffleGradInferMeta(const MetaTensor& out_grad,
int upscale_factor,
const std::string& data_format,
MetaTensor* x_grad);

void PNormInferMeta(const MetaTensor& x,
float porder,
int axis,
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/fluid/tests/unittests/test_conj_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
class TestConjOp(OpTest):
def setUp(self):
self.op_type = "conj"
self.python_api = paddle.tensor.conj
self.init_dtype_type()
self.init_input_output()
self.init_grad_input_output()
Expand All @@ -53,14 +54,15 @@ def init_grad_input_output(self):
self.grad_in = np.conj(self.grad_out)

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)

def test_check_grad_normal(self):
self.check_grad(
['X'],
'Out',
user_defined_grads=[self.grad_in],
user_defined_grad_outputs=[self.grad_out])
user_defined_grad_outputs=[self.grad_out],
check_eager=True)


class TestComplexConjOp(unittest.TestCase):
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/fluid/tests/unittests/test_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_dygraph(self):
class TestFlipOp(OpTest):
def setUp(self):
self.op_type = 'flip'
self.python_api = paddle.tensor.flip
self.init_test_case()
self.inputs = {'X': np.random.random(self.in_shape).astype('float64')}
self.init_attrs()
Expand All @@ -76,10 +77,10 @@ def init_attrs(self):
self.attrs = {"axis": self.axis}

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)

def test_check_grad(self):
self.check_grad(["X"], "Out")
self.check_grad(["X"], "Out", check_eager=True)

def init_test_case(self):
self.in_shape = (6, 4, 2, 3)
Expand Down Expand Up @@ -131,4 +132,5 @@ def init_test_case(self):


if __name__ == "__main__":
paddle.enable_static()
unittest.main()
6 changes: 4 additions & 2 deletions python/paddle/fluid/tests/unittests/test_pixel_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def pixel_shuffle_np(x, up_factor, data_format="NCHW"):
class TestPixelShuffleOp(OpTest):
def setUp(self):
self.op_type = "pixel_shuffle"
self.python_api = paddle.nn.functional.pixel_shuffle
self.init_data_format()
n, c, h, w = 2, 9, 4, 4

Expand All @@ -73,10 +74,10 @@ def init_data_format(self):
self.format = "NCHW"

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)


class TestChannelLast(TestPixelShuffleOp):
Expand Down Expand Up @@ -220,4 +221,5 @@ def error_data_format_layer():


if __name__ == '__main__':
paddle.enable_static()
unittest.main()
4 changes: 4 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,10 @@ def flip(x, axis, name=None):
"""
if isinstance(axis, int):
axis = [axis]

if in_dygraph_mode():
return _C_ops.final_state_flip(x, axis)

if paddle.in_dynamic_mode():
return _C_ops.flip(x, "axis", axis)

Expand Down
3 changes: 3 additions & 0 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3349,6 +3349,9 @@ def conj(x, name=None):
# [(4-4j), (5-5j), (6-6j)]])
"""
if in_dygraph_mode():
return _C_ops.final_state_conj(x)

if paddle.in_dynamic_mode():
return _C_ops.conj(x)

Expand Down
4 changes: 3 additions & 1 deletion python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@
func : UnchangedInferMeta
kernel :
func : conj
backward : conj_grad

- api : conv2d
args : (Tensor input, Tensor filter, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search)
Expand Down Expand Up @@ -659,6 +660,7 @@
func : FlipInferMeta
kernel :
func : flip
backward : flip_grad

- api : floor
args : (Tensor x)
Expand Down Expand Up @@ -1430,7 +1432,7 @@
func : PixelShuffleInferMeta
kernel :
func : pixel_shuffle
# backward : pixel_shuffle_grad
backward : pixel_shuffle_grad

# poisson // no need grad
- api : poisson
Expand Down
29 changes: 29 additions & 0 deletions python/paddle/utils/code_gen/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,16 @@
output : Tensor[](x_grad)
invoke : concat_grad_impl(x, out_grad, axis)

- backward_api : conj_grad
forward : conj (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [out_grad]
kernel :
func : conj

- backward_api : conv2d_grad
forward : conv2d (Tensor input, Tensor filter, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search) -> Tensor(out)
args : (Tensor input, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search)
Expand Down Expand Up @@ -456,6 +466,16 @@
backend: out_grad
layout: out_grad

- backward_api : flip_grad
forward : flip (Tensor x, int[] axis) -> Tensor(out)
args : (Tensor out_grad, int[] axis)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [out_grad]
kernel :
func : flip

- backward_api : floor_grad
forward : floor(Tensor x) -> Tensor(out)
args : (Tensor out_grad)
Expand Down Expand Up @@ -1010,6 +1030,15 @@
kernel :
func : pad3d_grad

- backward_api : pixel_shuffle_grad
forward : pixel_shuffle (Tensor x, int upscale_factor, str data_format) -> Tensor(out)
args : (Tensor out_grad, int upscale_factor, str data_format)
output : Tensor(x_grad)
infer_meta :
func : PixelShuffleGradInferMeta
kernel :
func : pixel_shuffle_grad

- backward_api : pool2d_grad
forward : pool2d(Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm)
Expand Down

0 comments on commit bc88fbb

Please sign in to comment.