Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sparse]Optimize BatchNorm1D forward in test mode #47736

Merged
merged 13 commits into from
Nov 11, 2022
85 changes: 72 additions & 13 deletions paddle/phi/kernels/gpu/batch_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,40 @@ static __global__ void BNForwardInference(const T *x,
}
}

template <typename T>
static __global__ void InverseVariance(const BatchNormParamType<T> *variance,
const double epsilon,
const int C,
BatchNormParamType<T> *inv_variance) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < C) {
inv_variance[tid] = 1 / sqrt(variance[tid] + epsilon);
}
}

template <typename T, phi::DataLayout layout>
static __global__ void BN1DForwardInference(
const T *x,
const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *inv_variance,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *bias,
const int C,
const int N,
const int HxW,
const double epsilon,
T *y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int num = N * C * HxW;
for (int i = gid; i < num; i += stride) {
const int c = layout == phi::DataLayout::kNCHW ? i / HxW % C : i % C;
BatchNormParamType<T> x_sub_mean =
static_cast<BatchNormParamType<T>>(x[i]) - mean[c];
y[i] = static_cast<T>(scale[c] * x_sub_mean * inv_variance[c] + bias[c]);
}
}

template <typename T, int BlockDim, phi::DataLayout layout>
static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining(
const T *x,
Expand Down Expand Up @@ -795,7 +829,7 @@ void BatchNormKernel(const Context &ctx,
// epsilon));
#else
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 2 ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
if (use_native_kernel) {
const int block_size = 256;
Expand All @@ -814,18 +848,43 @@ void BatchNormKernel(const Context &ctx,
epsilon,
transformed_y.template data<T>());
} else {
BNForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(),
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
if (x_dims.size() == 2) {
DenseTensor inv_var = phi::Empty<BatchNormParamType<T>>(ctx, {C});
auto *inv_var_ptr = inv_var.data<BatchNormParamType<T>>();
const int threads = 512 > C ? C : 512;
const int blocks = (C + 511) / 512;
InverseVariance<T><<<blocks, threads>>>(
est_var->template data<BatchNormParamType<T>>(),
epsilon,
C,
inv_var_ptr);
BN1DForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
// est_var->template data<BatchNormParamType<T>>(),
inv_var_ptr,
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
} else {
BNForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(),
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
}
}
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
Expand Down