From 49aa2c042cbae87ada74e7e63590f7b43239c596 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 15 Aug 2017 17:40:26 +0800 Subject: [PATCH 1/3] Implement GPU kernel for cross entropy operator. --- paddle/framework/pybind.cc | 2 +- paddle/operators/cross_entropy_op.cc | 15 +-- paddle/operators/cross_entropy_op.cu | 108 +++++++++++++++++- paddle/operators/cross_entropy_op.h | 11 +- .../framework/tests/test_cross_entropy_op.py | 2 +- 5 files changed, 120 insertions(+), 18 deletions(-) diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index fe0c87bc57082..2b3e7fba411e9 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -31,7 +31,7 @@ limitations under the License. */ namespace py = pybind11; USE_OP(add_two); -USE_CPU_ONLY_OP(onehot_cross_entropy); +USE_OP(onehot_cross_entropy); USE_OP(sgd); USE_OP(mul); USE_OP(mean); diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index a623c551e1088..ab1e1c101a10e 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -39,11 +39,10 @@ class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto X_grad = ctx.Output(framework::GradVarName("X")); + auto dX = ctx.Output(framework::GradVarName("X")); auto X = ctx.Input("X"); - // TODO(superjom) add enforce here after helper functions ready - X_grad->Resize(X->dims()); + dX->Resize(X->dims()); } }; @@ -70,9 +69,7 @@ namespace ops = paddle::operators; REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad, ops::OnehotCrossEntropyGradientOp); -REGISTER_OP_CPU_KERNEL( - onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); -REGISTER_OP_CPU_KERNEL( - onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOpKernel); +REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, + ops::OnehotCrossEntropyOpKernel); +REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOpKernel); diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 4bbc8f093a794..2392c3d5ed98a 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -12,10 +12,108 @@ See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU -#include "paddle/operators/cross_entropy_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/platform/assert.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, + const int N, const int D) { + // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. + // CUDA_1D_KERNEL_LOOP(i, N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + PADDLE_ASSERT(label[i] >= 0 && label[i] < D); + Y[i] = -log(X[i * D + label[i]]); + } +} + +template +__global__ void zero(T* X, const int N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + X[i] = 0.0; + } +} + +template +__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, + const int* label, const int N, + const int D) { + // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. + // CUDA_1D_KERNEL_LOOP(i, N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + int idx = i * D + label[i]; + dX[idx] = -dY[i] / X[idx]; + } +} + +template +class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto X = ctx.Input("X"); + const T* Xdata = X->data(); + const int* label_data = ctx.Input("label")->data(); + auto Y = ctx.Output("Y"); + Y->mutable_data(ctx.GetPlace()); + T* Ydata = Y->data(); + + int N = X->dims()[0]; + int D = X->dims()[1]; + int block = 512; + int grid = (N + block - 1) / block; + // TODO(qingqing) launch kernel on specified stream + // base on ExecutionContext. + CrossEntropyKernel<<>>(Ydata, Xdata, label_data, N, D); + } +}; + +template +class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto X = ctx.Input("X"); + auto dX = ctx.Output(framework::GradVarName("X")); + auto dY = ctx.Input(framework::GradVarName("Y")); + auto label = ctx.Input("label"); + + auto* dXdata = dX->template mutable_data(ctx.GetPlace()); + auto* dYdata = dY->template data(); + auto* Xdata = X->template data(); + auto* label_data = label->data(); + + int N = X->dims()[0]; + int D = X->dims()[1]; + int block = 512; + int grid = (N * D + block - 1) / block; + // TODO(qingqing): make zero an common function. + zero<<>>(dXdata, N * D); + + grid = (N + block - 1) / block; + // TODO(qingqing): launch kernel on specified stream + // base on ExecutionContext. + CrossEntropyGradientKernel<<>>(dXdata, dYdata, Xdata, + label_data, N, D); + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL( - onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); +REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, + ops::OnehotCrossEntropyOpCUDAKernel); +REGISTER_OP_GPU_KERNEL(onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOpCUDAKernel); diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index b7df92c9a98eb..261cbe2d423ab 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -39,10 +39,13 @@ T tolerable_value(T x) { return x; } -template +template class OnehotCrossEntropyOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + auto X = ctx.Input("X"); const T* Xdata = X->data(); const int* label_data = ctx.Input("label")->data(); @@ -62,10 +65,13 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel { } }; -template +template class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + auto X = ctx.Input("X"); auto dX = ctx.Output(framework::GradVarName("X")); auto dY = ctx.Input(framework::GradVarName("Y")); @@ -79,6 +85,7 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { const int batch_size = X->dims()[0]; const int class_num = X->dims()[1]; + memset(dXdata, 0, sizeof(T) * batch_size * class_num); for (int i = 0; i < batch_size; ++i) { int index = i * class_num + label_data[i]; dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]); diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index 4815192e255c6..5557e0d35820d 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -22,7 +22,7 @@ def setUp(self): class CrossEntropyGradOpTest(GradientChecker): - def test_softmax_grad(self): + def test_check_grad(self): op = create_op("onehot_cross_entropy") batch_size = 100 class_num = 10 From 26475cd9ba4539a74cd2d36e8697fba4fbc52ddb Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 15 Aug 2017 19:25:16 +0800 Subject: [PATCH 2/3] Use clipping log in cuda kernel, making it same with CPU. --- paddle/operators/cross_entropy_op.cu | 19 +++++++++++++++++-- paddle/operators/cross_entropy_op.h | 3 ++- .../paddle/v2/framework/tests/op_test_util.py | 3 ++- .../framework/tests/test_cross_entropy_op.py | 5 ++--- 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 2392c3d5ed98a..5f5d2692670b8 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -20,6 +20,21 @@ namespace operators { using Tensor = framework::Tensor; +template +struct clipping_log { + __host__ __device__ T operator()(const T x) { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + if (x == INFINITY) { + return kApproInf; + } + if (x == -INFINITY) { + return -kApproInf; + } + return x; + } +}; + template __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, const int N, const int D) { @@ -28,10 +43,11 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { PADDLE_ASSERT(label[i] >= 0 && label[i] < D); - Y[i] = -log(X[i * D + label[i]]); + Y[i] = -clipping_log()(X[i * D + label[i]]); } } +// TODO(qingqing): make zero setting an common function. template __global__ void zero(T* X, const int N) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; @@ -98,7 +114,6 @@ class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel { int D = X->dims()[1]; int block = 512; int grid = (N * D + block - 1) / block; - // TODO(qingqing): make zero an common function. zero<<>>(dXdata, N * D); grid = (N + block - 1) / block; diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index 261cbe2d423ab..e95f5e11678e2 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -21,7 +21,7 @@ namespace operators { using Tensor = framework::Tensor; template -T tolerable_value(T x) { +T tolerable_value(const T x) { static_assert(std::is_floating_point::value, "tolerable_value works only on float, " "double and double double."); @@ -85,6 +85,7 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { const int batch_size = X->dims()[0]; const int class_num = X->dims()[1]; + // TODO(qingqing): make zero setting an common function. memset(dXdata, 0, sizeof(T) * batch_size * class_num); for (int i = 0; i < batch_size; ++i) { int index = i * class_num + label_data[i]; diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index dd65e0f2dc23d..ae23108dfa446 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -64,7 +64,8 @@ def test_all(self): actual = numpy.array(scope.find_var(out_name).get_tensor()) expect = self.outputs[out_name] self.assertTrue( - numpy.allclose(actual, expect), + numpy.allclose( + actual, expect, atol=1e-04), "output name: " + out_name + "has diff") obj.test_all = test_all diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index 5557e0d35820d..d4277f2a42ce2 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -8,9 +8,8 @@ class TestCrossEntropy(unittest.TestCase): __metaclass__ = OpTestMeta def setUp(self): - # TODO this unit test is not passed self.type = "onehot_cross_entropy" - batch_size = 100 + batch_size = 30 class_num = 10 X = numpy.random.random((batch_size, class_num)).astype("float32") label = 5 * numpy.ones(batch_size).astype("int32") @@ -24,7 +23,7 @@ def setUp(self): class CrossEntropyGradOpTest(GradientChecker): def test_check_grad(self): op = create_op("onehot_cross_entropy") - batch_size = 100 + batch_size = 30 class_num = 10 inputs = { "X": numpy.random.uniform( From 8f6c8780a52b3e0a6df85f6d9e3e98366a381692 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Sat, 19 Aug 2017 17:08:04 +0800 Subject: [PATCH 3/3] Replace functor by function. --- paddle/operators/cross_entropy_op.cu | 25 +++++++++---------- paddle/operators/cross_entropy_op.h | 2 +- .../paddle/v2/framework/tests/op_test_util.py | 2 +- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 5f5d2692670b8..d999bfce58c8a 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -21,19 +21,18 @@ namespace operators { using Tensor = framework::Tensor; template -struct clipping_log { - __host__ __device__ T operator()(const T x) { - PADDLE_ASSERT(std::is_floating_point::value); - const T kApproInf = 1e20; - if (x == INFINITY) { - return kApproInf; - } - if (x == -INFINITY) { - return -kApproInf; - } - return x; +__host__ __device__ T clipping_log(const T x) { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + T v = log(x); + if (v == INFINITY) { + return kApproInf; } -}; + if (v == -INFINITY) { + return -kApproInf; + } + return v; +} template __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, @@ -43,7 +42,7 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { PADDLE_ASSERT(label[i] >= 0 && label[i] < D); - Y[i] = -clipping_log()(X[i * D + label[i]]); + Y[i] = -clipping_log(X[i * D + label[i]]); } } diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index e95f5e11678e2..eb4d1348de1d9 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -21,7 +21,7 @@ namespace operators { using Tensor = framework::Tensor; template -T tolerable_value(const T x) { +inline T tolerable_value(const T x) { static_assert(std::is_floating_point::value, "tolerable_value works only on float, " "double and double double."); diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index ae23108dfa446..3bc05a0feccbb 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -65,7 +65,7 @@ def test_all(self): expect = self.outputs[out_name] self.assertTrue( numpy.allclose( - actual, expect, atol=1e-04), + actual, expect, atol=1e-05), "output name: " + out_name + "has diff") obj.test_all = test_all