Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize the forward of log_softmax for the case when axis is not the last dimention. #32396

Merged
merged 20 commits into from
Jul 6, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
8f532b0
Merge pull request #1 from PaddlePaddle/develop
AshburnLee Sep 8, 2020
5b5804d
Merge pull request #2 from PaddlePaddle/develop
AshburnLee Sep 17, 2020
cee2470
Merge pull request #3 from PaddlePaddle/develop
AshburnLee Sep 30, 2020
5be3a45
Merge pull request #4 from PaddlePaddle/develop
AshburnLee Oct 13, 2020
a1d92b7
Merge pull request #5 from PaddlePaddle/develop
AshburnLee Oct 20, 2020
e674a5d
Merge pull request #6 from PaddlePaddle/develop
AshburnLee Nov 15, 2020
855d00b
Merge pull request #7 from PaddlePaddle/develop
AshburnLee Nov 18, 2020
7cb2c97
Merge pull request #8 from PaddlePaddle/develop
AshburnLee Mar 31, 2021
db9fc91
Merge pull request #9 from PaddlePaddle/develop
AshburnLee Apr 7, 2021
d521199
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Apr 20, 2021
c260064
Log_softmax backward case: axis!=-1
AshburnLee Apr 20, 2021
91dfddf
Made some minor changes
AshburnLee Apr 20, 2021
9061786
Reply review comments
AshburnLee Apr 29, 2021
47c8c6d
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Apr 29, 2021
42fd6f9
Fix a bug
AshburnLee Apr 29, 2021
2976137
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Jun 3, 2021
2755c26
Reply review comments
AshburnLee Jun 4, 2021
41914f6
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Jun 4, 2021
dee66ac
rearrange comments in the source file
AshburnLee Jun 5, 2021
89c6128
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Jun 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 178 additions & 12 deletions paddle/fluid/operators/log_softmax_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <limits>
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/log_softmax_op.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/fluid/platform/cuda_device_function.h"

namespace paddle {
Expand Down Expand Up @@ -142,6 +143,170 @@ void LaunchSoftmaxForwardForLastAxis(T *dst, const T *src, int dim_size,
}
}

AshburnLee marked this conversation as resolved.
Show resolved Hide resolved
// This reduction is not Block-wise reduction, only reduce along block.x.
// therefore the shared mem has offsets for different block.y.
template <typename T>
__forceinline__ __device__ T BlockDimxReduceMax(T *shared, T val) {
AshburnLee marked this conversation as resolved.
Show resolved Hide resolved
// shared mem have #inner_size position offsets
shared += threadIdx.y * blockDim.x;
AshburnLee marked this conversation as resolved.
Show resolved Hide resolved
__syncthreads();
AshburnLee marked this conversation as resolved.
Show resolved Hide resolved
shared[threadIdx.x] = val;

// block reduce operation
int offset = blockDim.x / 2;
math::MaxFunctor<T> max;
while (offset > 0) {
__syncthreads();
if (threadIdx.x < offset) {
shared[threadIdx.x] =
max(shared[threadIdx.x], shared[threadIdx.x + offset]);
}
offset /= 2;
}
__syncthreads();
return shared[0];
}

template <typename T>
__forceinline__ __device__ T BlockDimxReduceAdd(T *shared, T val) {
shared += threadIdx.y * blockDim.x;
__syncthreads();
shared[threadIdx.x] = val;
int offset = blockDim.x / 2;
math::AddFunctor<T> add;

while (offset > 0) {
__syncthreads();
if (threadIdx.x < offset) {
shared[threadIdx.x] =
add(shared[threadIdx.x], shared[threadIdx.x + offset]);
}
offset /= 2;
}
__syncthreads();
return shared[0];
}

