-
Notifications
You must be signed in to change notification settings - Fork 645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Non fully reproducible results on GPU #33
Comments
Hi @goingtosleep! There's a good chance this is due to TFDS. Here's what would be great: Could you dump MNIST entirely to disk once, and then modify the dataset loading code to read from disk? I believe then the training runs will be fully reproducible. I'd also like to share this with the TFDS team, as I believe there should be a way to get reproducible dataset loaders, though we haven't yet done that ourselves. |
An easy way to test this hypothesis would be to move the training pipeline outside of the train_ds = tfds.load('mnist', split='train')
train_ds = list(tfds.as_numpy(train_ds.cache().batch(128)))
test_ds = tfds.as_numpy(tfds.load( 'mnist', split='test', batch_size=-1))
def train():
...
for batch in ds_train:
batch = dict(batch) # Copy dict before mutating in-place
...
train()
train() That way there is no more tf.data involved as train_ds is just a Edit: Could you also share which version of Flax, TFDS and Python you're using ? |
On what accelerator did you run this test? On GPU reproducibility is never guaranteed because Jax is currently not deterministic on that platform even when executing the exact same computations. On CPU I was unable to reproduce your issue and TPU results should also be reproducible. |
Thank you all for the replies. I used I've done some tests so far:
First run:
Second run:
Third run:
Rerun after restarting the runtime:
I used a different network architecture so the accuracy and loss could be different from the MNIST example, but the network is deterministic (no dropout), so no randomness involved here in the training process. I'm not sure if the differences in the above results are due to from jax.config import config
config.update("jax_enable_x64", True) but the weights are still of |
Yes, indeed at the moment XLA builds on GPU aren't fully reproducible, e.g. jax-ml/jax#565. I'll check with the JAX team to learn more. |
Can we close this for now? I think both XLA and Jax teams are aware of this issue and the fix is in progress. |
If the fix isn't in, I don't think we should close the issue. |
jax-ml/jax#565 is fixed. I've verified that @goingtosleep 's code now outputs reproducible results on a TPU, so I am closing this issue. |
Sorry, the fix isn't in. JAX is still not reproducible on GPU. We need to make sure there's an open ticket on the JAX GitHub tracker. |
Hi @goingtosleep -- this is a bit late but the the |
I confirm that for the MNIST example, this issue is solved. With the following command:
results are consistent between 2 runs (on Google Colab):
|
We're working on improving the JAX documentation on this in jax-ml/jax#4824. Feedback on that PR is welcome! |
It sounds like the |
Although random key is fixed (e.g.
jax.random.PRNGKey(0)
), the results of different runs are always different.My question is how one can fix the random behavior? Because my expectation is when I choose a fixed random key, all the runs should produce the same result.
Thank you in advance.
Use the following code to reproduce the issue (I simply take the MNIST example with
shuffle
removed):The text was updated successfully, but these errors were encountered: