-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the FP16 precision problem of add_n. #50129
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
c309274
to
94ea8a5
Compare
} | ||
} | ||
out[id] = total; | ||
out[id] = static_cast<T>(total); | ||
id += blockDim.x * gridDim.x; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
只改这里应该不够吧?AddNKernel中有一些条件走了其他实现,比如如果是2个tensor,调用eigen库实现的,那里加法也需要提升到fp32计算。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已解决。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* Fix scale kernel for low precision, cherry pick #50998. * Fix the FP16 precision problem of add_n. (#50129) * Change squared_l2_norm to reuse ReduceKernel, and register fp16 and bf16 kernel, which is cherry pick #48315. * Cherry-pick the fix of MPTypeTrait in KP, which is implemented in #50993. * Cherry-pick the multi-precision support of AdamW for bf16, #48041. * Fix compiling error. * Cherry-pick the fix of CubTensorReduceImpl for bfloat16 in #50993. * Fix unittest. --------- Co-authored-by: liuruyan <[email protected]>
PR types
Performance optimization
PR changes
Others
Describe
Fix the FP16 precision problem of add_n.