Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non fully reproducible results on GPU #33

Closed
goingtosleep opened this issue Mar 4, 2020 · 13 comments
Closed

Non fully reproducible results on GPU #33

goingtosleep opened this issue Mar 4, 2020 · 13 comments

Comments

@goingtosleep
Copy link

goingtosleep commented Mar 4, 2020

Although random key is fixed (e.g. jax.random.PRNGKey(0)), the results of different runs are always different.

My question is how one can fix the random behavior? Because my expectation is when I choose a fixed random key, all the runs should produce the same result.

Thank you in advance.

Use the following code to reproduce the issue (I simply take the MNIST example with shuffle removed):

import jax
import flax
import numpy as onp
import jax.numpy as jnp
import tensorflow_datasets as tfds

class CNN(flax.nn.Module):
  def apply(self, x):
    x = flax.nn.Conv(x, features=32, kernel_size=(3, 3))
    x = jax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.nn.Conv(x, features=64, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))
    x = flax.nn.Dense(x, features=256)
    x = flax.nn.relu(x)
    x = flax.nn.Dense(x, features=10)
    x = flax.nn.log_softmax(x)
    return x

@jax.vmap
def cross_entropy_loss(logits, label):
  return -logits[label]

def compute_metrics(logits, labels):
  loss = jnp.mean(cross_entropy_loss(logits, labels))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return {'loss': loss, 'accuracy': accuracy}

@jax.jit
def train_step(optimizer, batch):
  def loss_fn(model):
    logits = model(batch['image'])
    loss = jnp.mean(cross_entropy_loss(
        logits, batch['label']))
    return loss, logits
  optimizer, _, _ = optimizer.optimize(loss_fn)
  return optimizer

@jax.jit
def eval(model, eval_ds):
  logits = model(eval_ds['image'] / 255.0)
  return compute_metrics(logits, eval_ds['label'])

def train():
  train_ds = tfds.load('mnist', split=tfds.Split.TRAIN)
  train_ds = train_ds.cache().batch(128)
  test_ds = tfds.as_numpy(tfds.load(
      'mnist', split=tfds.Split.TEST, batch_size=-1))

  _, model = CNN.create_by_shape(
      jax.random.PRNGKey(0),
      [((1, 28, 28, 1), jnp.float32)])

  optimizer = flax.optim.Momentum(
      learning_rate=0.1, beta=0.9).create(model)

  for epoch in range(10):
    for batch in tfds.as_numpy(train_ds):
      batch['image'] = batch['image'] / 255.0
      optimizer = train_step(optimizer, batch)

    metrics = eval(optimizer.target, test_ds)
    print('eval epoch: %d, loss: %.4f, accuracy: %.2f'
         % (epoch+1,
          metrics['loss'], metrics['accuracy'] * 100))

train()
train()
@avital
Copy link
Contributor

avital commented Mar 6, 2020

Hi @goingtosleep! There's a good chance this is due to TFDS. Here's what would be great: Could you dump MNIST entirely to disk once, and then modify the dataset loading code to read from disk? I believe then the training runs will be fully reproducible.

I'd also like to share this with the TFDS team, as I believe there should be a way to get reproducible dataset loaders, though we haven't yet done that ourselves.

@Conchylicultor
Copy link
Member

Conchylicultor commented Mar 6, 2020

An easy way to test this hypothesis would be to move the training pipeline outside of the train() fn and convert them to numpy and use list on the generator:

train_ds = tfds.load('mnist', split='train')
train_ds = list(tfds.as_numpy(train_ds.cache().batch(128)))
test_ds = tfds.as_numpy(tfds.load( 'mnist', split='test', batch_size=-1))

def train():
  ...
  for batch in ds_train:
    batch = dict(batch)  # Copy dict before mutating in-place
    ...

train()
train()

That way there is no more tf.data involved as train_ds is just a List[np.array]

Edit: Could you also share which version of Flax, TFDS and Python you're using ?

@jheek
Copy link
Member

jheek commented Mar 6, 2020

On what accelerator did you run this test?

On GPU reproducibility is never guaranteed because Jax is currently not deterministic on that platform even when executing the exact same computations.
BTW: We should probably mention this somewhere in the docs and so should Jax!

On CPU I was unable to reproduce your issue and TPU results should also be reproducible.

@goingtosleep
Copy link
Author

goingtosleep commented Mar 9, 2020

Thank you all for the replies.

I used flax-0.0.1a0, jax-0.1.59, jaxlib-0.1.39, tfds 2.0.0, on Google Colab P100 instance (Python 3.6.9).

I've done some tests so far:

  • On CPU, the random seed is fixed, the result is reproducible every run.

  • On GPU (Colab), I use keras datasets, which returns numpy ndarrays. Then I train the model with 25_000 first data points (due to memory limit) with no permutation and 1 batch only. Test accuracy is calculated on test set (10_000 data points) as usual. With this setup I believe there is no randomness involved except for weights initialization. The results are as follows:

First run:

Epoch 1 in [17.32s], loss [4.726857], accuracy [34.37%]
Epoch 2 in [0.02s], loss [3.494419], accuracy [70.59%]
Epoch 3 in [0.02s], loss [2.420179], accuracy [84.71%]
Epoch 4 in [0.02s], loss [1.485580], accuracy [88.12%]
Epoch 5 in [0.02s], loss [0.923213], accuracy [91.93%]

