-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
improve dropout #29465
improve dropout #29465
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ limitations under the License. */ | |
#include <thrust/iterator/counting_iterator.h> | ||
#include <thrust/random.h> | ||
#include <thrust/transform.h> | ||
#include <algorithm> | ||
#include <string> | ||
#include "paddle/fluid/memory/memcpy.h" | ||
#include "paddle/fluid/operators/dropout_op.h" | ||
|
@@ -26,60 +27,35 @@ limitations under the License. */ | |
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T, typename MaskType> | ||
__global__ void RandomGenerator(const size_t n, const int seed, | ||
const float dropout_prob, const T* src, | ||
MaskType* mask_data, T* dst, | ||
bool is_upscale_in_train) { | ||
curandStatePhilox4_32_10_t state; | ||
int idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
int step_size = 0; | ||
// aligned vector generates vectorized load/store on CUDA | ||
template <typename T, int Size> | ||
struct alignas(sizeof(T) * Size) AlignedVector { | ||
T val[Size]; | ||
}; | ||
|
||
MaskType mask; | ||
T dest; | ||
for (; idx < n; idx += blockDim.x * gridDim.x) { | ||
T s = src[idx]; | ||
if (step_size == 0) { | ||
curand_init(seed, idx, idx, &state); | ||
step_size = blockDim.x * gridDim.x; | ||
} else { | ||
curand_init(seed, idx, step_size, &state); | ||
} | ||
if (curand_uniform(&state) < dropout_prob) { | ||
mask = 0; | ||
dest = 0; | ||
} else { | ||
mask = 1; | ||
if (is_upscale_in_train) { | ||
dest = s / static_cast<T>(1.0f - dropout_prob); | ||
} else { | ||
dest = s; | ||
} | ||
} | ||
mask_data[idx] = mask; | ||
dst[idx] = dest; | ||
template <typename T> | ||
inline int VectorizedSize(const T* pointer) { | ||
uint64_t address = reinterpret_cast<uint64_t>(pointer); | ||
constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value; // NOLINT | ||
if (address % vec4 == 0) { | ||
return 4; | ||
} | ||
return 1; | ||
} | ||
|
||
template <typename T, typename MaskType> | ||
__global__ void RandomGeneratorWithSeed(const size_t n, const int* seed, | ||
const float dropout_prob, const T* src, | ||
MaskType* mask_data, T* dst, | ||
bool is_upscale_in_train) { | ||
__global__ void RandomGenerator(const size_t n, uint64_t seed, | ||
const float dropout_prob, const T* src, | ||
MaskType* mask_data, T* dst, | ||
bool is_upscale_in_train, uint64_t increment) { | ||
curandStatePhilox4_32_10_t state; | ||
int idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
int step_size = 0; | ||
curand_init(seed, idx, increment, &state); | ||
|
||
MaskType mask; | ||
T dest; | ||
for (; idx < n; idx += blockDim.x * gridDim.x) { | ||
T s = src[idx]; | ||
if (step_size == 0) { | ||
curand_init(seed[0], idx, idx, &state); | ||
step_size = blockDim.x * gridDim.x; | ||
} else { | ||
curand_init(seed[0], idx, step_size, &state); | ||
} | ||
if (curand_uniform(&state) < dropout_prob) { | ||
mask = 0; | ||
dest = 0; | ||
|
@@ -96,39 +72,49 @@ __global__ void RandomGeneratorWithSeed(const size_t n, const int* seed, | |
} | ||
} | ||
|
||
template <typename T, typename MaskType> | ||
__global__ void RandomGeneratorWithGenerator(const size_t n, uint64_t seed, | ||
const float dropout_prob, | ||
const T* src, MaskType* mask_data, | ||
T* dst, bool is_upscale_in_train, | ||
uint64_t increment) { | ||
template <typename T, typename MaskType, int VecSize> | ||
__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, | ||
const float dropout_prob, | ||
const T* src, MaskType* mask_data, | ||
T* dst, bool is_upscale_in_train, | ||
uint64_t increment) { | ||
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
curandStatePhilox4_32_10_t state; | ||
int idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
int step_size = 0; | ||
curand_init(seed, idx, increment, &state); | ||
|
||
MaskType mask; | ||
T dest; | ||
for (; idx < n; idx += blockDim.x * gridDim.x) { | ||
T s = src[idx]; | ||
if (step_size == 0) { | ||
curand_init(seed, idx, increment, &state); | ||
step_size = blockDim.x * gridDim.x; | ||
} else { | ||
curand_init(seed, idx, increment, &state); | ||
} | ||
if (curand_uniform(&state) < dropout_prob) { | ||
mask = 0; | ||
dest = 0; | ||
} else { | ||
mask = 1; | ||
if (is_upscale_in_train) { | ||
dest = s / static_cast<T>(1.0f - dropout_prob); | ||
using LoadT = AlignedVector<T, VecSize>; | ||
using MaskLoadT = AlignedVector<MaskType, VecSize>; | ||
T factor = static_cast<T>(1.0f / (1.0f - dropout_prob)); | ||
for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) { | ||
T src_vec[VecSize]; | ||
LoadT* value = reinterpret_cast<LoadT*>(&src_vec); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 有没有试过直接用 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我试了,这里不能这么使用,问了下wangchao,看上去是不支持, |
||
*value = *reinterpret_cast<const LoadT*>(&src[i]); | ||
float4 rand = curand_uniform4(&state); | ||
|
||
T dest_vec[VecSize]; | ||
MaskType mask_vec[VecSize]; | ||
|
||
#pragma unroll | ||
for (int ii = 0; ii < VecSize; ii++) { | ||
if ((&rand.x)[ii] < dropout_prob) { | ||
dest_vec[ii] = 0; | ||
mask_vec[ii] = 0; | ||
} else { | ||
dest = s; | ||
if (is_upscale_in_train) { | ||
dest_vec[ii] = src_vec[ii] * factor; | ||
} else { | ||
dest_vec[ii] = src_vec[ii]; | ||
} | ||
mask_vec[ii] = 1; | ||
} | ||
} | ||
mask_data[idx] = mask; | ||
dst[idx] = dest; | ||
|
||
*(reinterpret_cast<LoadT*>(&dst[i])) = | ||
*reinterpret_cast<LoadT*>(&dest_vec[0]); | ||
*(reinterpret_cast<MaskLoadT*>(&mask_data[i])) = | ||
*reinterpret_cast<MaskLoadT*>(&mask_vec[0]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个写回好长。。。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个我好像也不确定还有什么更好的改法 |
||
} | ||
} | ||
|
||
|
@@ -170,36 +156,57 @@ class GPUDropoutKernel : public framework::OpKernel<T> { | |
|
||
int threads = 512; | ||
int grid = (x_numel + threads - 1) / threads; | ||
const auto& dev_ctx = context.cuda_device_context(); | ||
int blocks_per_sm = | ||
dev_ctx.GetMaxPhysicalThreadCount() / dev_ctx.GetSMCount() / threads; | ||
grid = std::min(dev_ctx.GetSMCount() * blocks_per_sm, grid); | ||
|
||
// increment is used to set the args(offset) of curand_init, which defines | ||
// offset in subsequence. | ||
// The detail: | ||
// https://docs.nvidia.com/cuda/curand/device-api-overview.html | ||
// Increment should be at least the number of curand() random numbers used | ||
// in each thread to avoid the random number generated this time being the | ||
// same as the previous calls. | ||
uint64_t seed_data; | ||
uint64_t increment; | ||
int vec_size = VectorizedSize<T>(x_data); | ||
auto offset = | ||
((x_numel - 1) / (threads * grid * vec_size) + 1) * vec_size; | ||
int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()) | ||
.GetDeviceId(); | ||
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个随机数种子生成器,是说CPUGenerator生成的种子Tensor是在CPU上,CUDAGenerator生成的种子Tensor在GPU上吗?我们的种子看来是希望在CPU上访问的,那是不是可以直接用CPUGenerator? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里应该不是区分种子是否在CPU还是GPU上。 原来的RandomGeneratorWithGenerator kernel需要用到increment,来设置curand_init中的offset参数。op里会用到下面的IncrementOffset去获得当前的offset。因为dropout在训练中会多次调用,没调用一次,就应该改变offset,跳过这一次产生的这些随机数。这样下一次调用drop out产生的随机数和之前就不会发生重叠。
|
||
|
||
if (seed && platform::is_gpu_place(seed->place())) { | ||
auto seed_gpu_data = seed->data<int>(); | ||
RandomGeneratorWithSeed<T, uint8_t><<<grid, threads, 0, stream>>>( | ||
size, seed_gpu_data, dropout_prob, x_data, mask_data, y_data, | ||
upscale_in_train); | ||
return; | ||
} | ||
int seed_data; | ||
std::random_device rnd; | ||
if (seed) { | ||
seed_data = *(seed->data<int>()); | ||
framework::Tensor seed_cpu_tensor; | ||
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor); | ||
seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]); | ||
increment = offset; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 感觉这个increment跟原来的设置不完全一样,这个参数的影响是什么?代码里面能不能加个注释说明下。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 其他的2个kernel并不会用到这个参数,为了将3个kernel进行统一,所以其他的情况也设置了。increment作用就是用来设置curand_init的offset,上面已经解释过了原理。 |
||
} else if (gen_cuda->GetIsInitPy() && (!context.Attr<bool>("fix_seed"))) { | ||
auto seed_offset = gen_cuda->IncrementOffset(offset); | ||
seed_data = seed_offset.first; | ||
increment = seed_offset.second; | ||
} else { | ||
seed_data = | ||
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd(); | ||
if (seed) { | ||
seed_data = *(seed->data<int>()); | ||
} else { | ||
std::random_device rnd; | ||
seed_data = context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") | ||
: rnd(); | ||
} | ||
increment = offset; | ||
} | ||
|
||
int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()) | ||
.GetDeviceId(); | ||
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); | ||
if (gen_cuda->GetIsInitPy() && (!context.Attr<bool>("fix_seed"))) { | ||
auto seed_offset = gen_cuda->IncrementOffset(1); | ||
RandomGeneratorWithGenerator<T, uint8_t><<<grid, threads, 0, stream>>>( | ||
size, seed_offset.first, dropout_prob, x_data, mask_data, y_data, | ||
upscale_in_train, seed_offset.second); | ||
return; | ||
if (vec_size == 4) { | ||
VectorizedRandomGenerator<T, uint8_t, 4><<<grid, threads, 0, stream>>>( | ||
size, seed_data, dropout_prob, x_data, mask_data, y_data, | ||
upscale_in_train, increment); | ||
} else { | ||
RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>( | ||
size, seed_data, dropout_prob, x_data, mask_data, y_data, | ||
upscale_in_train, increment); | ||
} | ||
|
||
RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>( | ||
size, seed_data, dropout_prob, x_data, mask_data, y_data, | ||
upscale_in_train); | ||
} else { | ||
auto X = EigenMatrix<T>::Reshape(*x, 1); | ||
auto Y = EigenMatrix<T>::Reshape(*y, 1); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个修改是等价的吗?原来curand_init是放在for循环里面,每次迭代都会调用一次。increment参数是指什么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我认为这样改是没有问题的,并不需要每次迭代里面去初始化,并且动态修改offset,以下详细解释:
curand_init根据给定的seed、subsequence和offset设置初始的state。在生成随机数时,根据subsequence和offset来决定开始取随机数的位置。
根据原来的逻辑:它每次迭代重新init,只是重新设置了取随机数的位置,也就是(sequence的索引和offset)。
首先用一段程序解释curand_init 原来写在for循环里,和现在写在for循环外的区别:
idx取值为0~7。
第一种写法是:
第二种写法是:
curand_uniform每调用一次,就往下取一个数。因此前8个数,2种写法都是相同的。从上面的分析可以看到,a[8]和a[9]两种写法取得数不同。
只是取的随机数不同而已,但是我理解并不影响这个op的功能。而且我认为原始的写法没有必要。
最后再说PR里的写法:
offset是在调用cuda kernel前算好的,计算一个increment用来设置offset,increment不能小于每个线程产生的随机数个数。每一次调用drop out后,会设置offset = offset + increament。 那么下一次调用drop out时,每个线程就会跳过前一次drop out时取过的那些随机数。例如:
第一次调用drop out产生10个数,当grid=4, thread=2时,会将curand_init的当前的offset设置为0,同时会将线程的offset置为 offset= 0 + 2,因为本次每个线程最多会产生2个数。
第二次调用drop out时,curand_init的offset就会使用2,然后再设置offset = offset + increament(increament根据这次drop out每个线程产生的随机数确定)。从上一次取过的随机数之后去取数,以免和上一次调用drop out产生的随机数出现重叠。