Skip to content

Commit

Permalink
[train_unconditional] fix gradient accumulation. (open-mmlab#308)
Browse files Browse the repository at this point in the history
fix grad accum
  • Loading branch information
patil-suraj authored Sep 1, 2022
1 parent 4724250 commit 1b1d644
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions examples/unconditional_image_generation/train_unconditional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import math
import os

import torch
Expand Down Expand Up @@ -29,6 +30,7 @@
def main(args):
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with="tensorboard",
logging_dir=logging_dir,
Expand Down Expand Up @@ -105,6 +107,8 @@ def transforms(examples):
model, optimizer, train_dataloader, lr_scheduler
)

num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)

if args.push_to_hub:
Expand All @@ -117,7 +121,7 @@ def transforms(examples):
global_step = 0
for epoch in range(args.num_epochs):
model.train()
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
Expand Down Expand Up @@ -146,13 +150,16 @@ def transforms(examples):
ema_model.step(model)
optimizer.zero_grad()

progress_bar.update(1)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
if args.use_ema:
logs["ema_decay"] = ema_model.decay
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
global_step += 1
progress_bar.close()

accelerator.wait_for_everyone()
Expand Down

0 comments on commit 1b1d644

Please sign in to comment.