Second run:

Epoch 1 in [9.93s], loss [4.726847], accuracy [34.37%]
Epoch 2 in [0.02s], loss [3.494333], accuracy [70.60%]
Epoch 3 in [0.02s], loss [2.420010], accuracy [84.70%]
Epoch 4 in [0.02s], loss [1.485356], accuracy [88.11%]
Epoch 5 in [0.02s], loss [0.923163], accuracy [91.93%]

Third run:

Epoch 1 in [9.81s], loss [4.726859], accuracy [34.40%]
Epoch 2 in [0.02s], loss [3.494327], accuracy [70.59%]
Epoch 3 in [0.02s], loss [2.419940], accuracy [84.70%]
Epoch 4 in [0.02s], loss [1.485139], accuracy [88.11%]
Epoch 5 in [0.02s], loss [0.923007], accuracy [91.94%]

Rerun after restarting the runtime:

Epoch 1 in [19.73s], loss [4.726849], accuracy [34.37%]
Epoch 2 in [0.02s], loss [3.494359], accuracy [70.59%]
Epoch 3 in [0.02s], loss [2.420025], accuracy [84.70%]
Epoch 4 in [0.02s], loss [1.485441], accuracy [88.10%]
Epoch 5 in [0.02s], loss [0.923208], accuracy [91.93%]

I used a different network architecture so the accuracy and loss could be different from the MNIST example, but the network is deterministic (no dropout), so no randomness involved here in the training process.

I'm not sure if the differences in the above results are due to float32 precision. I tried

from jax.config import config
config.update("jax_enable_x64", True)

but the weights are still of float32, maybe I will test this later. What do you think?

@avital
Copy link
Contributor

avital commented Mar 10, 2020

Yes, indeed at the moment XLA builds on GPU aren't fully reproducible, e.g. jax-ml/jax#565. I'll check with the JAX team to learn more.

@jheek
Copy link
Member

jheek commented Mar 19, 2020

Can we close this for now? I think both XLA and Jax teams are aware of this issue and the fix is in progress.

@avital
Copy link
Contributor

avital commented Mar 19, 2020

If the fix isn't in, I don't think we should close the issue.

@marcvanzee
Copy link
Collaborator

jax-ml/jax#565 is fixed. I've verified that @goingtosleep 's code now outputs reproducible results on a TPU, so I am closing this issue.

@avital
Copy link
Contributor

avital commented Mar 27, 2020

Sorry, the fix isn't in. JAX is still not reproducible on GPU. We need to make sure there's an open ticket on the JAX GitHub tracker.

@avital avital reopened this Mar 27, 2020
@avital avital changed the title Random seed is not fixed Non fully reproducible results on GPU Apr 1, 2020
@avital
Copy link
Contributor

avital commented Oct 27, 2020

Hi @goingtosleep -- this is a bit late but the the XLA_FLAGS=--xla_gpu_deterministic_reductions environment should work (though perhaps not yet for all operations). I'd be curious to see if this solves the issue.

@goingtosleep
Copy link
Author

goingtosleep commented Oct 28, 2020

I confirm that for the MNIST example, this issue is solved. With the following command:

!export XLA_FLAGS=--xla_gpu_deterministic_reductions && export TF_CUDNN_DETERMINISTIC=1 && echo $XLA_FLAGS, $TF_CUDNN_DETERMINISTIC && python main.py,

results are consistent between 2 runs (on Google Colab):

--xla_gpu_deterministic_reductions, 1
2020-10-28 16:14:10.248318: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
eval epoch: 1, loss: 0.0563, accuracy: 98.14
eval epoch: 2, loss: 0.0561, accuracy: 98.21
eval epoch: 3, loss: 0.0401, accuracy: 98.79
eval epoch: 4, loss: 0.0365, accuracy: 98.89
eval epoch: 5, loss: 0.0359, accuracy: 98.95
eval epoch: 6, loss: 0.0360, accuracy: 98.94
eval epoch: 7, loss: 0.0303, accuracy: 99.16
eval epoch: 8, loss: 0.0418, accuracy: 98.93
eval epoch: 9, loss: 0.0406, accuracy: 99.03
eval epoch: 10, loss: 0.0326, accuracy: 99.18

eval epoch: 1, loss: 0.0563, accuracy: 98.14
eval epoch: 2, loss: 0.0561, accuracy: 98.21
eval epoch: 3, loss: 0.0401, accuracy: 98.79
eval epoch: 4, loss: 0.0365, accuracy: 98.89
eval epoch: 5, loss: 0.0359, accuracy: 98.95
eval epoch: 6, loss: 0.0360, accuracy: 98.94
eval epoch: 7, loss: 0.0303, accuracy: 99.16
eval epoch: 8, loss: 0.0418, accuracy: 98.93
eval epoch: 9, loss: 0.0406, accuracy: 99.03
eval epoch: 10, loss: 0.0326, accuracy: 99.18

@mattjj
Copy link
Member

mattjj commented Nov 7, 2020

We're working on improving the JAX documentation on this in jax-ml/jax#4824. Feedback on that PR is welcome!

@mattjj
Copy link
Member

mattjj commented Nov 11, 2021

It sounds like the --xla_gpu_deterministic_reductions flag is now gone (or it will be when we push an updated jaxlib) because it's now effectively always on by default. So hopefully this will get less surprising...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants