Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve the performence of divide_double_grad #62533

Merged
2 changes: 1 addition & 1 deletion paddle/fluid/operators/ops_signature/elementwise_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ KernelSignature ElementwiseDivGradOpArgumentMapping(
KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("divide_double_grad",
{"Y", "Out", "DX", "DDX", "DDY"},
{"Y", "Out", "Out@GRAD", "DX", "DDX", "DDY"},
{"axis"},
{"Y@GRAD", "DOut", "DDOut"});
}
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,15 @@

- backward_op : divide_double_grad
forward : divide_grad (Tensor x, Tensor y, Tensor out, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y)
args : (Tensor y, Tensor out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1)
args : (Tensor y, Tensor out, Tensor grad_out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1)
output : Tensor(y_grad), Tensor(out_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [y, grad_x, grad_x]
param : [y, out, out]
kernel :
func : divide_double_grad
data_type : out
optional : grad_x_grad, grad_y_grad
optional : grad_x, grad_x_grad, grad_y_grad
inplace : (grad_x_grad -> grad_out_grad)

- backward_op : divide_grad
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,15 @@

- backward_op : divide_double_grad
forward : divide_grad (Tensor x, Tensor y, Tensor out, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y)
args : (Tensor y, Tensor out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1)
args : (Tensor y, Tensor out, Tensor grad_out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1)
output : Tensor(y_grad), Tensor(out_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [y, grad_x, grad_x]
param : [y, out, out]
kernel :
func : divide_double_grad
data_type : out
optional : grad_x_grad, grad_y_grad
optional : grad_x, grad_x_grad, grad_y_grad
inplace : (grad_x_grad -> grad_out_grad)

- backward_op : divide_grad
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/elementwise_divide_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ template <typename T, typename Context>
void DivideDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dx,
const DenseTensor& grad_out,
const paddle::optional<DenseTensor>& dx,
const paddle::optional<DenseTensor>& ddx,
const paddle::optional<DenseTensor>& ddy,
int axis,
Expand Down
197 changes: 128 additions & 69 deletions paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,33 +166,28 @@ template <typename T, typename Context>
void DivideDoubleGradKernel(const Context& dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DivDoubleDY函数内可以把dout提出来,减少乘法次数

const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dx,
const DenseTensor& grad_out,
const paddle::optional<DenseTensor>& dx,
const paddle::optional<DenseTensor>& ddx,
const paddle::optional<DenseTensor>& ddy,
int axis,
DenseTensor* dy,
DenseTensor* dout,
DenseTensor* ddout) {
if (dy) {
dy->Resize(y.dims());
dev_ctx.template Alloc<T>(dy);
}
if (dout) {
dout->Resize(out.dims());
dev_ctx.template Alloc<T>(dout);
}
if (ddout) {
ddout->Resize(out.dims());
dev_ctx.template Alloc<T>(ddout);
auto* ddx_tensor = ddx.get_ptr();
auto* ddy_tensor = ddy.get_ptr();
auto* dx_tensor = dx.get_ptr();
DenseTensor dz_div_y;
dz_div_y.Resize(out.dims());
if (!dx_tensor || dx_tensor->dims() != out.dims()) {
dev_ctx.template Alloc<T>(&dz_div_y);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::DivideFunctor<T>,
funcs::InverseDivideFunctor<T>>(
dev_ctx, grad_out, y, &dz_div_y, axis);
dx_tensor = &dz_div_y;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

由于dxdz_div_y),只有在算dydout的时候才会使用到,所以我觉得可以:

  1. 182行的if条件改为:if ((dy || dout) && (!dx_tensor || dx_tensor->dims() != out.dims()))
  2. 保持dz_div_y的定义位置不变,将dz_div_y.Resize(out.dims());放到if条件里面,因为只有需要计算dx_tensor时,才需要dz_div_y这个中间变量

这样尽量减少不必要的计算。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// ddX_safe == null ? 0 : ddX
// ddY_safe == null ? 0 : ddY
DenseTensor ddX_safe, ddY_safe;
phi::funcs::GetDoubleGradSafeTensor<Context, T>(
dev_ctx, dx, ddx.get_ptr(), &ddX_safe);
phi::funcs::GetDoubleGradSafeTensor<Context, T>(
dev_ctx, y, ddy.get_ptr(), &ddY_safe);

// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y
// dY = Out * dX * ddY / Y - dX * ddX / Y
// dOut = - dX * ddY
Expand All @@ -206,63 +201,127 @@ void DivideDoubleGradKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(&tmp);
}
if (dy) {
// dX_div_Y = dX / Y;
DenseTensor dX_div_Y = tmp;
funcs::DefaultElementwiseOperator<Context,
T,
funcs::DivideFunctor<T>,
funcs::InverseDivideFunctor<T>>(
dev_ctx, dx, y, &dX_div_Y, axis);

// NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the
// first output tensor is nullptr, the branch to calculate first
// output tensor will not be activated, DivGradDx function will not
// be called and can be ignored, the first branch has little effect
// on running speed.
dy->Resize(y.dims());
dev_ctx.template Alloc<T>(dy);
if (!ddx_tensor && !ddy_tensor) {
FullLikeKernel<T, Context>(dev_ctx, y, Scalar(0.0), y.dtype(), dy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

构造函数的第三个参数Scalar(0.0)是否可以改为Scalar(static_cast<T>(0.0)),否则当T为complex的时候可能会有问题?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

} else {
DenseTensor tmp_dy = tmp;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

209行感觉可以删掉吧,反正tmp_dy也是语义不明,还不如直接用tmp,删掉之后下面的tmp_dy全部改成tmp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// dX / Y
Copy link
Contributor

@HydrogenSulfate HydrogenSulfate Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// dX / Y ==> // pre-compute 'dX / Y' into 'tmp' for 'ddout' and/or 'dy'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

funcs::DefaultElementwiseOperator<Context,
T,
funcs::DivideFunctor<T>,
funcs::InverseDivideFunctor<T>>(
dev_ctx, *dx_tensor, y, &tmp_dy, axis);
if (ddx_tensor && !ddy_tensor) {
// dy = -dX * ddX / Y
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, *ddx_tensor, tmp_dy, dy, axis);
auto& place = *dev_ctx.eigen_device();
auto dy_result = phi::EigenVector<T>::Flatten(*dy);
dy_result.device(place) = static_cast<T>(-1) * dy_result;
} else if (!ddx_tensor && ddy_tensor) {
// dY = Out * dX * ddY / Y
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, *ddy_tensor, tmp_dy, &tmp_dy, axis);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, out, tmp_dy, dy, axis);
} else {
// dY = Out * dX * ddY / Y - dX * ddX / Y

// dY = Out * dX * ddY / Y - dX * ddX / Y
phi::funcs::ElemwiseGradCompute<Context, T, DivGradDX<T>, DivDoubleDY<T>>(
dev_ctx,
ddX_safe,
ddY_safe,
out,
dX_div_Y,
axis,
nullptr,
dy,
DivGradDX<T>(),
DivDoubleDY<T>());
// NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the
// first output tensor is nullptr, the branch to calculate first
// output tensor will not be activated, DivGradDx function will not
// be called and can be ignored, the first branch has little effect
// on running speed.
phi::funcs::
ElemwiseGradCompute<Context, T, DivGradDX<T>, DivDoubleDY<T>>(
dev_ctx,
*ddx_tensor,
*ddy_tensor,
out,
tmp_dy,
axis,
nullptr,
dy,
DivGradDX<T>(),
DivDoubleDY<T>());
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的逻辑看起来是首先统一在 开头求出公共项:dx/y,然后根据不同的条件走不同的分支,有的分支内部调用不止一次DefaultElementwiseOperator,所以这里看一下是否可以全部统一使用ElemwiseGradCompute,if-else分支内也通过一次调用计算完毕,只不过需要根据不同的条件,编写不同的dy_op(如DivDoubleDY_Only_DDX, DivDoubleDY_Only_DDY)并传给ElemwiseGradCompute,这些不同的dy_op运算时真正使用到的参数也是不同的,不使用的参数可以随便传一个形状相同的并且能正常访问的DenseTensor占位即可。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

if (ddout) {
ddout->Resize(out.dims());
dev_ctx.template Alloc<T>(ddout);
// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, out, ddY_safe, &tmp, axis);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::SubtractFunctor<T>,
funcs::InverseSubtractFunctor<T>>(
dev_ctx, ddX_safe, tmp, &tmp, axis);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::DivideFunctor<T>,
funcs::InverseDivideFunctor<T>>(
dev_ctx, tmp, y, ddout, axis);
if (!ddx_tensor && !ddy_tensor) {
FullLikeKernel<T, Context>(dev_ctx, out, Scalar(0.0), out.dtype(), ddout);
} else if (ddx_tensor != nullptr && ddy_tensor == nullptr) {
// ddOut = ddX / Y
funcs::DefaultElementwiseOperator<Context,
T,
funcs::DivideFunctor<T>,
funcs::InverseDivideFunctor<T>>(
dev_ctx, *ddx_tensor, y, ddout, axis);
} else if (!ddx_tensor && ddy_tensor) {
// ddOut = - Out * ddY / Y
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, out, *ddy_tensor, &tmp, axis);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::DivideFunctor<T>,
funcs::InverseDivideFunctor<T>>(
dev_ctx, tmp, y, ddout, axis);
auto& place = *dev_ctx.eigen_device();
auto ddout_result = phi::EigenVector<T>::Flatten(*ddout);
ddout_result.device(place) = static_cast<T>(-1) * ddout_result;
} else {
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, out, *ddy_tensor, &tmp, axis);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::SubtractFunctor<T>,
funcs::InverseSubtractFunctor<T>>(
dev_ctx, *ddx_tensor, tmp, &tmp, axis);
funcs::DefaultElementwiseOperator<Context,
T,
funcs::DivideFunctor<T>,
funcs::InverseDivideFunctor<T>>(
dev_ctx, tmp, y, ddout, axis);
Comment on lines +606 to +620
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以把这三次调用合成一次

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同理,多次DefaultElementwiseOperator调用是否可以优化成一次调用

}

if (dout) {
// dOut = - dX * ddY
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, dx, ddY_safe, dout, axis);
auto& place = *dev_ctx.eigen_device();
auto dout_result = phi::EigenVector<T>::Flatten(*dout);
dout_result.device(place) = static_cast<T>(-1) * dout_result;
dout->Resize(out.dims());
dev_ctx.template Alloc<T>(dout);
if (!ddy_tensor) {
FullLikeKernel<T, Context>(dev_ctx, out, Scalar(0.0), out.dtype(), dout);
} else {
// dOut = - dX * ddY
funcs::DefaultElementwiseOperator<Context,
T,
funcs::MultiplyFunctor<T>,
funcs::InverseMultiplyFunctor<T>>(
dev_ctx, *dx_tensor, *ddy_tensor, dout, axis);
auto& place = *dev_ctx.eigen_device();
auto dout_result = phi::EigenVector<T>::Flatten(*dout);
dout_result.device(place) = static_cast<T>(-1) * dout_result;
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

326和327之间加个空行

template <typename T, typename Context>
Expand Down
12 changes: 7 additions & 5 deletions test/cpp/fluid/elementwise/test_elementwise_div_grad_grad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ class TestElementwiseDivGradGradWithoutDout
public:
TestElementwiseDivGradGradWithoutDout(const platform::Place &place,
const framework::DDim &dims)
: TestElementwiseOpGradGrad<T>("elementwise_div_grad_grad",
place,
dims,
{"Y", "Out", "DDX", "DDY", "DX"},
{"Y@GRAD", "DDOut"}) {}
: TestElementwiseOpGradGrad<T>(
"elementwise_div_grad_grad",
place,
dims,
{"Y", "Out", "Out@GRAD", "DDX", "DDY", "DX"},
{"Y@GRAD", "DDOut"}) {}

using TestElementwiseOpGradGrad<T>::feed_datas_;
using TestElementwiseOpGradGrad<T>::expected_outs_;
Expand Down Expand Up @@ -78,6 +79,7 @@ class TestElementwiseDivGradGradWithoutDout
this->op_type_,
{{"Y", {"Y"}},
{"Out", {"Out"}},
{"Out@GRAD", {"Out@GRAD"}},
{"DDX", {"DDX"}},
{"DDY", {"DDY"}},
{"DX", {"DX"}}},
Expand Down