Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Intergrate fp16/bf16 support to sdxl model loading #791

Merged
merged 2 commits into from
Sep 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions library/sdxl_model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _load_state_dict_on_device(model, state_dict, device, dtype=None):

def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
# model_version is reserved for future use
# dtype is reserved for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
# dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching

# Load the state dict
if model_util.is_safetensors(ckpt_path):
Expand Down Expand Up @@ -193,7 +193,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
for k in list(state_dict.keys()):
if k.startswith("model.diffusion_model."):
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
info = _load_state_dict_on_device(unet, unet_sd, device=map_location)
info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
print("U-Net: ", info)

# Text Encoders
Expand Down Expand Up @@ -221,7 +221,8 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
text_model1 = CLIPTextModel._from_config(text_model1_cfg)
with init_empty_weights():
text_model1 = CLIPTextModel._from_config(text_model1_cfg)

# Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace.
# Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
Expand All @@ -246,7 +247,8 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
with init_empty_weights():
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)

print("loading text encoders from checkpoint")
te1_sd = {}
Expand All @@ -257,21 +259,22 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
elif k.startswith("conditioner.embedders.1.model."):
te2_sd[k] = state_dict.pop(k)

info1 = text_model1.load_state_dict(te1_sd)
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
print("text encoder 1:", info1)

converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
info2 = text_model2.load_state_dict(converted_sd)
info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
print("text encoder 2:", info2)

# prepare vae
print("building VAE")
vae_config = model_util.create_vae_diffusers_config()
vae = AutoencoderKL(**vae_config) # .to(device)
with init_empty_weights():
vae = AutoencoderKL(**vae_config)

print("loading VAE from checkpoint")
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
info = vae.load_state_dict(converted_vae_checkpoint)
info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
print("VAE:", info)

ckpt_info = (epoch, global_step) if epoch is not None else None
Expand Down
33 changes: 29 additions & 4 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

def load_target_model(args, accelerator, model_version: str, weight_dtype):
# load models for each process
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
Expand All @@ -36,6 +37,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
model_version,
weight_dtype,
accelerator.device if args.lowram else "cpu",
model_dtype
)

# work on low-ram device
Expand All @@ -54,7 +56,8 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info


def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu"):
def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None):
# model_dtype only work with full fp16/bf16
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers

Expand All @@ -67,7 +70,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
unet,
logit_scale,
ckpt_info,
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, weight_dtype)
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype)
else:
# Diffusers model is loaded to CPU
from diffusers import StableDiffusionXLPipeline
Expand All @@ -77,7 +80,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
try:
try:
pipe = StableDiffusionXLPipeline.from_pretrained(
name_or_path, torch_dtype=weight_dtype, variant=variant, tokenizer=None
name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None
)
except EnvironmentError as ex:
if variant is not None:
Expand All @@ -93,6 +96,13 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version

text_encoder1 = pipe.text_encoder
text_encoder2 = pipe.text_encoder_2

# convert to fp32 for cache text_encoders outputs
if text_encoder1.dtype != torch.float32:
text_encoder1 = text_encoder1.to(dtype=torch.float32)
if text_encoder2.dtype != torch.float32:
text_encoder2 = text_encoder2.to(dtype=torch.float32)

vae = pipe.vae
unet = pipe.unet
del pipe
Expand All @@ -101,7 +111,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device)
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
print("U-Net converted to original U-Net")

logit_scale = None
Expand Down Expand Up @@ -146,6 +156,21 @@ def load_tokenizers(args: argparse.Namespace):
return tokeniers


def match_mixed_precision(args, weight_dtype):
if args.full_fp16:
assert (
weight_dtype == torch.float16
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
return weight_dtype
elif args.full_bf16:
assert (
weight_dtype == torch.bfloat16
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
return weight_dtype
else:
return None


def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Expand Down