-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
improve the performence of divide_double_grad #62533
Changes from 5 commits
467d580
5563d5e
2683315
f28f02b
4cd0aa2
fd22057
ef8f29d
96c6cc6
a74a2fa
d22d848
6c40ba6
87350dc
0f8b0f4
99e5281
8016cf5
a718d8e
156198a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -166,33 +166,28 @@ template <typename T, typename Context> | |
void DivideDoubleGradKernel(const Context& dev_ctx, | ||
const DenseTensor& y, | ||
const DenseTensor& out, | ||
const DenseTensor& dx, | ||
const DenseTensor& grad_out, | ||
const paddle::optional<DenseTensor>& dx, | ||
const paddle::optional<DenseTensor>& ddx, | ||
const paddle::optional<DenseTensor>& ddy, | ||
int axis, | ||
DenseTensor* dy, | ||
DenseTensor* dout, | ||
DenseTensor* ddout) { | ||
if (dy) { | ||
dy->Resize(y.dims()); | ||
dev_ctx.template Alloc<T>(dy); | ||
} | ||
if (dout) { | ||
dout->Resize(out.dims()); | ||
dev_ctx.template Alloc<T>(dout); | ||
} | ||
if (ddout) { | ||
ddout->Resize(out.dims()); | ||
dev_ctx.template Alloc<T>(ddout); | ||
auto* ddx_tensor = ddx.get_ptr(); | ||
auto* ddy_tensor = ddy.get_ptr(); | ||
auto* dx_tensor = dx.get_ptr(); | ||
DenseTensor dz_div_y; | ||
dz_div_y.Resize(out.dims()); | ||
if (!dx_tensor || dx_tensor->dims() != out.dims()) { | ||
dev_ctx.template Alloc<T>(&dz_div_y); | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::DivideFunctor<T>, | ||
funcs::InverseDivideFunctor<T>>( | ||
dev_ctx, grad_out, y, &dz_div_y, axis); | ||
dx_tensor = &dz_div_y; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 由于
这样尽量减少不必要的计算。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
// ddX_safe == null ? 0 : ddX | ||
// ddY_safe == null ? 0 : ddY | ||
DenseTensor ddX_safe, ddY_safe; | ||
phi::funcs::GetDoubleGradSafeTensor<Context, T>( | ||
dev_ctx, dx, ddx.get_ptr(), &ddX_safe); | ||
phi::funcs::GetDoubleGradSafeTensor<Context, T>( | ||
dev_ctx, y, ddy.get_ptr(), &ddY_safe); | ||
|
||
// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y | ||
// dY = Out * dX * ddY / Y - dX * ddX / Y | ||
// dOut = - dX * ddY | ||
|
@@ -206,63 +201,127 @@ void DivideDoubleGradKernel(const Context& dev_ctx, | |
dev_ctx.template Alloc<T>(&tmp); | ||
} | ||
if (dy) { | ||
// dX_div_Y = dX / Y; | ||
DenseTensor dX_div_Y = tmp; | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::DivideFunctor<T>, | ||
funcs::InverseDivideFunctor<T>>( | ||
dev_ctx, dx, y, &dX_div_Y, axis); | ||
|
||
// NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the | ||
// first output tensor is nullptr, the branch to calculate first | ||
// output tensor will not be activated, DivGradDx function will not | ||
// be called and can be ignored, the first branch has little effect | ||
// on running speed. | ||
dy->Resize(y.dims()); | ||
dev_ctx.template Alloc<T>(dy); | ||
if (!ddx_tensor && !ddy_tensor) { | ||
FullLikeKernel<T, Context>(dev_ctx, y, Scalar(0.0), y.dtype(), dy); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 构造函数的第三个参数 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} else { | ||
DenseTensor tmp_dy = tmp; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 209行感觉可以删掉吧,反正tmp_dy也是语义不明,还不如直接用tmp,删掉之后下面的tmp_dy全部改成tmp There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
// dX / Y | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::DivideFunctor<T>, | ||
funcs::InverseDivideFunctor<T>>( | ||
dev_ctx, *dx_tensor, y, &tmp_dy, axis); | ||
if (ddx_tensor && !ddy_tensor) { | ||
// dy = -dX * ddX / Y | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::MultiplyFunctor<T>, | ||
funcs::InverseMultiplyFunctor<T>>( | ||
dev_ctx, *ddx_tensor, tmp_dy, dy, axis); | ||
auto& place = *dev_ctx.eigen_device(); | ||
auto dy_result = phi::EigenVector<T>::Flatten(*dy); | ||
dy_result.device(place) = static_cast<T>(-1) * dy_result; | ||
} else if (!ddx_tensor && ddy_tensor) { | ||
// dY = Out * dX * ddY / Y | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::MultiplyFunctor<T>, | ||
funcs::InverseMultiplyFunctor<T>>( | ||
dev_ctx, *ddy_tensor, tmp_dy, &tmp_dy, axis); | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::MultiplyFunctor<T>, | ||
funcs::InverseMultiplyFunctor<T>>( | ||
dev_ctx, out, tmp_dy, dy, axis); | ||
} else { | ||
// dY = Out * dX * ddY / Y - dX * ddX / Y | ||
|
||
// dY = Out * dX * ddY / Y - dX * ddX / Y | ||
phi::funcs::ElemwiseGradCompute<Context, T, DivGradDX<T>, DivDoubleDY<T>>( | ||
dev_ctx, | ||
ddX_safe, | ||
ddY_safe, | ||
out, | ||
dX_div_Y, | ||
axis, | ||
nullptr, | ||
dy, | ||
DivGradDX<T>(), | ||
DivDoubleDY<T>()); | ||
// NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the | ||
// first output tensor is nullptr, the branch to calculate first | ||
// output tensor will not be activated, DivGradDx function will not | ||
// be called and can be ignored, the first branch has little effect | ||
// on running speed. | ||
phi::funcs:: | ||
ElemwiseGradCompute<Context, T, DivGradDX<T>, DivDoubleDY<T>>( | ||
dev_ctx, | ||
*ddx_tensor, | ||
*ddy_tensor, | ||
out, | ||
tmp_dy, | ||
axis, | ||
nullptr, | ||
dy, | ||
DivGradDX<T>(), | ||
DivDoubleDY<T>()); | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的逻辑看起来是首先统一在 开头求出公共项: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} | ||
|
||
if (ddout) { | ||
ddout->Resize(out.dims()); | ||
dev_ctx.template Alloc<T>(ddout); | ||
// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::MultiplyFunctor<T>, | ||
funcs::InverseMultiplyFunctor<T>>( | ||
dev_ctx, out, ddY_safe, &tmp, axis); | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::SubtractFunctor<T>, | ||
funcs::InverseSubtractFunctor<T>>( | ||
dev_ctx, ddX_safe, tmp, &tmp, axis); | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::DivideFunctor<T>, | ||
funcs::InverseDivideFunctor<T>>( | ||
dev_ctx, tmp, y, ddout, axis); | ||
if (!ddx_tensor && !ddy_tensor) { | ||
FullLikeKernel<T, Context>(dev_ctx, out, Scalar(0.0), out.dtype(), ddout); | ||
} else if (ddx_tensor != nullptr && ddy_tensor == nullptr) { | ||
// ddOut = ddX / Y | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::DivideFunctor<T>, | ||
funcs::InverseDivideFunctor<T>>( | ||
dev_ctx, *ddx_tensor, y, ddout, axis); | ||
} else if (!ddx_tensor && ddy_tensor) { | ||
// ddOut = - Out * ddY / Y | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::MultiplyFunctor<T>, | ||
funcs::InverseMultiplyFunctor<T>>( | ||
dev_ctx, out, *ddy_tensor, &tmp, axis); | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::DivideFunctor<T>, | ||
funcs::InverseDivideFunctor<T>>( | ||
dev_ctx, tmp, y, ddout, axis); | ||
auto& place = *dev_ctx.eigen_device(); | ||
auto ddout_result = phi::EigenVector<T>::Flatten(*ddout); | ||
ddout_result.device(place) = static_cast<T>(-1) * ddout_result; | ||
} else { | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::MultiplyFunctor<T>, | ||
funcs::InverseMultiplyFunctor<T>>( | ||
dev_ctx, out, *ddy_tensor, &tmp, axis); | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::SubtractFunctor<T>, | ||
funcs::InverseSubtractFunctor<T>>( | ||
dev_ctx, *ddx_tensor, tmp, &tmp, axis); | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::DivideFunctor<T>, | ||
funcs::InverseDivideFunctor<T>>( | ||
dev_ctx, tmp, y, ddout, axis); | ||
Comment on lines
+606
to
+620
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以把这三次调用合成一次 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同理,多次 |
||
} | ||
|
||
if (dout) { | ||
// dOut = - dX * ddY | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::MultiplyFunctor<T>, | ||
funcs::InverseMultiplyFunctor<T>>( | ||
dev_ctx, dx, ddY_safe, dout, axis); | ||
auto& place = *dev_ctx.eigen_device(); | ||
auto dout_result = phi::EigenVector<T>::Flatten(*dout); | ||
dout_result.device(place) = static_cast<T>(-1) * dout_result; | ||
dout->Resize(out.dims()); | ||
dev_ctx.template Alloc<T>(dout); | ||
if (!ddy_tensor) { | ||
FullLikeKernel<T, Context>(dev_ctx, out, Scalar(0.0), out.dtype(), dout); | ||
} else { | ||
// dOut = - dX * ddY | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::MultiplyFunctor<T>, | ||
funcs::InverseMultiplyFunctor<T>>( | ||
dev_ctx, *dx_tensor, *ddy_tensor, dout, axis); | ||
auto& place = *dev_ctx.eigen_device(); | ||
auto dout_result = phi::EigenVector<T>::Flatten(*dout); | ||
dout_result.device(place) = static_cast<T>(-1) * dout_result; | ||
} | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 326和327之间加个空行 |
||
template <typename T, typename Context> | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DivDoubleDY函数内可以把dout提出来,减少乘法次数