diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cc b/paddle/fluid/operators/elementwise/elementwise_div_op.cc index 191890865fb89..4029be65a00d6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cc @@ -107,6 +107,7 @@ class ElementwiseDivDoubleGradMaker : public framework::SingleGradOpMaker { op->SetType("elementwise_div_grad_grad"); op->SetInput("Y", this->Input("Y")); op->SetInput("Out", this->Input("Out")); + op->SetInput("Out@GRAD", this->Input(framework::GradVarName("Out"))); op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X"))); op->SetInput("DDY", this->OutputGrad(framework::GradVarName("Y"))); op->SetInput("DX", this->Output(framework::GradVarName("X"))); diff --git a/paddle/fluid/operators/ops_signature/elementwise_sig.cc b/paddle/fluid/operators/ops_signature/elementwise_sig.cc index b1150268fbad1..82f891bb48a00 100644 --- a/paddle/fluid/operators/ops_signature/elementwise_sig.cc +++ b/paddle/fluid/operators/ops_signature/elementwise_sig.cc @@ -168,7 +168,7 @@ KernelSignature ElementwiseDivGradOpArgumentMapping( KernelSignature ElementwiseDivDoubleGradOpArgumentMapping( const ArgumentMappingContext& ctx UNUSED) { return KernelSignature("divide_double_grad", - {"Y", "Out", "DX", "DDX", "DDY"}, + {"Y", "Out", "Out@GRAD", "DX", "DDX", "DDY"}, {"axis"}, {"Y@GRAD", "DOut", "DDOut"}); } diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml index 2c8996d6a53a5..101bae084c264 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml @@ -201,15 +201,15 @@ - backward_op : divide_double_grad forward : divide_grad (Tensor x, Tensor y, Tensor out, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y) - args : (Tensor y, Tensor out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1) + args : (Tensor y, Tensor out, Tensor grad_out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1) output : Tensor(y_grad), Tensor(out_grad), Tensor(grad_out_grad) infer_meta : func : GeneralTernaryGradInferMeta - param : [y, grad_x, grad_x] + param : [y, out, out] kernel : func : divide_double_grad data_type : out - optional : grad_x_grad, grad_y_grad + optional : grad_x, grad_x_grad, grad_y_grad inplace : (grad_x_grad -> grad_out_grad) - backward_op : divide_grad diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 2ca26f1efbdd5..c1318cb277004 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -175,15 +175,15 @@ - backward_op : divide_double_grad forward : divide_grad (Tensor x, Tensor y, Tensor out, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y) - args : (Tensor y, Tensor out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1) + args : (Tensor y, Tensor out, Tensor grad_out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1) output : Tensor(y_grad), Tensor(out_grad), Tensor(grad_out_grad) infer_meta : func : GeneralTernaryGradInferMeta - param : [y, grad_x, grad_x] + param : [y, out, out] kernel : func : divide_double_grad data_type : out - optional : grad_x_grad, grad_y_grad + optional : grad_x, grad_x_grad, grad_y_grad inplace : (grad_x_grad -> grad_out_grad) - backward_op : divide_grad diff --git a/paddle/phi/kernels/elementwise_divide_grad_kernel.h b/paddle/phi/kernels/elementwise_divide_grad_kernel.h index c764f05c3983f..15b1e65a9cfdf 100644 --- a/paddle/phi/kernels/elementwise_divide_grad_kernel.h +++ b/paddle/phi/kernels/elementwise_divide_grad_kernel.h @@ -33,7 +33,8 @@ template void DivideDoubleGradKernel(const Context& dev_ctx, const DenseTensor& y, const DenseTensor& out, - const DenseTensor& dx, + const DenseTensor& grad_out, + const paddle::optional& dx, const paddle::optional& ddx, const paddle::optional& ddy, int axis, diff --git a/paddle/phi/kernels/funcs/common_shape.h b/paddle/phi/kernels/funcs/common_shape.h index 19f2fa1f2fac4..45a1024339ba3 100644 --- a/paddle/phi/kernels/funcs/common_shape.h +++ b/paddle/phi/kernels/funcs/common_shape.h @@ -52,7 +52,6 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, "Axis should be less than or equal to %d, but received axis is %d.", max_dim, axis)); - if (x_dims.size() > y_dims.size()) { std::fill(y_dims_array, y_dims_array + axis, 1); if (axis + y_dims.size() < max_dim) { @@ -68,7 +67,6 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, std::copy(x_dims.Get(), x_dims.Get() + x_dims.size(), x_dims_array + axis); std::copy(y_dims.Get(), y_dims.Get() + y_dims.size(), y_dims_array); } - for (int i = 0; i < max_dim; ++i) { PADDLE_ENFORCE_EQ( x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index db6858bc9d7d7..4bd0ede6dc827 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/elementwise_utils.h" namespace phi { @@ -157,42 +158,325 @@ struct DivGradDY> { template struct DivDoubleDY { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return y * out * dout - x * dout; + HOSTDEVICE T operator()(const T& x, + const T& y, + const T& out, + const T& dout) const { + return (y * out - x) * dout; + } +}; + +template +struct DivDoubleDY_Only_DDY { + HOSTDEVICE T operator()(const T& x, + const T& y, + const T& out, + const T& dout) const { + return y * out * dout; } }; +template +struct DivDoubleDY_Only_DDX { + HOSTDEVICE T operator()(const T& x, + const T& y, + const T& out, + const T& dout) const { + return -x * dout; + } +}; + +// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y +template +struct DivDoubleDDOut { + HOSTDEVICE T operator()(const T& ddx, + const T& ddy, + const T& y, + const T& out) const { + return (ddx - out * ddy) / y; + } +}; + +template +struct DivDoubleDDOut_Only_DDY { + HOSTDEVICE T operator()(const T& ddx, + const T& ddy, + const T& y, + const T& out) const { + return -out * ddy / y; + } +}; + +template +void ComputeDDoutWithoutBroadcast(const CPUContext& dev_ctx UNUSED, + const phi::DenseTensor& ddx, + const phi::DenseTensor& ddy, + const phi::DenseTensor& y, + const phi::DenseTensor& out, + phi::DenseTensor* ddout, + DDout_OP dout_op) { + auto out_numel = out.numel(); + auto* ddx_data = ddx.data(); + auto* ddy_data = ddy.data(); + auto* y_data = y.data(); + auto* out_data = out.data(); + auto* ddout_data = ddout->data(); + for (int i = 0; i < out_numel; i++) { + ddout_data[i] = dout_op(ddx_data[i], ddy_data[i], y_data[i], out_data[i]); + } +} + +template +void ComputeDDoutWithBroadcast(const CPUContext& dev_ctx UNUSED, + const phi::DenseTensor& ddx, + const phi::DenseTensor& ddy, + const phi::DenseTensor& y, + const phi::DenseTensor& out, + phi::DenseTensor* ddout, + const int* x_dims_array, + const int* y_dims_array, + const int* out_dims_array, + const int max_dim, + DDout_OP dout_op) { + auto out_numel = out.numel(); + auto* ddx_data = ddx.data(); + auto* ddy_data = ddy.data(); + auto* y_data = y.data(); + auto* out_data = out.data(); + auto* ddout_data = ddout->data(); + std::vector index_array(max_dim, 0); + for (int i = 0; i < out_numel; i++) { + int x_index = phi::funcs::GetElementwiseIndex( + x_dims_array, max_dim, index_array.data()); + int y_index = phi::funcs::GetElementwiseIndex( + y_dims_array, max_dim, index_array.data()); + ddout_data[i] = dout_op( + ddx_data[x_index], ddy_data[y_index], y_data[y_index], out_data[i]); + phi::funcs::UpdateElementwiseIndexArray( + out_dims_array, max_dim, index_array.data()); + } +} + +#if defined(__NVCC__) || defined(__HIPCC__) + +template +__global__ void ComputeDDoutWithoutBroadcastGPUKernel(const T* ddx_data, + const T* ddy_data, + const T* y_data, + const T* out_data, + T* ddout_data, + int numel, + DDout_OP dout_op) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= numel) return; + ddout_data[tid] = + dout_op(ddx_data[tid], ddy_data[tid], y_data[tid], out_data[tid]); +} +template +void ComputeDDoutWithoutBroadcast(const GPUContext& dev_ctx UNUSED, + const phi::DenseTensor& ddx, + const phi::DenseTensor& ddy, + const phi::DenseTensor& y, + const phi::DenseTensor& out, + phi::DenseTensor* ddout, + DDout_OP dout_op) { + auto out_numel = out.numel(); + auto* ddx_data = ddx.data(); + auto* ddy_data = ddy.data(); + auto* y_data = y.data(); + auto* out_data = out.data(); + auto* ddout_data = ddout->data(); + int block = 512; + int64_t grid = (out_numel + block - 1) / block; + auto stream = reinterpret_cast(dev_ctx).stream(); + ComputeDDoutWithoutBroadcastGPUKernel + <<>>( + ddx_data, ddy_data, y_data, out_data, ddout_data, out_numel, dout_op); +} + +template +__global__ void ComputeDDoutWithBroadcastGPUKernel(const T* ddx_data, + const T* ddy_data, + const T* y_data, + const T* out_data, + T* ddout_data, + int numel, + const int* x_dims_array, + const int* y_dims_array, + const int* out_dims_array, + const int max_dim, + DDout_OP dout_op) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= numel) return; + int x_index = 0, y_index = 0, x_index_prod = 1, y_index_prod = 1, + out_index = tid, dim_index; + for (int64_t i = max_dim - 1; i >= 0; i--) { + if (out_index == 0) break; + dim_index = out_index % out_dims_array[i]; + out_index = out_index / out_dims_array[i]; + if (x_dims_array[i] > 1) { + x_index += dim_index * x_index_prod; + x_index_prod *= x_dims_array[i]; + } + if (y_dims_array[i] > 1) { + y_index += dim_index * y_index_prod; + y_index_prod *= y_dims_array[i]; + } + } + ddout_data[tid] = dout_op( + ddx_data[x_index], ddy_data[y_index], y_data[y_index], out_data[tid]); +} + +template +void ComputeDDoutWithBroadcast(const GPUContext& dev_ctx UNUSED, + const phi::DenseTensor& ddx, + const phi::DenseTensor& ddy, + const phi::DenseTensor& y, + const phi::DenseTensor& out, + phi::DenseTensor* ddout, + const int* x_dims_array, + const int* y_dims_array, + const int* out_dims_array, + const int max_dim, + DDout_OP dout_op) { + auto out_numel = out.numel(); + auto* ddx_data = ddx.data(); + auto* ddy_data = ddy.data(); + auto* y_data = y.data(); + auto* out_data = out.data(); + auto* ddout_data = ddout->data(); + DenseTensor x_dims_array_gpu; + x_dims_array_gpu.Resize({max_dim}); + int* x_dims_array_gpu_data = dev_ctx.template Alloc(&x_dims_array_gpu); +#if defined(__NVCC__) + cudaMemcpy(x_dims_array_gpu_data, + x_dims_array, + sizeof(int) * max_dim, + cudaMemcpyHostToDevice); +#else + hipMemcpy(x_dims_array_gpu_data, + x_dims_array, + sizeof(int) * max_dim, + hipMemcpyHostToDevice); +#endif + DenseTensor y_dims_array_gpu; + y_dims_array_gpu.Resize({max_dim}); + int* y_dims_array_gpu_data = dev_ctx.template Alloc(&y_dims_array_gpu); +#if defined(__NVCC__) + cudaMemcpy(y_dims_array_gpu_data, + y_dims_array, + sizeof(int) * max_dim, + cudaMemcpyHostToDevice); +#else + hipMemcpy(y_dims_array_gpu_data, + y_dims_array, + sizeof(int) * max_dim, + hipMemcpyHostToDevice); +#endif + DenseTensor out_dims_array_gpu; + out_dims_array_gpu.Resize({max_dim}); + int* out_dims_array_gpu_data = + dev_ctx.template Alloc(&out_dims_array_gpu); +#if defined(__NVCC__) + cudaMemcpy(out_dims_array_gpu_data, + out_dims_array, + sizeof(int) * max_dim, + cudaMemcpyHostToDevice); +#else + hipMemcpy(out_dims_array_gpu_data, + out_dims_array, + sizeof(int) * max_dim, + hipMemcpyHostToDevice); +#endif + int block = 512; + int64_t grid = (out_numel + block - 1) / block; + auto stream = reinterpret_cast(dev_ctx).stream(); + ComputeDDoutWithBroadcastGPUKernel + <<>>(ddx_data, + ddy_data, + y_data, + out_data, + ddout_data, + out_numel, + x_dims_array_gpu_data, + y_dims_array_gpu_data, + out_dims_array_gpu_data, + max_dim, + dout_op); +} + +#endif + +template +void DivDoubleDDoutCompute(const DeviceContext& dev_ctx, + const phi::DenseTensor& ddx, + const phi::DenseTensor& ddy, + const phi::DenseTensor& y, + const phi::DenseTensor& out, + int axis, + phi::DenseTensor* ddout, + DDout_OP dout_op) { + auto x_dims = ddx.dims(); + auto y_dims = ddy.dims(); + if (x_dims == y_dims) { + ComputeDDoutWithoutBroadcast( + dev_ctx, ddx, ddy, y, out, ddout, dout_op); + } else { + int max_dim = std::max(x_dims.size(), y_dims.size()); + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + std::vector x_dims_array(max_dim, 0); + std::vector y_dims_array(max_dim, 0); + std::vector out_dims_array(max_dim, 0); + phi::funcs::GetBroadcastDimsArrays(x_dims, + y_dims, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + ComputeDDoutWithBroadcast(dev_ctx, + ddx, + ddy, + y, + out, + ddout, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + dout_op); + } +} + template void DivideDoubleGradKernel(const Context& dev_ctx, const DenseTensor& y, const DenseTensor& out, - const DenseTensor& dx, + const DenseTensor& grad_out, + const paddle::optional& dx, const paddle::optional& ddx, const paddle::optional& ddy, int axis, DenseTensor* dy, DenseTensor* dout, DenseTensor* ddout) { - if (dy) { - dy->Resize(y.dims()); - dev_ctx.template Alloc(dy); - } - if (dout) { - dout->Resize(out.dims()); - dev_ctx.template Alloc(dout); - } - if (ddout) { - ddout->Resize(out.dims()); - dev_ctx.template Alloc(ddout); + auto* ddx_tensor = ddx.get_ptr(); + auto* ddy_tensor = ddy.get_ptr(); + auto* dx_tensor = dx.get_ptr(); + DenseTensor dz_div_y; + if ((dy || dout) && (!dx_tensor || dx_tensor->dims() != out.dims())) { + dz_div_y.Resize(out.dims()); + dev_ctx.template Alloc(&dz_div_y); + funcs::DefaultElementwiseOperator, + funcs::InverseDivideFunctor>( + dev_ctx, grad_out, y, &dz_div_y, axis); + dx_tensor = &dz_div_y; } - // ddX_safe == null ? 0 : ddX - // ddY_safe == null ? 0 : ddY - DenseTensor ddX_safe, ddY_safe; - phi::funcs::GetDoubleGradSafeTensor( - dev_ctx, dx, ddx.get_ptr(), &ddX_safe); - phi::funcs::GetDoubleGradSafeTensor( - dev_ctx, y, ddy.get_ptr(), &ddY_safe); - // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y // dY = Out * dX * ddY / Y - dX * ddX / Y // dOut = - dX * ddY @@ -200,69 +484,169 @@ void DivideDoubleGradKernel(const Context& dev_ctx, // inplace ddx DenseTensor tmp; if (dout) { + dout->Resize(out.dims()); + dev_ctx.template Alloc(dout); tmp = *dout; } else { tmp.Resize(out.dims()); dev_ctx.template Alloc(&tmp); } if (dy) { - // dX_div_Y = dX / Y; - DenseTensor dX_div_Y = tmp; - funcs::DefaultElementwiseOperator, - funcs::InverseDivideFunctor>( - dev_ctx, dx, y, &dX_div_Y, axis); - - // NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the - // first output tensor is nullptr, the branch to calculate first - // output tensor will not be activated, DivGradDx function will not - // be called and can be ignored, the first branch has little effect - // on running speed. + dy->Resize(y.dims()); + dev_ctx.template Alloc(dy); + if (!ddx_tensor && !ddy_tensor) { + FullLikeKernel( + dev_ctx, y, Scalar(static_cast(0.0)), y.dtype(), dy); + } else { + // pre-compute 'dX / Y' into 'tmp' for 'ddout' and/or 'dy' + funcs::DefaultElementwiseOperator, + funcs::InverseDivideFunctor>( + dev_ctx, *dx_tensor, y, &tmp, axis); + if (ddx_tensor && !ddy_tensor) { + // dy = -dX * ddX / Y + phi::funcs::ElemwiseGradCompute, + DivDoubleDY_Only_DDX>( + dev_ctx, + *ddx_tensor, // ddx + y, + out, // out + tmp, // dX /Y + axis, + nullptr, + dy, + DivGradDX(), + DivDoubleDY_Only_DDX()); + } else if (!ddx_tensor && ddy_tensor) { + // dY = Out * dX * ddY / Y + phi::funcs::ElemwiseGradCompute, + DivDoubleDY_Only_DDY>( + dev_ctx, + *dx_tensor, + *ddy_tensor, // ddy + out, // out + tmp, // dX / Y + axis, + nullptr, + dy, + DivGradDX(), + DivDoubleDY_Only_DDY()); + } else { + // dY = Out * dX * ddY / Y - dX * ddX / Y - // dY = Out * dX * ddY / Y - dX * ddX / Y - phi::funcs::ElemwiseGradCompute, DivDoubleDY>( - dev_ctx, - ddX_safe, - ddY_safe, - out, - dX_div_Y, - axis, - nullptr, - dy, - DivGradDX(), - DivDoubleDY()); + // NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the + // first output tensor is nullptr, the branch to calculate first + // output tensor will not be activated, DivGradDx function will not + // be called and can be ignored, the first branch has little effect + // on running speed. + phi::funcs:: + ElemwiseGradCompute, DivDoubleDY>( + dev_ctx, + *ddx_tensor, // ddx + *ddy_tensor, // ddy + out, // out + tmp, // dX / Y + axis, + nullptr, + dy, + DivGradDX(), + DivDoubleDY()); + } + } } if (ddout) { + ddout->Resize(out.dims()); + dev_ctx.template Alloc(ddout); // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y - funcs::DefaultElementwiseOperator, - funcs::InverseMultiplyFunctor>( - dev_ctx, out, ddY_safe, &tmp, axis); - funcs::DefaultElementwiseOperator, - funcs::InverseSubtractFunctor>( - dev_ctx, ddX_safe, tmp, &tmp, axis); - funcs::DefaultElementwiseOperator, - funcs::InverseDivideFunctor>( - dev_ctx, tmp, y, ddout, axis); + if (!ddx_tensor && !ddy_tensor) { + FullLikeKernel( + dev_ctx, out, Scalar(static_cast(0.0)), out.dtype(), ddout); + } else if (ddx_tensor != nullptr && ddy_tensor == nullptr) { + // ddOut = ddX / Y + funcs::DefaultElementwiseOperator, + funcs::InverseDivideFunctor>( + dev_ctx, *ddx_tensor, y, ddout, axis); + } else if (!ddx_tensor && ddy_tensor) { +// ddOut = - Out * ddY / Y +#if defined(__xpu__) + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, out, *ddy_tensor, &tmp, axis); + funcs::DefaultElementwiseOperator, + funcs::InverseDivideFunctor>( + dev_ctx, tmp, y, ddout, axis); + auto& place = *dev_ctx.eigen_device(); + auto ddout_result = phi::EigenVector::Flatten(*ddout); + ddout_result.device(place) = static_cast(-1) * ddout_result; +#else + DivDoubleDDoutCompute, T>( + dev_ctx, + *dx_tensor, + *ddy_tensor, + y, + out, + axis, + ddout, + DivDoubleDDOut_Only_DDY()); +#endif + } else { +#if defined(__xpu__) + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, out, *ddy_tensor, &tmp, axis); + funcs::DefaultElementwiseOperator, + funcs::InverseSubtractFunctor>( + dev_ctx, *ddx_tensor, tmp, &tmp, axis); + funcs::DefaultElementwiseOperator, + funcs::InverseDivideFunctor>( + dev_ctx, tmp, y, ddout, axis); +#else + DivDoubleDDoutCompute, T>( + dev_ctx, + *ddx_tensor, + *ddy_tensor, + y, + out, + axis, + ddout, + DivDoubleDDOut()); +#endif + } } if (dout) { - // dOut = - dX * ddY - funcs::DefaultElementwiseOperator, - funcs::InverseMultiplyFunctor>( - dev_ctx, dx, ddY_safe, dout, axis); - auto& place = *dev_ctx.eigen_device(); - auto dout_result = phi::EigenVector::Flatten(*dout); - dout_result.device(place) = static_cast(-1) * dout_result; + if (!ddy_tensor) { + FullLikeKernel( + dev_ctx, out, Scalar(static_cast(0.0)), out.dtype(), dout); + } else { + // dOut = - dX * ddY + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, *dx_tensor, *ddy_tensor, dout, axis); + auto& place = *dev_ctx.eigen_device(); + auto dout_result = phi::EigenVector::Flatten(*dout); + dout_result.device(place) = static_cast(-1) * dout_result; + } } } template diff --git a/test/cpp/fluid/elementwise/test_elementwise_div_grad_grad.cc b/test/cpp/fluid/elementwise/test_elementwise_div_grad_grad.cc index ddf1229cd0367..a29cc2ea43f7c 100644 --- a/test/cpp/fluid/elementwise/test_elementwise_div_grad_grad.cc +++ b/test/cpp/fluid/elementwise/test_elementwise_div_grad_grad.cc @@ -41,16 +41,16 @@ namespace paddle { namespace operators { template -class TestElementwiseDivGradGradWithoutDout - : public TestElementwiseOpGradGrad { +class TestElementwiseDivGradGradWithDout : public TestElementwiseOpGradGrad { public: - TestElementwiseDivGradGradWithoutDout(const platform::Place &place, - const framework::DDim &dims) - : TestElementwiseOpGradGrad("elementwise_div_grad_grad", - place, - dims, - {"Y", "Out", "DDX", "DDY", "DX"}, - {"Y@GRAD", "DDOut"}) {} + TestElementwiseDivGradGradWithDout(const platform::Place &place, + const framework::DDim &dims) + : TestElementwiseOpGradGrad( + "elementwise_div_grad_grad", + place, + dims, + {"Y", "Out", "Out@GRAD", "DDX", "DDY", "DX"}, + {"Y@GRAD", "DDOut", "DOut"}) {} using TestElementwiseOpGradGrad::feed_datas_; using TestElementwiseOpGradGrad::expected_outs_; @@ -59,6 +59,7 @@ class TestElementwiseDivGradGradWithoutDout size_t numel = static_cast(common::product(dims_)); std::vector dy(numel); std::vector ddout(numel); + std::vector dout(numel); for (size_t i = 0; i < numel; ++i) { // dY(Y@GRAD) = Out * dX * ddY / Y - dX * ddX / Y dy[i] = (feed_datas_["DX"][i] / feed_datas_["Y"][i]) * @@ -68,9 +69,12 @@ class TestElementwiseDivGradGradWithoutDout ddout[i] = (feed_datas_["DDX"][i] - feed_datas_["Out"][i] * feed_datas_["DDY"][i]) / (feed_datas_["Y"][i]); + // dOut = - DX * DDy + dout[i] = -feed_datas_["DX"][i] * feed_datas_["DDY"][i]; } expected_outs_["Y@GRAD"] = dy; expected_outs_["DDOut"] = ddout; + expected_outs_["DOut"] = dout; } std::unique_ptr CreateTestOp() override { @@ -78,27 +82,28 @@ class TestElementwiseDivGradGradWithoutDout this->op_type_, {{"Y", {"Y"}}, {"Out", {"Out"}}, + {"Out@GRAD", {"Out@GRAD"}}, {"DDX", {"DDX"}}, {"DDY", {"DDY"}}, {"DX", {"DX"}}}, - {{"Y@GRAD", {"Y@GRAD"}}, {"DDOut", {"DDOut"}}}, + {{"Y@GRAD", {"Y@GRAD"}}, {"DDOut", {"DDOut"}}, {"DOut", {"DOut"}}}, {{"use_mkldnn", false}, {"axis", 0}}); return op; } }; -TEST(test_elementwise_div_grad_grad_without_dout, cpu_place) { +TEST(test_elementwise_div_grad_grad, cpu_place) { framework::DDim dims({32, 64}); platform::CPUPlace p; - TestElementwiseDivGradGradWithoutDout test(p, dims); + TestElementwiseDivGradGradWithDout test(p, dims); ASSERT_TRUE(test.Check()); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -TEST(test_elementwise_div_grad_grad_without_dout, gpu_place) { +TEST(test_elementwise_div_grad_grad, gpu_place) { framework::DDim dims({32, 64}); platform::CUDAPlace p(0); - TestElementwiseDivGradGradWithoutDout test(p, dims); + TestElementwiseDivGradGradWithDout test(p, dims); ASSERT_TRUE(test.Check()); } #endif diff --git a/test/cpp/fluid/elementwise/test_elementwise_op_grad_grad.h b/test/cpp/fluid/elementwise/test_elementwise_op_grad_grad.h index ab67c559532d9..3e772aa632e52 100644 --- a/test/cpp/fluid/elementwise/test_elementwise_op_grad_grad.h +++ b/test/cpp/fluid/elementwise/test_elementwise_op_grad_grad.h @@ -135,8 +135,18 @@ class TestElementwiseOpGradGrad { expected_outs_[out_name].data(), [](const float &l, const float &r) { return fabs(l - r) < 1e-8; }); #else - auto is_equal = - std::equal(out_ptr, out_ptr + numel, expected_outs_[out_name].data()); + bool is_equal; + if (op_type_ == "elementwise_div_grad_grad") { + is_equal = std::equal(out_ptr, + out_ptr + numel, + expected_outs_[out_name].data(), + [](const float &l, const float &r) { + return fabs(l - r) < 0.0005; + }); + } else { + is_equal = std::equal( + out_ptr, out_ptr + numel, expected_outs_[out_name].data()); + } #endif if (!is_equal) { all_equal = false;