Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Baseline results are far off from Lua version #21

Closed
vince62s opened this issue Nov 12, 2017 · 16 comments
Closed

Baseline results are far off from Lua version #21

vince62s opened this issue Nov 12, 2017 · 16 comments

Comments

@vince62s
Copy link
Member

@guillaumekln I redid my comparison it is not so good.

Baseline 1M enfr:
Lua (brnn 2x500): BLEU on test set without replace_unk = 35.65 (beam1) 37.00 (beam5)
62 min per epoch on a GTX1080

TF (2x500 rnn) optim sgd: BLEU 26.47 after 200k steps
1h24 per 15K steps (approx 1 epoch) on a GTX1080ti

TF (2x500 rnn) optim noam: BLEU 28.87 after 200k steps

TF (transformer) optim noam: BLEU 33.23 after 100K steps

@gsoul
Copy link

gsoul commented Nov 12, 2017

I'm not sure if that's related. But I noticed that each evaluation step (that by default happens every 5h) - after model restoration - throws loss somewhat back and then it requires ~10k-15k steps to get to the same numbers as before evaluation.
For now I increased "eval_delay" from 5h to 50h, we'll see what happens...

@guillaumekln
Copy link
Contributor

Thanks for the comparison.

Do you confirm that you compared RNN with BRNN? I already identified some improvements to do both for speed and performance.

How does the transformer result compare with other implementations?

@vince62s
Copy link
Member Author

Yes it is rnn vs brnn but the rnn version of Lua is slightly off vs its rnn. I just took brnn by mistake but no major change.
vs t2t very far. 38.4 at 130k steps.

@vince62s
Copy link
Member Author

Just one strange thing.
when I ran the sgd optim, I set 150K steps and decay each 15K after but there was no "gap" in loss decrease, I thought it would have been more obvious.

@gsoul
Copy link

gsoul commented Nov 21, 2017

For Fr->En I get 22.5 BLEU after 500k steps and batch size 32, and 19.5 BLEU after 100k steps and batch size = 58.
I used multi-bleu.perl to calculate it, but I have somewhat custom tokenization: #24.
vocabulary_size = 80000

@guillaumekln
Copy link
Contributor

guillaumekln commented Nov 21, 2017

Thanks for testing.

I recently pushed this commit that could have a non negligible impact on the training. With the master version, I easily reached about 37 BLEU with the Transformer on the 1M ENFR baseline (by the way, major speedup for this model is coming very soon!)

Regarding RNNs, the default attention layer is tf.contrib.seq2seq.LuongAttention which implements the dot variant and is less powerful than the attention layer implemented in OpenNMT-lua. Maybe a better default would be tf.contrib.seq2seq.BahdanauAttention or replicate the GNMTAttention used in tensorflow/nmt.


For now I increased "eval_delay" from 5h to 50h, we'll see what happens...

This could be a real issue with some dataset and is covered in the TensorFlow documentation:

Overfitting: In order to avoid overfitting, it is recommended to set up the training input_fn to shuffle the training data properly. It is also recommended to train the model a little longer, say multiple epochs, before performing evaluation, as the input pipeline starts from scratch for each training. It is particularly important for local training and evaluation.

That means if eval_delay is set to 5 hours but one epoch actually takes 20 hours, the training will only see 1/4 of the training data (if we assume buffer_size is small compared to the complete training data).

So either:

  • eval_delay should be set to a duration greater than a complete epoch
  • buffer_size should be set to a size greater than the number of examples in the training data so the shuffling operates on the whole dataset.

Both cases are not convenient and a simple misconfiguration can significantly impact the training. This should be revised, maybe by enforcing the evaluation delay.

@gsoul
Copy link

gsoul commented Nov 21, 2017

I recently pushed this commit that could have a non negligible impact on the training.

Yeah, I incorporated it in the middle of 500k training, and 100+k is already using it.

I easily reached about 37 BLEU with the Transformer on the 1M ENFR baseline

I use WMT14 ENFR corpus for FR->EN, but results are not as promising so far.

This could be a real issue with some dataset and is covered in the TensorFlow documentation:

Yeah, I made a stupid thing, thanks for pointing that out! ). But for 100k+ training I decreased that number to 10h, and I just returned it to 5h. So we'll see what happens.

but one epoch actually takes 20 hours

I believe it strongly depends on GPU. If you use TitanX or 1080ti, it might be 1.5-2 times faster than my 1070.

Thank you for your fast reply and great insights into the config settings! I'll try them out.

@gsoul
Copy link

gsoul commented Nov 21, 2017

(by the way, major speedup for this model is coming very soon!)

Great news, would it involve training or inference phase, or both?

@guillaumekln
Copy link
Contributor

Both. I'm preparing the PR.

@gsoul
Copy link

gsoul commented Nov 22, 2017

The models became much faster indeed. Training - 1.5 times faster, and inference "a hell lot faster", I don't know like 10 times or maybe even more.

At the same time. After 220k steps my Fr->En translator gives 21.34 BLEU score, which seems a bit too small. All the settings seem to be default, maybe except for tokenization. Though I don't think there's an issue there.

