Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model averaging #337

Open
danpovey opened this issue Apr 28, 2022 · 3 comments
Open

Model averaging #337

danpovey opened this issue Apr 28, 2022 · 3 comments

Comments

@danpovey
Copy link
Collaborator

danpovey commented Apr 28, 2022

OK, we have some results locally (thanks, @yaozengwei!) showing that model averaging over finely spaced checkpoints is a bit better than averaging over epochs. It's about 0.05% on test-clean and 0.15% on test-other, at around 3%/7% WER, but probably still worth doing as at this point we are picking up pennies in WER.

... OK, here's the idea. The idea is that we always store a separate version of the model, say model_avg, in which for each floating-point parameter, it contains the average from the start of training of all the parameters. We update this every average_period batches, for, say, average_period = 10 or 100 (could be every batch but this is for speed). Each time we average, we do:

    model_avg = model_avg * ((cur_batch_idx_train-average_period) / cur_batch_idx_train) + model * (average_period/cur_batch_idx_train)

[this is not the syntax we'd use, we'd have to write a function to do this weighted-average.]
I propose that we include this averaged version inside the checkpoints-*.pt and epoch-*.pt as a separate key in the dict.
Then the way we would implement decoding epoch-29.pt with --avg 5 would be something like the following.

   Let epoch_30 = epoch-29.pt ['averaged']    # Unfortunately we have to add 1.. epoch-29.pt was really 
                                                                            #  trained for 30 epochs!!  I wish there was a way to fix this without disruption.
         epoch_25 = epoch-24.pt ['averaged']
  epochs_25_to_30 = (epoch_30 * 30 - epoch_25 * 25) / 5  # Note, we don't actually need to use the epoch numbers
                                         # in this formula.  We can use the cur_batch_idx_train from the dicts,
                                           # which will be more robust if we change the dataset definition halfway through training.
@yaozengwei
Copy link
Collaborator

I will write the function.

@pzelasko
Copy link
Collaborator

I think PyTorch has sth similar here https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging

@danpovey
Copy link
Collaborator Author

OK that's interesting. I think we can still write our own though, because it doesn't look to me like it's easy to use that PyTorch thing with our batch_idx_train, which allows us to choose the period for averaging at decode time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants