Skip to content

Commit

Permalink
Fix a critical bug in softmax_with_cross_entropy_op backward. (#9120)
Browse files Browse the repository at this point in the history
* Fix a critical bug in softmax_with_cross_entropy_op, which will lead to the wrong gradients.

* Enhance unit testing.
  • Loading branch information
qingqing01 authored Mar 15, 2018
1 parent 1e4c504 commit b5a16dc
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 26 deletions.
48 changes: 24 additions & 24 deletions paddle/fluid/operators/softmax_with_cross_entropy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,21 @@ using Tensor = framework::Tensor;

namespace {
template <typename T>
__global__ void CrossEntropyGrad(T* logit_grad, const T* loss_grad,
const int64_t* labels, const int batch_size,
const int class_num) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int sample_idx = tid / class_num;

if (tid < batch_size) {
PADDLE_ASSERT(labels[sample_idx] >= 0 && labels[sample_idx] < class_num);
logit_grad[tid * class_num + labels[tid]] -= static_cast<T>(1.);
__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
const int batch_size, const int class_num) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size;
i += blockDim.x * gridDim.x) {
int idx = i * class_num + labels[i];
logit_grad[idx] -= static_cast<T>(1.);
}
}

__syncthreads();

if (tid < batch_size * class_num) {
logit_grad[tid] *= loss_grad[sample_idx];
template <typename T>
__global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
const int class_num) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x) {
logit_grad[i] *= loss_grad[i / class_num];
}
}

Expand Down Expand Up @@ -94,22 +94,22 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
const int batch_size = logit_grad->dims()[0];
const int class_num = logit_grad->dims()[1];
int block = 512;
int grid = (batch_size * class_num + block - 1) / block;
auto stream = context.cuda_device_context().stream();

if (context.Attr<bool>("soft_label")) {
int grid = (batch_size * class_num + block - 1) / block;
const T* label_data = labels->data<T>();
SoftCrossEntropyGradientKernel<
T><<<grid, block, 0,
context.template device_context<platform::CUDADeviceContext>()
.stream()>>>(logit_grad_data, loss_grad_data, label_data,
batch_size, class_num);
SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, batch_size, class_num);
} else {
int grid = (batch_size + block - 1) / block;
const int64_t* label_data = labels->data<int64_t>();
CrossEntropyGrad<
T><<<grid, block, 0,
context.template device_context<platform::CUDADeviceContext>()
.stream()>>>(logit_grad_data, loss_grad_data, label_data,
batch_size, class_num);
CrossEntropyGrad<T><<<grid, block, 0, stream>>>(
logit_grad_data, label_data, batch_size, class_num);
int num = batch_size * class_num;
grid = (num + block - 1) / block;
Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,
class_num);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):

def setUp(self):
self.op_type = "softmax_with_cross_entropy"
batch_size = 2
batch_size = 41
class_num = 37

logits = np.random.uniform(0.1, 1.0,
Expand Down Expand Up @@ -59,7 +59,7 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest):

def setUp(self):
self.op_type = "softmax_with_cross_entropy"
batch_size = 2
batch_size = 41
class_num = 37

logits = np.random.uniform(0.1, 1.0,
Expand Down

0 comments on commit b5a16dc

Please sign in to comment.