Skip to content

Commit

Permalink
fix dtype error in StableDiffusion and DreamBooth training
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoXing1996 committed May 29, 2023
1 parent 47a54de commit bd4584e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
10 changes: 6 additions & 4 deletions mmagic/models/editors/dreambooth/dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,12 @@ def train_step(self, data, optim_wrapper):
f'{self.scheduler.config.prediction_type}')

# NOTE: we train unet in fp32, convert to float manually
model_output = self.unet(
noisy_latents.float(),
timesteps,
encoder_hidden_states=encoder_hidden_states.float())
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
with torch.autocast(device_type=device_type, dtype=torch.float32):
model_output = self.unet(
noisy_latents.float(),
timesteps,
encoder_hidden_states=encoder_hidden_states.float())
model_pred = model_output['sample']

loss_dict = dict()
Expand Down
24 changes: 18 additions & 6 deletions mmagic/models/editors/stable_diffusion/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,17 @@ def __init__(self,
default_args = dict()
if dtype is not None:
default_args['dtype'] = dtype
self.dtype = dtype

self.dtype = torch.float32
if dtype in ['float16', 'fp16', 'half']:
self.dtype = torch.float16
elif dtype == 'bf16':
self.dtype = torch.bfloat16
else:
assert dtype in [
'fp32', None
], ('dtype must be one of \'fp32\', \'fp16\', \'bf16\' or None.')

self.vae = build_module(vae, MODELS, default_args=default_args)
self.unet = build_module(unet, MODELS, default_args=default_args)
self.scheduler = build_module(scheduler, DIFFUSION_SCHEDULERS)
Expand Down Expand Up @@ -627,11 +637,13 @@ def train_step(self, data, optim_wrapper_dict):
raise ValueError('Unknown prediction type '
f'{self.scheduler.config.prediction_type}')

# NOTE: convert to float manually
model_output = self.unet(
noisy_latents.float(),
timesteps,
encoder_hidden_states=encoder_hidden_states.float())
# NOTE: we train unet in fp32, convert to float manually
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
with torch.autocast(device_type=device_type, dtype=torch.float32):
model_output = self.unet(
noisy_latents.float(),
timesteps,
encoder_hidden_states=encoder_hidden_states.float())
model_pred = model_output['sample']

loss_dict = dict()
Expand Down

0 comments on commit bd4584e

Please sign in to comment.