Skip to content

Commit

Permalink
optimize check_finite_and_unscale_op by fused kernel, test=develop (P…
Browse files Browse the repository at this point in the history
  • Loading branch information
thisjiang committed Apr 13, 2021
1 parent 5c7ad3b commit cbb58bb
Showing 1 changed file with 84 additions and 21 deletions.
105 changes: 84 additions & 21 deletions paddle/fluid/operators/amp/check_finite_and_unscale_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,48 @@ __global__ void InverseAndMemset(const T* s, T* o, bool* found_inf) {
}

template <typename T, typename MT>
__global__ void CheckFiniteAndUnscale(const T* in, const MT* scale, int num,
bool* found_inf, T* out) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;

if (idx < num) {
MT val = static_cast<MT>(in[idx]) * (*scale);
__global__ void CheckFiniteAndUnscale(const T** xs, const MT* scale,
int64_t size, int64_t* starts,
bool* found_inf, T** outs) {
const int64_t tid = threadIdx.x + blockIdx.x * blockDim.x;

// copy starts array from global memory to shared memory
extern __shared__ int64_t s_starts[];
for (int i = threadIdx.x; i <= size; i += blockDim.x) {
s_starts[i] = starts[i];
}
__syncthreads();

const int64_t num = s_starts[size];
int pre_xs_index = 0;
bool t_found_inf = false;
const MT t_scale = *scale;
for (int64_t idx = tid; idx < num; idx += gridDim.x * blockDim.x) {
// get the xs's index of thread
int xs_index = pre_xs_index;
while (idx < s_starts[xs_index]) xs_index++;
// avoid some tensor's numel is zero
while (idx >= s_starts[xs_index]) xs_index++;
pre_xs_index = xs_index - 1;

// get in data and out data
const T* in = xs[pre_xs_index];
T* out = outs[pre_xs_index];
int64_t in_idx = idx - s_starts[pre_xs_index];

// Unscale
MT val = static_cast<MT>(in[in_idx]) * t_scale;
T narrow_val = static_cast<T>(val);
out[idx] = narrow_val;
out[in_idx] = narrow_val;

// CheckFinite
if (!isfinite(narrow_val)) {
*found_inf = true;
t_found_inf = true;
}
}
if (t_found_inf) {
*found_inf = true;
}
}

template <typename T>
Expand All @@ -65,20 +95,53 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
InverseAndMemset<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
scale_data, inverse_scale_v, found_inf_data);

for (size_t i = 0; i < xs.size(); ++i) {
const auto* x = xs[i];
auto* out = outs[i];
const T* x_data = x->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());

int num = x->numel();
int block = 1024;
int grid = (num + block - 1) / block;
VLOG(3) << "launch kernel";
CheckFiniteAndUnscale<T, MPDType><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, inverse_scale_v, num, found_inf_data, out_data);
VLOG(3) << "finish kernel";
size_t xs_size = xs.size();
// calculate each tensor's start index and copy to device
auto h_starts_tensor =
memory::Alloc(platform::CPUPlace(), (xs_size + 1) * sizeof(int64_t));
int64_t* h_starts = reinterpret_cast<int64_t*>(h_starts_tensor->ptr());

auto d_starts_tensor =
memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t));
int64_t* d_starts = reinterpret_cast<int64_t*>(d_starts_tensor->ptr());

h_starts[0] = 0;
for (int i = 1; i <= xs_size; i++) {
// the start index value of each tensor is
// the sum of previous tensor's size
h_starts[i] = h_starts[i - 1] + xs[i - 1]->numel();
}
int64_t total_num = h_starts[xs_size];
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
d_starts, platform::CPUPlace(), h_starts,
(xs_size + 1) * sizeof(int64_t), dev_ctx.stream());

// copy each tensor's data address to device
auto h_mem = memory::Alloc(platform::CPUPlace(), 2 * xs_size * sizeof(T*));
const T** h_xs = reinterpret_cast<const T**>(h_mem->ptr());
T** h_outs = reinterpret_cast<T**>(h_mem->ptr()) + xs_size;

auto d_mem = memory::Alloc(dev_ctx, 2 * xs_size * sizeof(T*));
const T** d_xs = reinterpret_cast<const T**>(d_mem->ptr());
T** d_outs = reinterpret_cast<T**>(d_mem->ptr()) + xs_size;

for (size_t i = 0; i < xs_size; ++i) {
h_xs[i] = xs[i]->data<T>();
h_outs[i] = outs[i]->mutable_data<T>(dev_ctx.GetPlace());
}
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), d_xs,
platform::CPUPlace(), h_xs, 2 * xs_size * sizeof(T*),
dev_ctx.stream());

// Launch Kernel
int block = 1024;
int block_num = block * 20; // each thread deal with 20 number
int grid = (total_num + block_num - 1) / block_num;
VLOG(3) << "launch kernel";
CheckFiniteAndUnscale<T, MPDType><<<
grid, block, (xs_size + 1) * sizeof(int64_t), dev_ctx.stream()>>>(
d_xs, inverse_scale_v, xs_size, d_starts, found_inf_data, d_outs);
VLOG(3) << "finish kernel";
}
};
} // namespace operators
Expand Down

0 comments on commit cbb58bb

Please sign in to comment.