Skip to content

Commit

Permalink
Merge 7e5fc1a into 47a54de
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoXing1996 authored May 29, 2023
2 parents 47a54de + 7e5fc1a commit a0ce951
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions mmagic/models/editors/stable_diffusion/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,19 @@ def __init__(self,
default_args = dict()
if dtype is not None:
default_args['dtype'] = dtype
self.dtype = dtype

self.dtype = torch.float32
if dtype in ['float16', 'fp16', 'half']:
self.dtype = torch.float16
elif dtype == 'bf16':
self.dtype = torch.bfloat16
else:
assert dtype in [
'fp32', None
], ('dtype must be one of \'fp32\', \'fp16\', \'bf16\' or None.')

self.vae = build_module(vae, MODELS, default_args=default_args)
self.unet = build_module(unet, MODELS, default_args=default_args)
self.unet = build_module(unet, MODELS) # NOTE: initialize unet as fp32
self.scheduler = build_module(scheduler, DIFFUSION_SCHEDULERS)
if test_scheduler is None:
self.test_scheduler = deepcopy(self.scheduler)
Expand Down Expand Up @@ -627,7 +637,7 @@ def train_step(self, data, optim_wrapper_dict):
raise ValueError('Unknown prediction type '
f'{self.scheduler.config.prediction_type}')

# NOTE: convert to float manually
# NOTE: we train unet in fp32, convert to float manually
model_output = self.unet(
noisy_latents.float(),
timesteps,
Expand Down

0 comments on commit a0ce951

Please sign in to comment.