diff --git a/src/interactive_conditional_samples.py b/src/interactive_conditional_samples.py index 8b66000d7..1e95c0366 100755 --- a/src/interactive_conditional_samples.py +++ b/src/interactive_conditional_samples.py @@ -9,13 +9,13 @@ import model, sample, encoder def interact_model( - model_name='124M', + model_name='774M', seed=None, nsamples=1, batch_size=1, - length=None, + length=900, temperature=1, - top_k=0, + top_k=40, top_p=1, models_dir='models', ): @@ -65,7 +65,7 @@ def interact_model( temperature=temperature, top_k=top_k, top_p=top_p ) - saver = tf.train.Saver() + saver = tf.compat.v1.train.Saver() ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) saver.restore(sess, ckpt)