Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve dropout #29465

Merged
merged 4 commits into from
Dec 11, 2020
Merged

improve dropout #29465

merged 4 commits into from
Dec 11, 2020

Conversation

zhangting2020
Copy link
Contributor

@zhangting2020 zhangting2020 commented Dec 8, 2020

PR types

Performance optimization

PR changes

OPs

Describe

improve dropout
  • 将原来OP里的3个kernel统一为了一个,因为原来的3个kernel没有太大差异:
    • 原来的RandomGenerator和RandomGeneratorWithSeed逻辑上无差异,只是函数的seed参数类型不同,前者是const int seed,后者是const int* seed。RandomGeneratorWithSeed用到的是Input(seed)。这两个函数完全可以统一为一个。
    • RandomGeneratorWithGenerator的逻辑和前两个也相同,只是该函数使用了CUDAGenerator产生的seed和offset去设置curand_init。函数多了一个increment参数,用来设置kernel中curand_init的offset参数。
    • 以上3个函数接口和逻辑上是可以进行统一的。
  • 代码逻辑修改:
    • curand_init用来生成state,但是原始的逻辑中将curand_init放在了for循环里,并没有起到实质性的意义,因为只是在多次迭代时,修改了curand_init的offset参数而已,在PR 的comments里有更详细的描述解释这个接口的原理。放在for循环内或者外,只是在产生随机数时从序列的哪个位置取数上有差异,但是这不影响op的功能。

Op-Benchmark Forward Results

image

Bert Forward Results

image

@paddle-bot-old
Copy link

paddle-bot-old bot commented Dec 8, 2020

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

wangchaochaohu
wangchaochaohu previously approved these changes Dec 9, 2020
int step_size = 0;
// aligned vector generates vectorized load/store on CUDA
template <typename T, int Size>
struct alignas(sizeof(T) * Size) aligned_vector {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

类型名要用AxxBxx这种命名方式。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

mask_data[idx] = mask;
dst[idx] = dest;
template <typename T>
inline int VectorizedSize(char* pointer) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

输入参数不用传T*

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里修改为了T*

const float dropout_prob, const T* src,
MaskType* mask_data, T* dst,
bool is_upscale_in_train) {
__global__ void RandomGeneratorWithGenerator(const size_t n, uint64_t seed,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认一下函数名RandomGeneratorWithGenerator

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR将原来的3个kernel:RandomGenerator、RandomGeneratorWithGenerator、RandomGeneratorWithSeed,统一为了一个,现在函数名称改为了:RandomGenerator。

curandStatePhilox4_32_10_t state;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = 0;
curand_init(seed, idx, increment, &state);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个修改是等价的吗?原来curand_init是放在for循环里面,每次迭代都会调用一次。increment参数是指什么?

Copy link
Contributor Author

@zhangting2020 zhangting2020 Dec 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我认为这样改是没有问题的,并不需要每次迭代里面去初始化,并且动态修改offset,以下详细解释:

curand_init (
    unsigned long long seed, unsigned long long subsequence,
    unsigned long long offset, curandState_t *state)

curand_init根据给定的seed、subsequence和offset设置初始的state。在生成随机数时,根据subsequence和offset来决定开始取随机数的位置。

|seq0  ...   2^67-1|seq1   ...   2^68-1|seq2   ... 
                           ^
                   |offset |                     (determined by offset parameter)
                           |
                           RNG begins here for given seed, sequence(=seq1), offset

根据原来的逻辑:它每次迭代重新init,只是重新设置了取随机数的位置,也就是(sequence的索引和offset)。

for (; idx < n; idx += blockDim.x * gridDim.x) {
    T s = src[idx];
    if (step_size == 0) {
      curand_init(seed, idx, idx, &state);
      step_size = blockDim.x * gridDim.x;
    } else {
      curand_init(seed, idx, step_size, &state);
    }
...

首先用一段程序解释curand_init 原来写在for循环里,和现在写在for循环外的区别:

  • curand_init写在for循环里:
_global__ void testrand1(unsigned long seed, float *a, int N){
    int idx = threadIdx.x + blockIdx.x * blockDim.x;
    curandState state;
    int step_size = 0;
    for (int i = idx; i < N; i += blockDim.x * gridDim.x) {
        if (step_size == 0) {
          curand_init(seed, idx, i, &state);
          step_size = blockDim.x * gridDim.x;
        } else {
          curand_init(seed, idx, step_size, &state);
        }
        a[i] = curand_uniform(&state);
    }
}
0 0.145468
1 0.926417
2 0.782640
3 0.535606
4 0.650189
5 0.629326
6 0.713179
7 0.448197
8 0.300772
9 0.136307
  • 下面这段是目前PR里的写法:
__global__ void testrand1(unsigned long seed, float *a, int N){
    int idx = threadIdx.x + blockIdx.x * blockDim.x;
    curandState state;
    curand_init(seed, idx, idx, &state);
    int step_size = 0;
    for (int i = idx; i < N; i += blockDim.x * gridDim.x) {
        a[i] = curand_uniform(&state);
    }
}

int main() {

    const int N = 10;

    float *h_a  = (float*)malloc(N*sizeof(float));
    float *d_a;
    cudaMalloc((void**)&d_a, N*sizeof(float));
    int thread = 2;
    // int grid = (N + thread - 1) / thread;
    int grid = 4;
    testrand1<<<grid, thread>>>(1234, d_a, N);
    cudaPeekAtLastError();
    cudaDeviceSynchronize();

    cudaMemcpy(h_a, d_a, N*sizeof(float), cudaMemcpyDeviceToHost);

    for (int i=0; i<N; i++) printf("%i %f\n", i, h_a[i]);

    getchar();
}
0 0.145468
1 0.926417
2 0.782640
3 0.535606
4 0.650189
5 0.629326
6 0.713179
7 0.448197
8 0.434899
9 0.511765

idx取值为0~7。
第一种写法是:

  • idx=0时
    • 第一次迭代:产生子序列sequence0,然后设置随机数a[0]=sequence0[0],同时offset=8
    • 第二次迭代:产生子序列sequence8,然后设置随机数a[8]=sequence8[8];
  • idx=1时
    • 第一次迭代:产生子序列sequence1,然后设置随机数a[1]=sequence1[1],同时offset=8
    • 第二次迭代:产生子序列sequence8,然后设置随机数a[9]=sequence9[8];

第二种写法是:

  • idx=0时,产生子序列sequence0,然后设置随机数a[0]=sequence0[0]和a[8]=sequence0[1];
  • idx=1时,产生子序列sequence1,然后设置随机数a[1]=sequence1[1]和a[9]=sequence1[2];

curand_uniform每调用一次,就往下取一个数。因此前8个数,2种写法都是相同的。从上面的分析可以看到,a[8]和a[9]两种写法取得数不同。

只是取的随机数不同而已,但是我理解并不影响这个op的功能。而且我认为原始的写法没有必要。

最后再说PR里的写法:

  • offset是在调用cuda kernel前算好的,计算一个increment用来设置offset,increment不能小于每个线程产生的随机数个数。每一次调用drop out后,会设置offset = offset + increament。 那么下一次调用drop out时,每个线程就会跳过前一次drop out时取过的那些随机数。例如:

  • 第一次调用drop out产生10个数,当grid=4, thread=2时,会将curand_init的当前的offset设置为0,同时会将线程的offset置为 offset= 0 + 2,因为本次每个线程最多会产生2个数。

  • 第二次调用drop out时,curand_init的offset就会使用2,然后再设置offset = offset + increament(increament根据这次drop out每个线程产生的随机数确定)。从上一次取过的随机数之后去取数,以免和上一次调用drop out产生的随机数出现重叠。

    • idx=0时,产生子序列sequence0,然后设置随机数a[0]=sequence0[2]和a[8]=sequence0[2+1];
    • idx=1时,产生子序列sequence1,然后设置随机数a[1]=sequence1[2]和a[9]=sequence1[2+1];

T* dst, bool is_upscale_in_train,
uint64_t increment) {
template <typename T, typename MaskType, int VecSize>
__global__ void VectorizedRandomGeneratorWithGenerator(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认一下函数名VectorizedRandomGeneratorWithGenerator

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改为了VectorizedRandomGenerator

T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) {
T src_vec[VecSize];
LoadT* value = reinterpret_cast<LoadT*>(&src_vec);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有没有试过直接用float4 __ldg(const float4 *ptr)来加载?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我试了,这里不能这么使用,问了下wangchao,看上去是不支持,

*(reinterpret_cast<LoadT*>(&dst[i])) =
*reinterpret_cast<LoadT*>(&dest_vec[0]);
*(reinterpret_cast<MaskLoadT*>(&mask_data[i])) =
*reinterpret_cast<MaskLoadT*>(&mask_vec[0]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个写回好长。。。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个我好像也不确定还有什么更好的改法

*(reinterpret_cast<MaskLoadT*>(&mask_data[i])) =
*reinterpret_cast<MaskLoadT*>(&mask_vec[0]);

__syncthreads();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有看到线程之间有数据依赖和交互啊,这里为什么要加同步?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前忘记删了,已删除。

((x_numel - 1) / (threads * grid * vec_size) + 1) * vec_size;
int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace())
.GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个随机数种子生成器,是说CPUGenerator生成的种子Tensor是在CPU上,CUDAGenerator生成的种子Tensor在GPU上吗?我们的种子看来是希望在CPU上访问的,那是不是可以直接用CPUGenerator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该不是区分种子是否在CPU还是GPU上。 原来的RandomGeneratorWithGenerator kernel需要用到increment,来设置curand_init中的offset参数。op里会用到下面的IncrementOffset去获得当前的offset。因为dropout在训练中会多次调用,没调用一次,就应该改变offset,跳过这一次产生的这些随机数。这样下一次调用drop out产生的随机数和之前就不会发生重叠。

std::pair<uint64_t, uint64_t> Generator::IncrementOffset(
    uint64_t increament_offset) {
  uint64_t cur_offset = this->state_.thread_offset;
#ifdef PADDLE_WITH_CUDA
  std::lock_guard<std::mutex> lock(this->mu_);

  this->state_.thread_offset += increament_offset;

#else
  PADDLE_THROW(platform::errors::PermissionDenied(
      "Increment Offset only support in CUDA place"));
#endif
  return std::make_pair(static_cast<int>(this->state_.current_seed),
                        cur_offset);
}

framework::Tensor seed_cpu_tensor;
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
increment = offset;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉这个increment跟原来的设置不完全一样,这个参数的影响是什么?代码里面能不能加个注释说明下。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其他的2个kernel并不会用到这个参数,为了将3个kernel进行统一,所以其他的情况也设置了。increment作用就是用来设置curand_init的offset,上面已经解释过了原理。

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhangting2020 zhangting2020 merged commit 6702040 into PaddlePaddle:develop Dec 11, 2020
@zhangting2020 zhangting2020 deleted the dropout branch December 11, 2020 10:09
zhangting2020 added a commit to zhangting2020/Paddle that referenced this pull request Jan 11, 2021
* improve drop out

* add VectorizedRandomGeneratorWithGenerator

* fix bug

* modify according to comments
lanxianghit pushed a commit that referenced this pull request Jan 11, 2021
* improve dropout (#29465)

* improve drop out

* add VectorizedRandomGeneratorWithGenerator

* fix bug

* modify according to comments

* improve dropout grad (#29605)

* improve grad perf

* fix the bug of dropout_grad (#29813)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants