Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lora] correct lora saving & loading #2655

Merged
merged 3 commits into from
Mar 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 23 additions & 25 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .models.cross_attention import LoRACrossAttnProcessor
from .models.modeling_utils import _get_model_file
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, is_safetensors_available, logging
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging


if is_safetensors_available():
Expand Down Expand Up @@ -150,13 +150,14 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict

model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
if (is_safetensors_available() and weight_name is None) or weight_name.endswith(".safetensors"):
if weight_name is None:
weight_name = LORA_WEIGHT_NAME_SAFE
# Let's first try to load .safetensors weights
if (is_safetensors_available() and weight_name is None) or (
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
weight_name is not None and weight_name.endswith(".safetensors")
):
Comment on lines +154 to +156
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to add a small comment highlighting that we always look for safetensors first just for easy readability?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point

try:
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
Expand All @@ -169,14 +170,13 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except EnvironmentError:
if weight_name == LORA_WEIGHT_NAME_SAFE:
weight_name = None
# try loading non-safetensors weights
pass

if model_file is None:
if weight_name is None:
weight_name = LORA_WEIGHT_NAME
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
weights_name=weight_name or LORA_WEIGHT_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
Expand Down Expand Up @@ -225,9 +225,10 @@ def save_attn_procs(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
weights_name: str = None,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
**kwargs,
):
r"""
Save an attention processor to a directory, so that it can be re-loaded using the
Expand All @@ -245,6 +246,12 @@ def save_attn_procs(
need to replace `torch.save` by another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
"""
weight_name = weight_name or deprecate(
"weights_name",
"0.18.0",
"`weights_name` is deprecated, please use `weight_name` instead.",
take_from=kwargs,
)
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
Expand All @@ -265,22 +272,13 @@ def save_function(weights, filename):
# Save the model
state_dict = model_to_save.state_dict()

# Clean the folder from a previous save
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was a bad copy-paste and is not needed

for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions.
weights_no_suffix = weights_name.replace(".bin", "")
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
os.remove(full_filename)

if weights_name is None:
if weight_name is None:
if safe_serialization:
weights_name = LORA_WEIGHT_NAME_SAFE
weight_name = LORA_WEIGHT_NAME_SAFE
else:
weights_name = LORA_WEIGHT_NAME
weight_name = LORA_WEIGHT_NAME

# Save the model
save_function(state_dict, os.path.join(save_directory, weights_name))
save_function(state_dict, os.path.join(save_directory, weight_name))

logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")