template <typename T, typename AccT>
__global__ void LogSoftmaxForwardCUDAKernelNotLastAxis(
T *output, const T *input, int outer_size, int dim_size, int inner_size) {
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<AccT *>(smem);

const int outer_stride = inner_size * dim_size;
const int dim_stride = inner_size;

for (int x_id = blockIdx.x; x_id < outer_size; x_id += gridDim.x) {
for (int y_id = blockIdx.y * blockDim.y + threadIdx.y; y_id < inner_size;
y_id += blockDim.y * gridDim.y) {
const int data_offset = x_id * outer_stride + y_id;
// When blockDim.x==1, no block.x-reduction opetaions are needed.
// And threadIdx.x is 0 all the time, so the for-loops below are literally
// loops (No parallel executions).
// Loop all elements along axis and calculate the Max, Sum and
// (input[id]-Max-log(Sum))
// to get the final log_softmax values along that axis.
// 1. reduce max
AccT max_value = -std::numeric_limits<AccT>::infinity();
for (int d = threadIdx.x; d < dim_size; d += blockDim.x) {
const AccT value =
static_cast<AccT>(input[data_offset + d * dim_stride]);
AshburnLee marked this conversation as resolved.
Show resolved Hide resolved
max_value = math::MaxFunctor<AccT>()(max_value, value);
}
if (blockDim.x > 1) {
max_value = BlockDimxReduceMax<AccT>(sdata, max_value);
}

// 2. reduce sum
AccT sum = 0;
for (int d = threadIdx.x; d < dim_size; d += blockDim.x) {
sum += std::exp(static_cast<AccT>(input[data_offset + d * dim_stride]) -
AshburnLee marked this conversation as resolved.
Show resolved Hide resolved
max_value);
}
if (blockDim.x > 1) {
sum = BlockDimxReduceAdd<AccT>(sdata, sum);
}

// 3. input-max-log_sum and write to output
for (int d = threadIdx.x; d < dim_size; d += blockDim.x) {
output[data_offset + d * dim_stride] = static_cast<T>(
static_cast<AccT>(input[data_offset + d * dim_stride]) - max_value -
std::log(sum));
}
}
}
}

// block.y covers inner_size. Threads along the x axis process dim_size
// elements,
// and make sure not to exceed the 1024 threads per block.
inline dim3 GetBlockSize(int dim_size, int inner_size) {
int inner_threads = inner_size;
inner_threads = std::min(inner_threads, 1024);
int dim_threads = 1;

while (dim_threads * inner_threads <= 1024 && dim_threads <= dim_size) {
AshburnLee marked this conversation as resolved.
Show resolved Hide resolved
dim_threads *= 2;
}
dim_threads /= 2;
return dim3(dim_threads, inner_threads);
}

// First cover the y axis as many blocks as possible.
// Then cover the x axis as many blocks as possible,
// and make sure not to exceed the max_active_blocks.
inline dim3 GetGridSize(dim3 block, int max_active_blocks, int outer_size,
int dim_size, int inner_size) {
int inner_blocks = (inner_size + block.y - 1) / block.y;
if (inner_blocks > max_active_blocks) inner_blocks = max_active_blocks;

int outer_blocks = (max_active_blocks + inner_blocks - 1) / inner_blocks;
if (outer_blocks > outer_size) outer_blocks = outer_size;
return dim3(outer_blocks, inner_blocks);
}

// When designing grid size and block size, priority is given to block size,
// and grid will be determined according to the maximum number of active blocks,
// which will calculated by CUDA occupancy API.
template <typename T, typename Kernel>
void ComputeLaunchConfigure(Kernel k, int outer_size, int dim_size,
AshburnLee marked this conversation as resolved.
Show resolved Hide resolved
int inner_size, dim3 &grid, dim3 &block,
int &shared_mem, int num_sm) {
AshburnLee marked this conversation as resolved.
Show resolved Hide resolved
block = GetBlockSize(dim_size, inner_size);
int block_threads = block.x * block.y;
shared_mem = block.x == 1 ? 0 : block_threads * sizeof(T);
AshburnLee marked this conversation as resolved.
Show resolved Hide resolved
int max_active_blocks;
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, k, block_threads, shared_mem));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, k, block_threads, shared_mem));
#endif
AshburnLee marked this conversation as resolved.
Show resolved Hide resolved
max_active_blocks *= num_sm;
grid =
GetGridSize(block, max_active_blocks, outer_size, dim_size, inner_size);
}

template <typename T, typename MPDType>
void LaunchLogSoftmaxForwardCUDAKernelNotLastAxis(T *output_data,
const T *input_data,
int outer_size, int dim_size,
int inner_size, int num_sm,
gpuStream_t stream) {
AshburnLee marked this conversation as resolved.
Show resolved Hide resolved
int shared_mem;
dim3 grid;
dim3 block;

ComputeLaunchConfigure<MPDType>(
&LogSoftmaxForwardCUDAKernelNotLastAxis<T, MPDType>, outer_size, dim_size,
inner_size, grid, block, shared_mem, num_sm);

LogSoftmaxForwardCUDAKernelNotLastAxis<
T, MPDType><<<grid, block, shared_mem, stream>>>(
output_data, input_data, outer_size, dim_size, inner_size);
}

template <typename T>
class LogSoftmaxKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
Expand All @@ -164,14 +329,15 @@ class LogSoftmaxKernel<platform::CUDADeviceContext, T>
}
int outer_size = SizeToAxis(axis, x->dims());
gpuStream_t stream = context.cuda_device_context().stream();
int num_sm = context.cuda_device_context().GetSMCount();

if (inner_size == 1 && dim_size <= 1024 && dim_size * sizeof(T) <= 4096) {
LaunchSoftmaxForwardForLastAxis<T, MPDType>(output_data, input_data,
dim_size, outer_size, stream);
} else {
LogSoftmaxFunctor<platform::CUDADeviceContext, T>()(
context.template device_context<platform::CUDADeviceContext>(), x,
out, axis);
LaunchLogSoftmaxForwardCUDAKernelNotLastAxis<T, MPDType>(
output_data, input_data, outer_size, dim_size, inner_size, num_sm,
stream);
}
}
};
Expand All @@ -195,7 +361,7 @@ __global__ void ComputeLogSoftmaxBackwardInWarp(const T *output,
constexpr int warp_iter = near_greater_power_of_two / kernel_warp_size;
int batch_id = blockDim.y * blockIdx.x + threadIdx.y;

int thread_in_warp_idx = threadIdx.x % kernel_warp_size;
int thread_in_warp_idx = threadIdx.x;
AshburnLee marked this conversation as resolved.
Show resolved Hide resolved

// 1.read data from global memory to registers
AccT output_register[warp_iter];
Expand All @@ -209,8 +375,8 @@ __global__ void ComputeLogSoftmaxBackwardInWarp(const T *output,
grad_output_register[iter] = static_cast<AccT>(
grad_output[batch_id * element_count + element_index]);
} else {
output_register[iter] = AccT(0);
grad_output_register[iter] = AccT(0);
output_register[iter] = static_cast<AccT>(0);
grad_output_register[iter] = static_cast<AccT>(0);
}
}

Expand Down Expand Up @@ -271,13 +437,13 @@ class LogSoftmaxGradKernel<platform::CUDADeviceContext, T>
public:
void Compute(const framework::ExecutionContext &context) const override {
const auto *out = context.Input<framework::Tensor>("Out");
const auto *g_out =
const auto *d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *g_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto *d_x = context.Output<framework::Tensor>(framework::GradVarName("X"));

const auto *out_data = out->data<T>();
const auto *g_out_data = g_out->data<T>();
auto *g_x_data = g_x->mutable_data<T>(context.GetPlace());
const auto *d_out_data = d_out->data<T>();
auto *d_x_data = d_x->mutable_data<T>(context.GetPlace());

const int rank = out->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
Expand All @@ -292,11 +458,11 @@ class LogSoftmaxGradKernel<platform::CUDADeviceContext, T>

if (inner_size == 1 && dim_size <= 1024 && dim_size * sizeof(T) <= 4096) {
LaunchSoftmaxBackwardForLastAxis<T, MPDType>(
g_x_data, g_out_data, out_data, dim_size, outer_size, stream);
d_x_data, d_out_data, out_data, dim_size, outer_size, stream);
} else {
LogSoftmaxGradFunctor<platform::CUDADeviceContext, T>()(
context.template device_context<platform::CUDADeviceContext>(), out,
g_out, g_x, axis);
d_out, d_x, axis);
}
}
};
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/operators/math/functors.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ struct AddFunctor {
inline HOSTDEVICE T operator()(T x, T y) { return x + y; }
};

template <typename T>
struct MaxFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a < b ? b : a; }
};

template <typename T>
struct AddGradFunctor {
inline HOSTDEVICE T Dx(T x, T y) { return static_cast<T>(1.); }
Expand Down