Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rishikksh20 authored Mar 27, 2021
1 parent a458359 commit d41cb23
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,9 @@ def train(local_rank, args, hp, model):
if (step + 1) % hp.train.accum_grad == 0:
losses.update(loss.item()*hp.train.accum_grad)
torch.nn.utils.clip_grad_norm_(model.parameters(), hp.train.grad_clip)
scheduler.step()

optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1

Expand Down Expand Up @@ -275,4 +276,4 @@ def main():


if __name__ == "__main__":
main()
main()

0 comments on commit d41cb23

Please sign in to comment.