-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement GPU kernel for cross entropy operator. #3501
Conversation
paddle/operators/cross_entropy_op.cu
Outdated
using Tensor = framework::Tensor; | ||
|
||
template <typename T> | ||
struct clipping_log { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not just use a function
template<typename T>
__host__ __device__ T clipping_log(const Tx) {
...
return xx;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I notice that Majel and WarpCTC use many functors
to define many common functions. The functor has its own advantages, see: https://stackoverflow.com/questions/356950/c-functors-and-their-uses
But it is not necessary here, I change it to a plain function.
paddle/operators/cross_entropy_op.cu
Outdated
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; | ||
i += blockDim.x * gridDim.x) { | ||
PADDLE_ASSERT(label[i] >= 0 && label[i] < D); | ||
Y[i] = -clipping_log<T>()(X[i * D + label[i]]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我有点没搞清楚啊,交叉熵的计算公式是什么呀?不是 cost = 求和(yln(x) + (1-y)ln(1-x)) 这个吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
label实际是one-hot类型: [0, 0, ..., 1, ..., 0,0] ,所以这里是:y[i] = -log(x[ i * D + label[i] ]), 求和通常不在cross_entropy_op
里做,是后面接的op或者打印loss时再求和。
Hello, do we need something like <<<grid, block, 0, stream_id >>> to execute it? |
@zchen0211 We need to specify the stream, like <<<grid, block, 0, stream_id >>>. I think the stream_id can be got from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Fix #3492