Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement GPU kernel for cross entropy operator. #3501

Merged
merged 6 commits into from
Aug 22, 2017

Conversation

qingqing01
Copy link
Contributor

@qingqing01 qingqing01 commented Aug 15, 2017

Fix #3492

  • Implement GPU kernel for cross entropy operator.
  • Fix the origin CPU implementation.

using Tensor = framework::Tensor;

template <typename T>
struct clipping_log {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just use a function

template<typename T> 
__host__ __device__ T clipping_log(const Tx) {
  ...
  return xx;
}

Copy link
Contributor Author

@qingqing01 qingqing01 Aug 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice that Majel and WarpCTC use many functors to define many common functions. The functor has its own advantages, see: https://stackoverflow.com/questions/356950/c-functors-and-their-uses

But it is not necessary here, I change it to a plain function.

for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
Y[i] = -clipping_log<T>()(X[i * D + label[i]]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我有点没搞清楚啊,交叉熵的计算公式是什么呀?不是 cost = 求和(yln(x) + (1-y)ln(1-x)) 这个吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

label实际是one-hot类型: [0, 0, ..., 1, ..., 0,0] ,所以这里是:y[i] = -log(x[ i * D + label[i] ]), 求和通常不在cross_entropy_op里做,是后面接的op或者打印loss时再求和。

@zchen0211
Copy link
Contributor

Hello, do we need something like <<<grid, block, 0, stream_id >>> to execute it?

@qingqing01
Copy link
Contributor Author

@zchen0211 We need to specify the stream, like <<<grid, block, 0, stream_id >>>. I think the stream_id can be got from CUDAContext, but there is no stream in CUDAContext now. And @QiJune said he has worked on it last week.

Copy link
Member

@QiJune QiJune left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@qingqing01 qingqing01 merged commit f931140 into PaddlePaddle:develop Aug 22, 2017
@qingqing01 qingqing01 deleted the cross_entropy branch March 7, 2018 12:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants