Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix cudnn Dropout reproducibility #17547

Merged
merged 12 commits into from
Apr 10, 2020
Merged

Fix cudnn Dropout reproducibility #17547

merged 12 commits into from
Apr 10, 2020

Conversation

roywei
Copy link
Member

@roywei roywei commented Feb 7, 2020

Fix #15662
This will replace #16532 with an alternative solution.
Please refer to the discussion in #16532 (comment)

  1. Added GetSeed() in GPU random resource, it's similar to CPU version.
  2. During cudnn dropout, check whether random seed has been changed and reset cudnn dropout descriptor's seed if random seed changed.

Benefit:

  1. This avoided generating a random int (tensor on gpu) on GPU, copy to CPU and and pass it to cudnnDropoutGetStatesSize.
  2. get the seed from gpu random resource is fast, we can afford to do it every forward.

Drawback:
Will be affected by #7410, by default mxnet's random seed is fixed.

@roywei
Copy link
Member Author

roywei commented Feb 17, 2020

@DickJC123 @ptrendx could you help take a look again? Thanks

3rdparty/mshadow/mshadow/random.h Outdated Show resolved Hide resolved
3rdparty/mshadow/mshadow/random.h Outdated Show resolved Hide resolved
src/operator/nn/dropout-inl.h Outdated Show resolved Hide resolved
src/operator/rnn.cc Outdated Show resolved Hide resolved
@karan6181
Copy link
Contributor

@DickJC123 @ptrendx Can you please review this PR once again? Thanks!

@apeforest
Copy link
Contributor

@roywei Could you please rebase and update this PR? Thanks!

@ChaiBapchya
Copy link
Contributor

@roywei with CI issue fixed. let's rebase to include the fix. Thanks!

@access2rohit
Copy link
Contributor

restarted CI jobs. @roywei can you get your PR merged once they pass CI

@ChaiBapchya
Copy link
Contributor

@access2rohit without rebasing, the fix #17962 won't be included [for windows-gpu]

@roywei
Copy link
Member Author

roywei commented Apr 7, 2020

@ChaiBapchya @access2rohit Rebased! Thanks!

@@ -1495,7 +1500,7 @@ class RNNOp {
cudnnRNNInputMode_t input_mode_;
cudnnDropoutDescriptor_t dropout_desc_;
Storage::Handle reserve_space_;
uint64_t seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seed_ is used for rand_r in the forward implementations. That's very bad practice, not portable and should be fixed.
As this PR already passes CI, it's fine to fix it in a follow-up PR. In fac, #17984 is blocked by the usage of rand_r so I may fix it there.

More info: https://channel9.msdn.com/Events/GoingNative/2013/rand-Considered-Harmful

Copy link
Contributor

@leezu leezu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@leezu leezu merged commit 249b9a1 into apache:master Apr 10, 2020
Comment on lines +258 to +264
Random<xpu, unsigned> *prnd = ctx.requested[1].get_random<xpu, unsigned>(s);
uint64_t rng_seed = prnd->GetSeed();
// reset dropout descriptor if rng seed changed.
bool reset = seed_ != rng_seed;
seed_ = rng_seed;
ctx.requested[0].get_cudnn_dropout_desc(&dropout_desc_, s, 1.0f - this->pkeep_,
seed_, reset);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@roywei there are more than 1 random number generators in the resources (4 by default) and they could have different seeds. The result of rotating random number generator seeds here is that cudnn dropout state is reinitialized very often. @sxjscience observed that there's significant performance impact because of this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe @roywei is already working on this issue, since this regression has been caught by our benchmark last week

Copy link
Member Author

@roywei roywei Apr 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spent some time looking into it and actually the problem is dropout's forward is entering the re-init cudnn dropout desc logic every time during forward even without my change. This is true for nd/np and gluon, false for symbol dropout. So originally it was already reinitializing in every forward. The reason it does not cause any performance regression is because it's always using the seed_ defined in uint64_t seed_ = 17 + rand() % 4096; and never changed, it won't listen to MXNet's PRNG (which is the problem this PR is trying to fix). So if the seed didn't change, event if you re-init cudnn dropout descriptor, it won't take any time. My PR changed the seed, so it was actually re-init every forward, thus the regression.

However, it works fine under symbol case. My guess is somehow for symbolic, every forward is using the same Dropout node and the state check took effect then didn't go into re-init logic. For imperative case the state check is always empty and went into the re-init logic.

Compare the following two code, one is gluon and one is symbol, if I print some log during reinitialization in the original code without my change here. It will print out every forward in ND/NP/Gluon, not the case in Symbol.

import mxnet as mx
data = mx.nd.ones((10, 200, 300, 500), ctx=mx.gpu(0))
dropout = mx.gluon.nn.Dropout(0.5)
# with or without hybridize is the same result
dropout.hybridize()
with mx.autograd.record():
        result1 = dropout(data)
        result2 = dropout(result1)

print 2 times

re-init dropout desc
re-init dropout desc

Symbol:

import mxnet as mx
data = mx.nd.ones((10, 200, 300, 500), ctx=mx.gpu(0))
net = mx.sym.Variable("data")
net = mx.sym.Dropout(data=net, p=0.5, cudnn_off=False)
exe = net.simple_bind(mx.gpu(0), data=data.shape)
result1 = exe.forward(is_train=True, data=data)
result2 = exe.forward(is_train=True, data=result1[0])

print 1 time

re-init dropout desc

Given this situation, I don't have a good solution as checking the state handle size does not work in imperative mode, checking PRNG seed will also not work. I will revert this PR for now as it's causing regressions for models using dropout.

cc @szha @sxjscience @apeforest

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@roywei the links in your comment don't seem to point to what they refer to so I'm a bit confused. Where did you insert the 're-init dropout desc' comment?

zheyuye added a commit to zheyuye/incubator-mxnet that referenced this pull request Apr 23, 2020
roywei added a commit to roywei/incubator-mxnet that referenced this pull request Apr 23, 2020
@roywei roywei mentioned this pull request Apr 23, 2020
roywei added a commit to roywei/incubator-mxnet that referenced this pull request Apr 23, 2020
leezu pushed a commit that referenced this pull request Apr 23, 2020
* Revert "Fix cudnn Dropout reproducibility (#17547)"

This reverts commit 249b9a1.

* fix conflict
AntiZpvoh pushed a commit to AntiZpvoh/incubator-mxnet that referenced this pull request Jul 6, 2020
* Revert "Fix cudnn Dropout reproducibility (apache#17547)"

This reverts commit 249b9a1.

* fix conflict
chinakook pushed a commit to chinakook/mxnet that referenced this pull request Jul 24, 2020
* Revert "Fix cudnn Dropout reproducibility (apache#17547)"

This reverts commit 249b9a1.

* fix conflict
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

cudnn Dropout reproducibility
8 participants