-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
improve the performence of divide_double_grad #62533
Changes from 1 commit
467d580
5563d5e
2683315
f28f02b
4cd0aa2
fd22057
ef8f29d
96c6cc6
a74a2fa
d22d848
6c40ba6
87350dc
0f8b0f4
99e5281
8016cf5
a718d8e
156198a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -177,6 +177,17 @@ void DivideDoubleGradKernel(const Context& dev_ctx, | |
auto* ddx_tensor = ddx.get_ptr(); | ||
auto* ddy_tensor = ddy.get_ptr(); | ||
auto* dx_tensor = dx.get_ptr(); | ||
DenseTensor dz_div_y; | ||
dz_div_y.Resize(out.dims()); | ||
if (dx_tensor == nullptr || dx_tensor->dims() != out.dims()) { | ||
dev_ctx.template Alloc<T>(&dz_div_y); | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::DivideFunctor<T>, | ||
funcs::InverseDivideFunctor<T>>( | ||
dev_ctx, grad_out, y, &dz_div_y, axis); | ||
dx_tensor = &dz_div_y; | ||
} | ||
// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y | ||
// dY = Out * dX * ddY / Y - dX * ddX / Y | ||
// dOut = - dX * ddY | ||
|
@@ -195,17 +206,6 @@ void DivideDoubleGradKernel(const Context& dev_ctx, | |
if (ddx_tensor == nullptr && ddy_tensor == nullptr) { | ||
dy = nullptr; | ||
} else { | ||
if (dx_tensor == nullptr || dx_tensor->dims() != out.dims()) { | ||
DenseTensor dz_div_y; | ||
dz_div_y.Resize(out.dims()); | ||
dev_ctx.template Alloc<T>(&dz_div_y); | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::DivideFunctor<T>, | ||
funcs::InverseDivideFunctor<T>>( | ||
dev_ctx, grad_out, y, &dz_div_y, axis); | ||
dx_tensor = &dz_div_y; | ||
} | ||
DenseTensor tmp_dy = tmp; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 209行感觉可以删掉吧,反正tmp_dy也是语义不明,还不如直接用tmp,删掉之后下面的tmp_dy全部改成tmp There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
// dX / Y | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
funcs::DefaultElementwiseOperator<Context, | ||
|
@@ -312,17 +312,6 @@ void DivideDoubleGradKernel(const Context& dev_ctx, | |
if (ddy_tensor == nullptr) { | ||
dout = nullptr; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} else { | ||
if (dx_tensor == nullptr || dx_tensor->dims() != out.dims()) { | ||
DenseTensor dz_div_y; | ||
dz_div_y.Resize(out.dims()); | ||
dev_ctx.template Alloc<T>(&dz_div_y); | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
funcs::DivideFunctor<T>, | ||
funcs::InverseDivideFunctor<T>>( | ||
dev_ctx, grad_out, y, &dz_div_y, axis); | ||
dx_tensor = &dz_div_y; | ||
} | ||
// dOut = - dX * ddY | ||
funcs::DefaultElementwiseOperator<Context, | ||
T, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FullLikeKernel<T, Context>(dev_ctx, y, Scalar(0.0), y.dtype(), dy);
,否则大量代码需要修改。其他地方问题类似
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done