From cbb58bbdd5eea6e40aecc434de8e4d1dd0fe3894 Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Tue, 13 Apr 2021 10:33:03 +0800 Subject: [PATCH] optimize check_finite_and_unscale_op by fused kernel, test=develop (#31954) --- .../amp/check_finite_and_unscale_op.cu | 105 ++++++++++++++---- 1 file changed, 84 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu index e28a3c1b6da81b..c0b9b592ec5534 100644 --- a/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu @@ -28,18 +28,48 @@ __global__ void InverseAndMemset(const T* s, T* o, bool* found_inf) { } template -__global__ void CheckFiniteAndUnscale(const T* in, const MT* scale, int num, - bool* found_inf, T* out) { - const int idx = threadIdx.x + blockIdx.x * blockDim.x; - - if (idx < num) { - MT val = static_cast(in[idx]) * (*scale); +__global__ void CheckFiniteAndUnscale(const T** xs, const MT* scale, + int64_t size, int64_t* starts, + bool* found_inf, T** outs) { + const int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; + + // copy starts array from global memory to shared memory + extern __shared__ int64_t s_starts[]; + for (int i = threadIdx.x; i <= size; i += blockDim.x) { + s_starts[i] = starts[i]; + } + __syncthreads(); + + const int64_t num = s_starts[size]; + int pre_xs_index = 0; + bool t_found_inf = false; + const MT t_scale = *scale; + for (int64_t idx = tid; idx < num; idx += gridDim.x * blockDim.x) { + // get the xs's index of thread + int xs_index = pre_xs_index; + while (idx < s_starts[xs_index]) xs_index++; + // avoid some tensor's numel is zero + while (idx >= s_starts[xs_index]) xs_index++; + pre_xs_index = xs_index - 1; + + // get in data and out data + const T* in = xs[pre_xs_index]; + T* out = outs[pre_xs_index]; + int64_t in_idx = idx - s_starts[pre_xs_index]; + + // Unscale + MT val = static_cast(in[in_idx]) * t_scale; T narrow_val = static_cast(val); - out[idx] = narrow_val; + out[in_idx] = narrow_val; + + // CheckFinite if (!isfinite(narrow_val)) { - *found_inf = true; + t_found_inf = true; } } + if (t_found_inf) { + *found_inf = true; + } } template @@ -65,20 +95,53 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel { InverseAndMemset<<<1, 1, 0, dev_ctx.stream()>>>( scale_data, inverse_scale_v, found_inf_data); - for (size_t i = 0; i < xs.size(); ++i) { - const auto* x = xs[i]; - auto* out = outs[i]; - const T* x_data = x->data(); - T* out_data = out->mutable_data(dev_ctx.GetPlace()); - - int num = x->numel(); - int block = 1024; - int grid = (num + block - 1) / block; - VLOG(3) << "launch kernel"; - CheckFiniteAndUnscale<<>>( - x_data, inverse_scale_v, num, found_inf_data, out_data); - VLOG(3) << "finish kernel"; + size_t xs_size = xs.size(); + // calculate each tensor's start index and copy to device + auto h_starts_tensor = + memory::Alloc(platform::CPUPlace(), (xs_size + 1) * sizeof(int64_t)); + int64_t* h_starts = reinterpret_cast(h_starts_tensor->ptr()); + + auto d_starts_tensor = + memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t)); + int64_t* d_starts = reinterpret_cast(d_starts_tensor->ptr()); + + h_starts[0] = 0; + for (int i = 1; i <= xs_size; i++) { + // the start index value of each tensor is + // the sum of previous tensor's size + h_starts[i] = h_starts[i - 1] + xs[i - 1]->numel(); + } + int64_t total_num = h_starts[xs_size]; + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + d_starts, platform::CPUPlace(), h_starts, + (xs_size + 1) * sizeof(int64_t), dev_ctx.stream()); + + // copy each tensor's data address to device + auto h_mem = memory::Alloc(platform::CPUPlace(), 2 * xs_size * sizeof(T*)); + const T** h_xs = reinterpret_cast(h_mem->ptr()); + T** h_outs = reinterpret_cast(h_mem->ptr()) + xs_size; + + auto d_mem = memory::Alloc(dev_ctx, 2 * xs_size * sizeof(T*)); + const T** d_xs = reinterpret_cast(d_mem->ptr()); + T** d_outs = reinterpret_cast(d_mem->ptr()) + xs_size; + + for (size_t i = 0; i < xs_size; ++i) { + h_xs[i] = xs[i]->data(); + h_outs[i] = outs[i]->mutable_data(dev_ctx.GetPlace()); } + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), d_xs, + platform::CPUPlace(), h_xs, 2 * xs_size * sizeof(T*), + dev_ctx.stream()); + + // Launch Kernel + int block = 1024; + int block_num = block * 20; // each thread deal with 20 number + int grid = (total_num + block_num - 1) / block_num; + VLOG(3) << "launch kernel"; + CheckFiniteAndUnscale<<< + grid, block, (xs_size + 1) * sizeof(int64_t), dev_ctx.stream()>>>( + d_xs, inverse_scale_v, xs_size, d_starts, found_inf_data, d_outs); + VLOG(3) << "finish kernel"; } }; } // namespace operators