-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change the input param of fusion op interface from pointer to tensor #36349
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
ASSERT_THROW(test.CheckBackward(1e-3), paddle::platform::EnforceNotMet); | ||
} else { | ||
ASSERT_NO_THROW(test.CheckForward(1e-3, true)); | ||
ASSERT_NO_THROW(test.CheckBackward(1e-3)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里其实没有必要加ASSERT_NO_THROW
吧
T *z_bias_ptr = nullptr) { | ||
void Forward(const platform::CUDADeviceContext &ctx, const Tensor &x, | ||
const Tensor &x_scale, const Tensor &x_bias, const Tensor &z, | ||
const Tensor &z_scale, const Tensor &z_bias, Tensor *out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
z
可能并不存在,这种情况如何调用,也需要定义个空的Tensor?
PR types
Others
PR changes
OPs
Describe
将 cudnn_norm_conv、cudnn_bn_stats_finalize 与 cudnn_scale_bias_add_relu 三个融合op的接口参数,由指针改为tensor,相应的修改了对应的单测。