Any thoughts/ideas on what might be wrong? I don't think that Fr->En is so much more complicated than En->Fr. I'm not sure if BPE gives a big accuracy boost...

@guillaumekln
Copy link
Contributor

What configuration(s) are you actually using? In particular, do you use config/optim/adam_with_noam_decay.yml?

You should get good results by now.

@gsoul
Copy link

gsoul commented Nov 23, 2017

Here are my configs:

config/opennmt-transf.yml

# Default OpenNMT parameters.

params:
  optimizer: GradientDescentOptimizer
  learning_rate: 1.0
  clip_gradients: 5.0
  decay_type: exponential_decay
  decay_rate: 0.7
  decay_steps: 100000
  start_decay_steps: 500000
  beam_width: 5
  maximum_iterations: 250

train:
  batch_size: 56
  save_checkpoints_steps: 5000
  save_summary_steps: 1000
  train_steps: 2000000
  eval_delay: 18000 # Every 10 hours.
  maximum_features_length: 50
  maximum_labels_length: 50

infer:
  batch_size: 30

config/optim/adam_with_noam_decay.yml

# Uses the learning rate schedule defined in https://arxiv.org/abs/1706.03762.

params:
  optimizer: AdamOptimizer
  learning_rate: 1.0 # The scale constant.
  decay_type: noam_decay
  decay_rate: 512 # Model dimension.
  decay_steps: 16000 # Warmup steps.
  start_decay_steps: 0

config/data/enfr.yml

model_dir: enfr

data:
  train_features_file: data/enfr/giga-fren.release2.token.fr
  train_labels_file: data/enfr/giga-fren.release2.token.en
  eval_features_file: data/enfr/newstest2013.token.fr
  eval_labels_file: data/enfr/newstest2013.token.en
  source_words_vocabulary: data/enfr/src-token.txt
  target_words_vocabulary: data/enfr/tgt-token.txt

command to run:

CUDA_VISIBLE_DEVICES=0 nohup unbuffer ~/anaconda2/bin/python -m bin.main train --model config/models/transformer.py --config config/opennmt-transf.yml config/optim/adam_with_noam_decay.yml config/data/enfr.yml > /home/soul/projects/opennmt-tf/transformer_token_6_6_default.txt &

@gsoul
Copy link

gsoul commented Nov 23, 2017

Tensorboard:

screen shot 2017-11-23 at 11 43 46 am

@guillaumekln
Copy link
Contributor

guillaumekln commented Nov 23, 2017

The configurations look good. I'm starting a comparative run on the same dataset and will check.

@gsoul
Copy link

gsoul commented Nov 23, 2017

Great! I'm looking forward hearing about your results.

@guillaumekln
Copy link
Contributor

guillaumekln commented Nov 30, 2017

Regarding the initial issue, I recently pushed some commits that make the implementation and training even closer to the Lua version.

On the 1M ENFR baseline, I obtain 36.73 BLEU after 200k steps with a brnn 2x500 and the configuration below:

model_dir: run/baseline-enfr-rnn

data:
  train_features_file: /training/Users/klein/baseline-1M-enfr/baseline-1M_train.en.light_tok
  train_labels_file: /training/Users/klein/baseline-1M-enfr/baseline-1M_train.fr.light_tok
  eval_features_file: /training/Users/klein/baseline-1M-enfr/baseline-1M_test.en.light_tok
  eval_labels_file: /training/Users/klein/baseline-1M-enfr/baseline-1M_test.fr.light_tok
  source_words_vocabulary: /training/Users/klein/baseline-1M-enfr/en-vocab.txt
  target_words_vocabulary: /training/Users/klein/baseline-1M-enfr/fr-vocab.txt

params:
  optimizer: GradientDescentOptimizer
  learning_rate: 1.0
  clip_gradients: 5.0
  param_init: 0.1
  decay_type: exponential_decay
  decay_rate: 0.7
  decay_steps: 20000
  start_decay_steps: 140000
  beam_width: 5
  maximum_iterations: 250

train:
  batch_size: 64
  save_checkpoints_steps: 5000
  save_summary_steps: 200
  train_steps: 200000
  eval_delay: 7200 # Every 2 hours.
  maximum_features_length: 50
  maximum_labels_length: 50
  save_eval_predictions: true
  external_evaluators: BLEU

infer:
  batch_size: 30

I would now call that on par with the Lua version, so I'm closing this issue.


@gsoul Regarding the giga-fren train, an important aspect appears to be proper data shuffling. Manually shuffling the data before the training is a good start to increase the learning efficiency. Otherwise, I introduced a new sample_buffer_size option that defines how many examples to sample from. You can set this to the corpus size for a random uniform permutation but it will obviously increase the CPU memory usage. I set this value to 1,000,000 by default.

My Transformer training on this dataset is now at 25.70 BLEU after 380k steps. Does anyone know what score we should expect on newstest2013 for FREN?

However, the Transformer is another piece of work. As @vince62s shared on the tensor2tensor repository, the model is sensitive to hyper-parameters in particular batch size. In any cases, let's keep this out of the comparison with the Lua version and maybe open a new issue to track the tuning of this model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants