From 6ad340c4e2446f84e3146c110bbd298f11cecbaf Mon Sep 17 00:00:00 2001 From: sunli <466530738@qq.com> Date: Mon, 9 Jan 2023 12:38:00 +0000 Subject: [PATCH 1/9] lerp support 0 Tensor --- paddle/phi/kernels/gpu/lerp_grad_kernel.cu | 7 ++-- .../phi/kernels/impl/lerp_grad_kernel_impl.h | 33 +++++++++++++++++-- paddle/phi/kernels/impl/lerp_kernel_impl.h | 26 +++++++++++++-- .../tests/unittests/test_zero_dim_tensor.py | 1 + 4 files changed, 60 insertions(+), 7 deletions(-) diff --git a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu index b097e4ce4d07a..ed3faf244768d 100644 --- a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu @@ -80,6 +80,9 @@ bool XYNeedReduce(const DenseTensor& x, auto x_dims = x.dims(); auto y_dims = y.dims(); auto out_dims = out.dims(); + if (out_dims.size() == 0) { + return false; + } int x_rank = x_dims.size(); int y_rank = y_dims.size(); int out_rank = out_dims.size(); @@ -166,10 +169,10 @@ void LerpGradKernel(const Context& ctx, const int rank = out.dims().size(); PADDLE_ENFORCE_GE( rank, - 1, + 0, phi::errors::InvalidArgument( "The number of dimensions for LerpGradOp must be " - "greater than or equal to 1, but the value received is %d.", + "greater than or equal to 0, but the value received is %d.", rank)); PADDLE_ENFORCE_LE( rank, diff --git a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h index b47acbda0da2d..9249a3a3ae6dd 100644 --- a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h @@ -91,6 +91,31 @@ static void LerpGradFunction(const Context& ctx, } } +template +static void LerpGradFunctionZero(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& weight, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad) { + auto dim = make_ddim(std::vector(1, 1)); + auto eigen_w = phi::EigenTensor::From(weight, dim); + auto eigen_dout = phi::EigenTensor::From(out_grad, dim); + + if (x_grad) { + ctx.template Alloc(x_grad); + auto eigen_dx = phi::EigenTensor::From(*x_grad, dim); + eigen_dx.device(place) = (1 - eigen_w) * eigen_dout; + } + if (y_grad) { + ctx.template Alloc(y_grad); + auto eigen_dy = phi::EigenTensor::From(*y_grad, dim); + eigen_dy.device(place) = eigen_w * eigen_dout; + } +} + template void LerpGradKernel(const Context& ctx, const DenseTensor& x, @@ -103,10 +128,10 @@ void LerpGradKernel(const Context& ctx, int rank = out.dims().size(); PADDLE_ENFORCE_GE( rank, - 1, + 0, phi::errors::InvalidArgument( "The number of dimensions for LerpGradOp must be " - "greater than or equal to 1, but the value received is %d.", + "greater than or equal to 0, but the value received is %d.", rank)); PADDLE_ENFORCE_LE( rank, @@ -116,6 +141,10 @@ void LerpGradKernel(const Context& ctx, "less than or equal to 6, but the value received is %d.", rank)); switch (rank) { + case 0: + LerpGradFunctionZero( + ctx, x, y, weight, out, out_grad, x_grad, y_grad); + break; case 1: LerpGradFunction( ctx, x, y, weight, out, out_grad, x_grad, y_grad); diff --git a/paddle/phi/kernels/impl/lerp_kernel_impl.h b/paddle/phi/kernels/impl/lerp_kernel_impl.h index 72fa0672a5f48..668349e09b951 100644 --- a/paddle/phi/kernels/impl/lerp_kernel_impl.h +++ b/paddle/phi/kernels/impl/lerp_kernel_impl.h @@ -27,7 +27,6 @@ static void LerpFunction(const Context& ctx, const DenseTensor& weight, DenseTensor* out) { ctx.template Alloc(out); - const auto& out_dims = out->dims(); auto x_dims = phi::funcs::ExtendDims2Rank(x.dims(), D); auto y_dims = phi::funcs::ExtendDims2Rank(y.dims(), D); @@ -51,6 +50,24 @@ static void LerpFunction(const Context& ctx, (eigen_y.broadcast(y_bcast_dims) - eigen_x.broadcast(x_bcast_dims)); } +template +static void LerpFunctionZero(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& weight, + DenseTensor* out) { + ctx.template Alloc(out); + + auto dim = make_ddim(std::vector(1, 1)); + auto eigen_x = phi::EigenTensor::From(x, dim); + auto eigen_y = phi::EigenTensor::From(y, dim); + auto eigen_w = phi::EigenTensor::From(weight, dim); + auto eigen_out = phi::EigenTensor::From(*out, dim); + + auto& place = *ctx.eigen_device(); + eigen_out.device(place) = eigen_x + eigen_w * (eigen_y - eigen_x); +} + template void LerpKernel(const Context& ctx, const DenseTensor& x, @@ -60,10 +77,10 @@ void LerpKernel(const Context& ctx, int rank = out->dims().size(); PADDLE_ENFORCE_GE( rank, - 1, + 0, phi::errors::InvalidArgument( "The number of dimensions for LerpOp must be " - "greater than or equal to 1, but the value received is %d.", + "greater than or equal to 0, but the value received is %d.", rank)); PADDLE_ENFORCE_LE( rank, @@ -73,6 +90,9 @@ void LerpKernel(const Context& ctx, "less than or equal to 6, but the value received is %d.", rank)); switch (rank) { + case 0: + LerpFunctionZero(ctx, x, y, weight, out); + break; case 1: LerpFunction(ctx, x, y, weight, out); break; diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 5b18abdbc558d..765352cf6b9fc 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -283,6 +283,7 @@ def test_static(self): paddle.logical_and, paddle.logical_or, paddle.logical_xor, + paddle.lerp, ] binary_int_api_list = [ From 2df7f36fccd7dacae715176637db6a3e83cdeff9 Mon Sep 17 00:00:00 2001 From: sunli <466530738@qq.com> Date: Mon, 9 Jan 2023 12:53:39 +0000 Subject: [PATCH 2/9] fix lerp grad --- paddle/phi/kernels/impl/lerp_grad_kernel_impl.h | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h index 9249a3a3ae6dd..6f9fae5daf6ed 100644 --- a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h @@ -101,17 +101,18 @@ static void LerpGradFunctionZero(const Context& ctx, DenseTensor* x_grad, DenseTensor* y_grad) { auto dim = make_ddim(std::vector(1, 1)); - auto eigen_w = phi::EigenTensor::From(weight, dim); - auto eigen_dout = phi::EigenTensor::From(out_grad, dim); + auto eigen_w = phi::EigenTensor::From(weight, dim); + auto eigen_dout = phi::EigenTensor::From(out_grad, dim); + auto& place = *ctx.eigen_device(); if (x_grad) { ctx.template Alloc(x_grad); - auto eigen_dx = phi::EigenTensor::From(*x_grad, dim); + auto eigen_dx = phi::EigenTensor::From(*x_grad, dim); eigen_dx.device(place) = (1 - eigen_w) * eigen_dout; } if (y_grad) { ctx.template Alloc(y_grad); - auto eigen_dy = phi::EigenTensor::From(*y_grad, dim); + auto eigen_dy = phi::EigenTensor::From(*y_grad, dim); eigen_dy.device(place) = eigen_w * eigen_dout; } } From 349b1b49dabf619afb6f01bf4798175f1c39ae05 Mon Sep 17 00:00:00 2001 From: sunli <466530738@qq.com> Date: Tue, 10 Jan 2023 02:37:01 +0000 Subject: [PATCH 3/9] fix lerp zero test --- .../tests/unittests/test_zero_dim_tensor.py | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 765352cf6b9fc..6391d76191b02 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -283,7 +283,6 @@ def test_static(self): paddle.logical_and, paddle.logical_or, paddle.logical_xor, - paddle.lerp, ] binary_int_api_list = [ @@ -967,6 +966,20 @@ def test_argsort(self): self.assertEqual(x1.grad.numpy(), 0) self.assertEqual(x2.grad.numpy(), 0) + def test_lerp(self): + x = paddle.rand([]) + y = paddle.rand([]) + w = paddle.rand([]) + x.stop_gradient = False + y.stop_gradient = False + + out = paddle.lerp(x, y, w) + out.backward() + + self.assertEqual(out.shape, []) + self.assertEqual(x.grad.shape, []) + self.assertEqual(y.grad.shape, []) + class TestSundryAPIStatic(unittest.TestCase): def setUp(self): @@ -1381,6 +1394,22 @@ def test_argsort(self): self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, ()) + @prog_scope() + def test_lerp(self): + x = paddle.rand([]) + y = paddle.rand([]) + w = paddle.rand([]) + x.stop_gradient = False + y.stop_gradient = False + + out = paddle.lerp(x, y, w) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + + self.assertEqual(res[0].shape, ()) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): From de44473e0ef72980ae0f7fdc0c718c9692823799 Mon Sep 17 00:00:00 2001 From: sunli <466530738@qq.com> Date: Tue, 10 Jan 2023 07:47:26 +0000 Subject: [PATCH 4/9] fix 0D + ND/ND + 0D --- paddle/phi/kernels/gpu/lerp_grad_kernel.cu | 16 +++- .../tests/unittests/test_zero_dim_tensor.py | 84 +++++++++++++++---- 2 files changed, 78 insertions(+), 22 deletions(-) diff --git a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu index ed3faf244768d..f42f316aae980 100644 --- a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu @@ -77,8 +77,11 @@ __global__ void LerpGradScalarKernelImpl(const T* weight, bool XYNeedReduce(const DenseTensor& x, const DenseTensor& y, const DenseTensor& out) { - auto x_dims = x.dims(); - auto y_dims = y.dims(); + auto x_dims = + x.dims().size() ? x.dims() : make_ddim(std::vector(1, 1)); + auto y_dims = + y.dims().size() ? y.dims() : make_ddim(std::vector(1, 1)); + auto out_dims = out.dims(); if (out_dims.size() == 0) { return false; @@ -234,9 +237,12 @@ void LerpGradKernel(const Context& ctx, x_grad_data, y_grad_data); + auto zero_dim = make_ddim(std::vector(1, 1)); if (x_grad) { std::vector reduce_axis_x = - funcs::GetReduceDim(x_grad->dims(), b_xgrad.dims(), -1); + funcs::GetReduceDim(x_grad->dims().size() ? x_grad->dims() : zero_dim, + b_xgrad.dims(), + -1); if (!reduce_axis_x.empty()) { phi::funcs:: ReduceKernel>( @@ -248,7 +254,9 @@ void LerpGradKernel(const Context& ctx, if (y_grad) { std::vector reduce_axis_y = - funcs::GetReduceDim(y_grad->dims(), b_ygrad.dims(), -1); + funcs::GetReduceDim(y_grad->dims().size() ? y_grad->dims() : zero_dim, + b_ygrad.dims(), + -1); if (!reduce_axis_y.empty()) { phi::funcs:: ReduceKernel>( diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index aee56a6dceeed..f3d7bd906df31 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -967,18 +967,47 @@ def test_argsort(self): self.assertEqual(x2.grad.numpy(), 0) def test_lerp(self): - x = paddle.rand([]) - y = paddle.rand([]) - w = paddle.rand([]) - x.stop_gradient = False - y.stop_gradient = False + # 0D + 0D + x0 = paddle.rand([]) + y0 = paddle.rand([]) + w0 = paddle.rand([]) + x0.stop_gradient = False + y0.stop_gradient = False - out = paddle.lerp(x, y, w) - out.backward() + out0 = paddle.lerp(x0, y0, w0) + out0.backward() - self.assertEqual(out.shape, []) - self.assertEqual(x.grad.shape, []) - self.assertEqual(y.grad.shape, []) + self.assertEqual(out0.shape, []) + self.assertEqual(x0.grad.shape, []) + self.assertEqual(y0.grad.shape, []) + + # 0D + ND + x1 = paddle.rand([]) + y1 = paddle.rand([64, 64]) + w1 = paddle.rand([]) + x1.stop_gradient = False + y1.stop_gradient = False + + out1 = paddle.lerp(x1, y1, w1) + out1.backward() + + self.assertEqual(out1.shape, [64, 64]) + self.assertEqual(x1.grad.shape, []) + self.assertEqual(y1.grad.shape, [64, 64]) + + # ND + 0D + x2 = paddle.rand([64, 64]) + y2 = paddle.rand([]) + w2 = paddle.rand([]) + x2.stop_gradient = False + y2.stop_gradient = False + + out2 = paddle.lerp(x2, y2, w2) + out2.backward() + + self.assertEqual(out2.shape, [64, 64]) + self.assertEqual(x2.grad.shape, [64, 64]) + self.assertEqual(y2.grad.shape, []) def test_repeat_interleave(self): places = ['cpu'] @@ -1424,17 +1453,36 @@ def test_argsort(self): @prog_scope() def test_lerp(self): - x = paddle.rand([]) - y = paddle.rand([]) - w = paddle.rand([]) - x.stop_gradient = False - y.stop_gradient = False - out = paddle.lerp(x, y, w) - paddle.static.append_backward(out) + # 0D + 0D + x0 = paddle.rand([]) + y0 = paddle.rand([]) + w0 = paddle.rand([]) + x0.stop_gradient = False + y0.stop_gradient = False + out0 = paddle.lerp(x0, y0, w0) + paddle.static.append_backward(out0) + # 0D + ND + x1 = paddle.rand([]) + y1 = paddle.rand([64, 64]) + w1 = paddle.rand([]) + x1.stop_gradient = False + y1.stop_gradient = False + out1 = paddle.lerp(x1, y1, w1) + paddle.static.append_backward(out1) + # ND + 0D + x2 = paddle.rand([64, 64]) + y2 = paddle.rand([]) + w2 = paddle.rand([]) + x2.stop_gradient = False + y2.stop_gradient = False + out2 = paddle.lerp(x2, y2, w2) + paddle.static.append_backward(out2) prog = paddle.static.default_main_program() - res = self.exe.run(prog, fetch_list=[out]) + res = self.exe.run(prog, fetch_list=[out0, out1, out2]) self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (64, 64)) + self.assertEqual(res[2].shape, (64, 64)) @prog_scope() def test_repeat_interleave(self): From 96cf5587a0aa2a1aeca3d812d698746a65775ebc Mon Sep 17 00:00:00 2001 From: sunli <466530738@qq.com> Date: Tue, 10 Jan 2023 09:24:14 +0000 Subject: [PATCH 5/9] fix check --- paddle/phi/kernels/gpu/lerp_grad_kernel.cu | 7 +++++++ paddle/phi/kernels/impl/lerp_grad_kernel_impl.h | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu index f42f316aae980..e6fe211e32aa2 100644 --- a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu @@ -170,6 +170,13 @@ void LerpGradKernel(const Context& ctx, DenseTensor* x_grad, DenseTensor* y_grad) { const int rank = out.dims().size(); + PADDLE_ENFORCE_EQ(out.dims().size(), + out_grad.dims.size(), + phi::errors::InvalidArgument( + "The number of dimensions for LerpGradOp must be " + " equal to LerpOut, but the value received is %d != %d", + out_grad.dims.size(), + out.dims().size())); PADDLE_ENFORCE_GE( rank, 0, diff --git a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h index 6f9fae5daf6ed..30b9894207162 100644 --- a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h @@ -127,6 +127,13 @@ void LerpGradKernel(const Context& ctx, DenseTensor* x_grad, DenseTensor* y_grad) { int rank = out.dims().size(); + PADDLE_ENFORCE_EQ(out.dims().size(), + out_grad.dims.size(), + phi::errors::InvalidArgument( + "The number of dimensions for LerpGradOp must be " + " equal to LerpOut, but the value received is %d != %d", + out_grad.dims.size(), + out.dims().size())); PADDLE_ENFORCE_GE( rank, 0, From 4eca4aa49e13853fa14d3c0dd08bde87006db18d Mon Sep 17 00:00:00 2001 From: sunli <466530738@qq.com> Date: Tue, 10 Jan 2023 09:25:37 +0000 Subject: [PATCH 6/9] update code --- paddle/phi/kernels/gpu/lerp_grad_kernel.cu | 4 ++-- paddle/phi/kernels/impl/lerp_grad_kernel_impl.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu index e6fe211e32aa2..07ead29fbf977 100644 --- a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu @@ -171,11 +171,11 @@ void LerpGradKernel(const Context& ctx, DenseTensor* y_grad) { const int rank = out.dims().size(); PADDLE_ENFORCE_EQ(out.dims().size(), - out_grad.dims.size(), + out_grad.dims().size(), phi::errors::InvalidArgument( "The number of dimensions for LerpGradOp must be " " equal to LerpOut, but the value received is %d != %d", - out_grad.dims.size(), + out_grad.dims().size(), out.dims().size())); PADDLE_ENFORCE_GE( rank, diff --git a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h index 30b9894207162..39c2ef62e0a91 100644 --- a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h @@ -128,11 +128,11 @@ void LerpGradKernel(const Context& ctx, DenseTensor* y_grad) { int rank = out.dims().size(); PADDLE_ENFORCE_EQ(out.dims().size(), - out_grad.dims.size(), + out_grad.dims().size(), phi::errors::InvalidArgument( "The number of dimensions for LerpGradOp must be " " equal to LerpOut, but the value received is %d != %d", - out_grad.dims.size(), + out_grad.dims().size(), out.dims().size())); PADDLE_ENFORCE_GE( rank, From 142a540f40a4cdd4357bf77ed4a30f4b6e5f4935 Mon Sep 17 00:00:00 2001 From: sunli <466530738@qq.com> Date: Tue, 10 Jan 2023 13:04:59 +0000 Subject: [PATCH 7/9] fix lerp infer shape --- paddle/phi/infermeta/ternary.cc | 4 +--- paddle/phi/kernels/gpu/lerp_grad_kernel.cu | 15 ++++++++------- paddle/phi/kernels/impl/lerp_grad_kernel_impl.h | 15 ++++++++------- .../fluid/tests/unittests/test_zero_dim_tensor.py | 5 ++--- 4 files changed, 19 insertions(+), 20 deletions(-) diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 7439195818918..d790d226b2d37 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -598,9 +598,7 @@ void LerpInferMeta(const MetaTensor& x, auto w_dims = weight.dims(); DDim out_dims; out_dims = funcs::GetOutputDims(x_dims, y_dims); - if (w_dims.size() > 1 || w_dims[0] != 1) { - out_dims = funcs::GetOutputDims(out_dims, w_dims); - } + out_dims = funcs::GetOutputDims(out_dims, w_dims); out->set_dims(out_dims); out->set_dtype(x.dtype()); out->share_lod(x); diff --git a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu index 07ead29fbf977..f78161b39c761 100644 --- a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu @@ -170,13 +170,14 @@ void LerpGradKernel(const Context& ctx, DenseTensor* x_grad, DenseTensor* y_grad) { const int rank = out.dims().size(); - PADDLE_ENFORCE_EQ(out.dims().size(), - out_grad.dims().size(), - phi::errors::InvalidArgument( - "The number of dimensions for LerpGradOp must be " - " equal to LerpOut, but the value received is %d != %d", - out_grad.dims().size(), - out.dims().size())); + PADDLE_ENFORCE_EQ( + out.dims().size(), + out_grad.dims().size(), + phi::errors::InvalidArgument( + "The number of dimensions for LerpOut must be " + " equal to LerpGrad, but the value received is %d != %d", + out.dims().size(), + out_grad.dims().size())); PADDLE_ENFORCE_GE( rank, 0, diff --git a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h index 39c2ef62e0a91..45839105bfd41 100644 --- a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h @@ -127,13 +127,14 @@ void LerpGradKernel(const Context& ctx, DenseTensor* x_grad, DenseTensor* y_grad) { int rank = out.dims().size(); - PADDLE_ENFORCE_EQ(out.dims().size(), - out_grad.dims().size(), - phi::errors::InvalidArgument( - "The number of dimensions for LerpGradOp must be " - " equal to LerpOut, but the value received is %d != %d", - out_grad.dims().size(), - out.dims().size())); + PADDLE_ENFORCE_EQ( + out.dims().size(), + out_grad.dims().size(), + phi::errors::InvalidArgument( + "The number of dimensions for LerpOut must be " + " equal to LerpGrad, but the value received is %d != %d", + out.dims().size(), + out_grad.dims().size())); PADDLE_ENFORCE_GE( rank, 0, diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index f3d7bd906df31..1e6a2848fc3f1 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -1460,7 +1460,7 @@ def test_lerp(self): x0.stop_gradient = False y0.stop_gradient = False out0 = paddle.lerp(x0, y0, w0) - paddle.static.append_backward(out0) + # 0D + ND x1 = paddle.rand([]) y1 = paddle.rand([64, 64]) @@ -1468,7 +1468,7 @@ def test_lerp(self): x1.stop_gradient = False y1.stop_gradient = False out1 = paddle.lerp(x1, y1, w1) - paddle.static.append_backward(out1) + # ND + 0D x2 = paddle.rand([64, 64]) y2 = paddle.rand([]) @@ -1476,7 +1476,6 @@ def test_lerp(self): x2.stop_gradient = False y2.stop_gradient = False out2 = paddle.lerp(x2, y2, w2) - paddle.static.append_backward(out2) prog = paddle.static.default_main_program() res = self.exe.run(prog, fetch_list=[out0, out1, out2]) From 8df230763fd1ec43b2c3132685d19480b5f79312 Mon Sep 17 00:00:00 2001 From: sunli <466530738@qq.com> Date: Wed, 11 Jan 2023 11:36:40 +0000 Subject: [PATCH 8/9] static backward test --- paddle/phi/kernels/gpu/lerp_grad_kernel.cu | 8 ----- .../phi/kernels/impl/lerp_grad_kernel_impl.h | 29 +++++++++---------- .../tests/unittests/test_zero_dim_tensor.py | 3 ++ 3 files changed, 16 insertions(+), 24 deletions(-) diff --git a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu index f78161b39c761..f42f316aae980 100644 --- a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu @@ -170,14 +170,6 @@ void LerpGradKernel(const Context& ctx, DenseTensor* x_grad, DenseTensor* y_grad) { const int rank = out.dims().size(); - PADDLE_ENFORCE_EQ( - out.dims().size(), - out_grad.dims().size(), - phi::errors::InvalidArgument( - "The number of dimensions for LerpOut must be " - " equal to LerpGrad, but the value received is %d != %d", - out.dims().size(), - out_grad.dims().size())); PADDLE_ENFORCE_GE( rank, 0, diff --git a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h index 45839105bfd41..541de0cc162cc 100644 --- a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h @@ -33,33 +33,36 @@ static void LerpGradFunction(const Context& ctx, auto* dx = x_grad; auto* dy = y_grad; - auto dout_dims = dout.dims(); + auto& out_dims = out.dims(); DDim dx_dims; DDim dy_dims; auto w_dims = phi::funcs::ExtendDims2Rank(w.dims(), D); + auto g_dims = phi::funcs::ExtendDims2Rank(out_grad.dims(), D); Eigen::DSizes dx_bcast_dims; Eigen::DSizes dy_bcast_dims; Eigen::DSizes w_bcast_dims; + Eigen::DSizes g_bcast_dims; if (dx) { dx_dims = phi::funcs::ExtendDims2Rank(dx->dims(), D); - phi::funcs::GetBroadcastDims(dx_dims, dout_dims, &dx_bcast_dims); + phi::funcs::GetBroadcastDims(dx_dims, out_dims, &dx_bcast_dims); } if (dy) { dy_dims = phi::funcs::ExtendDims2Rank(dy->dims(), D); - phi::funcs::GetBroadcastDims(dy_dims, dout_dims, &dy_bcast_dims); + phi::funcs::GetBroadcastDims(dy_dims, out_dims, &dy_bcast_dims); } - phi::funcs::GetBroadcastDims(w_dims, dout_dims, &w_bcast_dims); + phi::funcs::GetBroadcastDims(w_dims, out_dims, &w_bcast_dims); + phi::funcs::GetBroadcastDims(g_dims, out_dims, &g_bcast_dims); auto eigen_w = phi::EigenTensor::From(w, w_dims); - auto eigen_dout = phi::EigenTensor::From(dout); + auto eigen_dout = phi::EigenTensor::From(dout, g_dims); Eigen::DSizes dx_reshape_dims; Eigen::DSizes dy_reshape_dims; Eigen::DSizes reduce_dims; - for (int i = 0; i < dout_dims.size(); ++i) { + for (int i = 0; i < out_dims.size(); ++i) { if (dx) { dx_reshape_dims[2 * i] = dx_bcast_dims[i]; dx_reshape_dims[2 * i + 1] = dx_dims[i]; @@ -76,7 +79,8 @@ static void LerpGradFunction(const Context& ctx, if (dx) { ctx.template Alloc(dx); auto eigen_dx = phi::EigenTensor::From(*dx, dx_dims); - auto eigen_expr = (1 - eigen_w.broadcast(w_bcast_dims)) * eigen_dout; + auto eigen_expr = (1 - eigen_w.broadcast(w_bcast_dims)) * + eigen_dout.broadcast(g_bcast_dims); eigen_dx.device(place) = eigen_expr.reshape(dx_reshape_dims) .sum(reduce_dims) .reshape(eigen_dx.dimensions()); @@ -84,7 +88,8 @@ static void LerpGradFunction(const Context& ctx, if (dy) { ctx.template Alloc(dy); auto eigen_dy = phi::EigenTensor::From(*dy, dy_dims); - auto eigen_expr = eigen_w.broadcast(w_bcast_dims) * eigen_dout; + auto eigen_expr = + eigen_w.broadcast(w_bcast_dims) * eigen_dout.broadcast(g_bcast_dims); eigen_dy.device(place) = eigen_expr.reshape(dy_reshape_dims) .sum(reduce_dims) .reshape(eigen_dy.dimensions()); @@ -127,14 +132,6 @@ void LerpGradKernel(const Context& ctx, DenseTensor* x_grad, DenseTensor* y_grad) { int rank = out.dims().size(); - PADDLE_ENFORCE_EQ( - out.dims().size(), - out_grad.dims().size(), - phi::errors::InvalidArgument( - "The number of dimensions for LerpOut must be " - " equal to LerpGrad, but the value received is %d != %d", - out.dims().size(), - out_grad.dims().size())); PADDLE_ENFORCE_GE( rank, 0, diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 1e6a2848fc3f1..6c32f26b7b1f3 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -1460,6 +1460,7 @@ def test_lerp(self): x0.stop_gradient = False y0.stop_gradient = False out0 = paddle.lerp(x0, y0, w0) + paddle.static.append_backward(out0) # 0D + ND x1 = paddle.rand([]) @@ -1468,6 +1469,7 @@ def test_lerp(self): x1.stop_gradient = False y1.stop_gradient = False out1 = paddle.lerp(x1, y1, w1) + paddle.static.append_backward(out1) # ND + 0D x2 = paddle.rand([64, 64]) @@ -1476,6 +1478,7 @@ def test_lerp(self): x2.stop_gradient = False y2.stop_gradient = False out2 = paddle.lerp(x2, y2, w2) + paddle.static.append_backward(out2) prog = paddle.static.default_main_program() res = self.exe.run(prog, fetch_list=[out0, out1, out2]) From 59324847dc61ff25b8e9327d94dfda4cb254f431 Mon Sep 17 00:00:00 2001 From: sunli <466530738@qq.com> Date: Thu, 12 Jan 2023 07:20:48 +0000 Subject: [PATCH 9/9] updata static graph test --- .../tests/unittests/test_zero_dim_tensor.py | 56 +++++++++---------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 6c32f26b7b1f3..5bb527533b46e 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -1453,38 +1453,32 @@ def test_argsort(self): @prog_scope() def test_lerp(self): - # 0D + 0D - x0 = paddle.rand([]) - y0 = paddle.rand([]) - w0 = paddle.rand([]) - x0.stop_gradient = False - y0.stop_gradient = False - out0 = paddle.lerp(x0, y0, w0) - paddle.static.append_backward(out0) - - # 0D + ND - x1 = paddle.rand([]) - y1 = paddle.rand([64, 64]) - w1 = paddle.rand([]) - x1.stop_gradient = False - y1.stop_gradient = False - out1 = paddle.lerp(x1, y1, w1) - paddle.static.append_backward(out1) - - # ND + 0D - x2 = paddle.rand([64, 64]) - y2 = paddle.rand([]) - w2 = paddle.rand([]) - x2.stop_gradient = False - y2.stop_gradient = False - out2 = paddle.lerp(x2, y2, w2) - paddle.static.append_backward(out2) + shapes = [ + [(), (), (), ()], + [(), (64, 64), (), (64, 64)], + [(64, 64), (), (), (64, 64)], + ] + for shape in shapes: + x = paddle.rand(shape[0]) + y = paddle.rand(shape[1]) + w = paddle.rand(shape[2]) - prog = paddle.static.default_main_program() - res = self.exe.run(prog, fetch_list=[out0, out1, out2]) - self.assertEqual(res[0].shape, ()) - self.assertEqual(res[1].shape, (64, 64)) - self.assertEqual(res[2].shape, (64, 64)) + x.stop_gradient = False + y.stop_gradient = False + out = paddle.lerp(x, y, w) + paddle.static.append_backward(out.sum()) + + prog = paddle.static.default_main_program() + block = prog.global_block() + x_grad = block.var(fluid.framework.grad_var_name(x.name)) + y_grad = block.var(fluid.framework.grad_var_name(y.name)) + out_grad = block.var(fluid.framework.grad_var_name(out.name)) + + res = self.exe.run(prog, fetch_list=[out, out_grad, y_grad, x_grad]) + self.assertEqual(res[0].shape, shape[3]) + self.assertEqual(res[1].shape, shape[3]) + self.assertEqual(res[2].shape, shape[1]) + self.assertEqual(res[3].shape, shape[0]) @prog_scope() def test_repeat_interleave(self):