Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add features to the Dreambooth LoRA SDXL training script #5508

Merged
merged 30 commits into from
Nov 21, 2023

Conversation

linoytsaban
Copy link
Collaborator

What does this PR do?

@apolinario
Adds some popular lora sd-xl tuning features to the dreambooth lora sd xl training script:

  • support for different lr for text encoder
  • support for Prodigy optimizer
  • support for min snr gamma
  • support for custom captions and dataset loading from the hub

linoytsaban and others added 2 commits October 24, 2023 11:34
- support for different lr for text encoder
- support for Prodigy optimizer
- support for min snr gamma
- support for custom captions and dataset loading from the hub
@apolinario apolinario changed the title Additions: Add features to the Dreambooth LoRA SDXL training script Oct 24, 2023
@linoytsaban linoytsaban marked this pull request as ready for review October 24, 2023 13:09
@patrickvonplaten
Copy link
Contributor

Ok for me even though I'm not sure it's a good idea to keep adding more and more features to our example scripts because:

  • a) It's a never ending story
  • b) It makes the examples harder to understand and moves us more into the realm of a Trainer instead of an example.

Transformers has a nice philosophy here:

While we strive to present as many use cases as possible, the example scripts are just that - examples. It is expected that they won't work out-of-the-box on your specific problem and that you will be required to change a few lines of code to adapt them to your needs. To help you with that, most of the examples fully expose the preprocessing of the data, allowing you to tweak and edit them as required.

From https://github.com/huggingface/transformers/tree/main/examples#examples

I'd advocate for bit-by-bit moving towards the same philosophy in diffusers especially for examples that have existed for quite some time and are already quite big

@sayakpaul
Copy link
Member

Thinking about it, I tend to agree with @patrickvonplaten. But in order to maximize the impact here for diffusers and to show the community it's possible to leverage an all-along diffusers and HF ecosystem to do great DreamBooth training runs for SDXL, I think this has a lot of merit.

in order to maximize the impact here for diffusers

This will probably organically draw the community to use our script and base their implementations on top of that.

So, to that, I say we keep this script in community_examples like we do for ControlNet SDXL and make a note from the README. Who knows, we end up publishing a tech report with that script: "Bag of tricks for improving DreamBooth on SDXL".

I guess this gives us a win from both the worlds?

@apolinario
Copy link
Collaborator

apolinario commented Nov 3, 2023

IMO some elements of this PR are very fundamental to training and dramatically improve what training is able to do while not increase the complexity as much (e.g.: separate LR for the text encoder vs the UNet). While I agree with the Transformers philosophy, I think the features we are supporting here do not violate that.

Rather, the opposite, they actually provide examples for what is today considered fundamental for such task: allow custom captioning, allow separate LR, allow auto-optimizers, etc.

Besides that, I agree with @sayakpaul that we should aim to achieve the best of both worlds by having both training examples & community examples. But I very strongly think the changes introduced by this PR should live in the example scrip.

A new script we are working on, that includes pivotal tuning and some other more complex uses for example, could be in a good place for a community example.

return batch


def compute_snr(timesteps):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Not relevant to this PR] @sayakpaul this function is now in 5+ training scripts I think. I'd be ok with moving it to training_utils.py

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def compute_snr(noise_scheduler, timesteps):

:)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we import it here @linoytsaban ?


# Optimizer creation
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warn(f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

safeguard_warmup=args.prodigy_safeguard_warmup,
)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this elif statement here: https://github.com/huggingface/diffusers/pull/5508/files#r1391214175
to cluster all adamw at the same spot?

# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code style looks weird here

Copy link
Collaborator Author

@linoytsaban linoytsaban Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually used the same logic that was used here (as I thought it makes sense so we maintain consistency across examples):

more specifically, I copied this code segment:

               if args.with_prior_preservation:
                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
                    target, target_prior = torch.chunk(target, 2, dim=0)
                    # Compute prior loss
                    prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

                # Compute instance loss
                if args.snr_gamma is None:
                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
                else:
                    # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.
                    # This is discussed in Section 4.2 of the same paper.
                    snr = compute_snr(noise_scheduler, timesteps)
                    base_weight = (
                        torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
                    )

                    if noise_scheduler.config.prediction_type == "v_prediction":
                        # Velocity objective needs to be floored to an SNR weight of one.
                        mse_loss_weights = base_weight + 1
                    else:
                        # Epsilon and sample both use the same loss weights.
                        mse_loss_weights = base_weight
                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
                    loss = loss.mean()

                if args.with_prior_preservation:
                    # Add the prior loss to the instance loss.
                    loss = loss + args.prior_loss_weight * prior_loss

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 13, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Nov 13, 2023

IMO some elements of this PR are very fundamental to training and dramatically improve what training is able to do while not increase the complexity as much (e.g.: separate LR for the text encoder vs the UNet). While I agree with the Transformers philosophy, I think the features we are supporting here do not violate that.

Rather, the opposite, they actually provide examples for what is today considered fundamental for such task: allow custom captioning, allow separate LR, allow auto-optimizers, etc.

Besides that, I agree with @sayakpaul that we should aim to achieve the best of both worlds by having both training examples & community examples. But I very strongly think the changes introduced by this PR should live in the example scrip.

A new script we are working on, that includes pivotal tuning and some other more complex uses for example, could be in a good place for a community example.

I'm ok to (exceptionally) go forward with this PR, if you say it adds big improvements to the training script.
I don't agree that it is in line with the above linked philosophy because:

  • It significantly increases complexity, e.g. I'm not able to see just from reviewing this PR whether everything is backward compatible or not because too many if-else statements are added and too much code is moved around.
  • People will now open PRs to add similar features to other examples script where I don't think they might necessarily make sense
  • I neither consider "allow custom captioning" nor "prodigy" as fundamental for dreambooth because:
    • Neither is used in the original dreambooth paper
    • Neither has enough adoption to be considered fundamental IMO (maybe at least not in the scientific community). Prodigy is far from being a standard optimizer that is added to major PyTorch codebases - it's does not have much traction: https://github.com/konstmish/prodigy .
    • I haven't seen an issue neither about custom captioning nor prodigy
  • I don't have a problem with adding a different learning rate for the text encoder nor with min_snr as those are just a few line changes and/or they are already exhaustively used in the library, they make obvious sense and we had issues / requests for them.

I'm ok with the changes if you feel strongly @linoytsaban @apolinario but going forward I'm not too keen on "updating" the dreambooth training script every 1,2 months.

@linoytsaban
Copy link
Collaborator Author

linoytsaban commented Nov 13, 2023

@patrickvonplaten @sayakpaul I see what you mean.

Since the support for custom captions required the most adjustments out of the features added in this PR (the rest required relatively little modifications), if we don't agree on all the above features, another option would then be what @apolinario suggested except we keep the custom caption feature out of this PR, and have it only in the community examples like the pivotal tuning. Would that make more sense?

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Nov 13, 2023

@patrickvonplaten @sayakpaul I see what you mean.

Since the support for custom captions required the most adjustments out of the features added in this PR (the rest required relatively little modifications), if we don't agree on all the above features, another option would then be what @apolinario suggested except we keep the custom caption feature out of this PR, and have it only in the community examples like the pivotal tuning. Would that make more sense?

I'm ok adding it here if you feel strongly (do you have any links/data showing that custom captions is indispensable by now?). But if it's not indispensable that yes I'd be happy if we could move it to the community / research training scripts

@apolinario
Copy link
Collaborator

apolinario commented Nov 13, 2023

I'm ok adding it here if you feel strongly (do you have any links/data showing that custom captions is indispensable by now?).

Imo it is a bit indispensible for modern dreambooth tuning. This started being discussed about 1 year ago (https://www.reddit.com/r/StableDiffusion/comments/zcr644/make_better_dreambooth_style_models_by_using/) and has become the default for all fine-tuners to either use custom hand made captions or a captioning system like BLIP, with the trigger word appended. This bring the best of both worlds vs. exclusive trigger word in a lot of scenarios and is indispensable for modern dreambooth fine-tuning, incl all scripts, tools and UIs; and almost all modern LoRAs/models use it

Also, on the diffusers side, this is not a huge innovation on our end to add something like this, as the structure/inspiration was taken from the ControlNet training script (https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py)

@patrickvonplaten
Copy link
Contributor

What's the difference between dreambooth + custom captions and just text-to-image fine-tuning? Isn't the whole point of dreambooth to not have to have captions?

Also, on the diffusers side, this is not a huge innovation on our end to add something like this, as the structure/inspiration was taken from the ControlNet training script (https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py)

I don't fully understand this. ControlNet doesn't have custom captions

linoytsaban and others added 4 commits November 19, 2023 11:22
… avoid unnecessary dependency on datasets)-

1. user provides --dataset_name
2. user provides local dir --instance_data_dir that contains a metadata .jsonl file
3. user provides local dir --instance_data_dir that contains only images
in cases [1,2] we import datasets and use load_dataset method, in case [3] we process the data same as in the original script setting
)
load_as_dataset = True
else:
if not self.instance_data_root.exists():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this else: case we shouldn't use datasets. Can we make sure that images are loaded without using datasets?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. using the logic from before the change:

        if not self.instance_data_root.exists():
            raise ValueError("Instance images root doesn't exists.")

        self.instance_images_path = list(Path(instance_data_root).iterdir())
        self.num_instance_images = len(self.instance_images_path)

img_str += f"""
- text: '{validation_prompt if validation_prompt else ' ' }'
parameters:
negative_prompt: '-'
Copy link
Collaborator

@apolinario apolinario Nov 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
negative_prompt: '-'

we can just remove this line if there's no negative prompt

Copy link
Collaborator Author

@linoytsaban linoytsaban Nov 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for some reason it was breaking when I removed it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

committing a fix now


parser.add_argument("--adam_beta1", type=float, default=0.9,
help="The beta1 parameter for the Adam and Prodigy optimizers.")
parser.add_argument("--adam_beta2", type=float, default=0.999,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not tested this myself, but in an issue of the Prodigy repo some people reported using adam_beta2=0.99 to work better when fine-tuning diffusion with LoRA.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

try:
from datasets import load_dataset
except ImportError:
raise ImportError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

Weights for this model are available in Safetensors format.

[Download]({repo_id}/tree/main) them in the Files & versions tab.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the model card improvements!

help="The column of the dataset containing the instance prompt for each image",
)

parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[not related to this PR] @sayakpaul some example scripts use repeats and others use num_epochs => I think we should try to aim for consistent naming soon

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I see (https://github.com/search?q=repo%3Ahuggingface%2Fdiffusers%20repeats&type=code), repeats was likely introduced in Textual Inversion and num_epochs is what we follow for the 90% of the training scripts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repeats doesn't likely affect the other two. But for a regular practitioner, it can be quite confusing when they see these three popping up in the same place.

It should be clarified in the README.

@patrickvonplaten
Copy link
Contributor

Thanks a mille for bearing with me through the PR 🙏

@patrickvonplaten patrickvonplaten merged commit 6fac136 into huggingface:main Nov 21, 2023
20 checks passed
@jmaccall316
Copy link

Hello, can anyone answer this?
If I use custom captions, will it still read instance prompt? That would be interesting to use it as an overall class style prompt to cover the subject, with the possibility of maybe not even having to mention the subject in the captions. This could be useful for subjects not already known or properly represented by the model.

affromero pushed a commit to affromero/diffusers that referenced this pull request Nov 24, 2023
…#5508)

* Additions:
- support for different lr for text encoder
- support for Prodigy optimizer
- support for min snr gamma
- support for custom captions and dataset loading from the hub

* adjusted --caption_column behaviour (to -not- use the second column of the dataset by default if --caption_column is not provided)

* fixed --output_dir / --model_dir_name confusion

* added --repeats, --adam_weight_decay_text_encoder
+ some fixes

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py

Co-authored-by: Patrick von Platen <[email protected]>

* - import compute_snr from diffusers/training_utils.py
- cluster adamw together
- when using 'prodigy', if --train_text_encoder == True and --text_encoder_lr != --learning rate, changes the lr of the text encoders optimization params to be --learning_rate (otherwise errors)

* shape fixes when custom captions are used

* formatting and a little cleanup

* code styling

* --repeats default value fixed, changed to 1

* bug fix - removed redundant lines of embedding concatenation when using prior_preservation (that duplicated class_prompt embeddings)

* changed dataset loading logic according to the following usecases (to avoid unnecessary dependency on datasets)-
1. user provides --dataset_name
2. user provides local dir --instance_data_dir that contains a metadata .jsonl file
3. user provides local dir --instance_data_dir that contains only images
in cases [1,2] we import datasets and use load_dataset method, in case [3] we process the data same as in the original script setting

* styling fix

* arg name fix

* adjusted the --repeats logic

* -removed redundant arg and 'if' when loading local folder with prompts
-updated readme template
-some default val fixes
-custom caption tests

* image path fix for readme

* code style

* bug fix

* --caption_column arg

* readme fix

---------

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Linoy Tsaban <[email protected]>
@linoytsaban linoytsaban deleted the dreambooth_lora_xl branch January 5, 2024 06:41
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…#5508)

* Additions:
- support for different lr for text encoder
- support for Prodigy optimizer
- support for min snr gamma
- support for custom captions and dataset loading from the hub

* adjusted --caption_column behaviour (to -not- use the second column of the dataset by default if --caption_column is not provided)

* fixed --output_dir / --model_dir_name confusion

* added --repeats, --adam_weight_decay_text_encoder
+ some fixes

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py

Co-authored-by: Patrick von Platen <[email protected]>

* - import compute_snr from diffusers/training_utils.py
- cluster adamw together
- when using 'prodigy', if --train_text_encoder == True and --text_encoder_lr != --learning rate, changes the lr of the text encoders optimization params to be --learning_rate (otherwise errors)

* shape fixes when custom captions are used

* formatting and a little cleanup

* code styling

* --repeats default value fixed, changed to 1

* bug fix - removed redundant lines of embedding concatenation when using prior_preservation (that duplicated class_prompt embeddings)

* changed dataset loading logic according to the following usecases (to avoid unnecessary dependency on datasets)-
1. user provides --dataset_name
2. user provides local dir --instance_data_dir that contains a metadata .jsonl file
3. user provides local dir --instance_data_dir that contains only images
in cases [1,2] we import datasets and use load_dataset method, in case [3] we process the data same as in the original script setting

* styling fix

* arg name fix

* adjusted the --repeats logic

* -removed redundant arg and 'if' when loading local folder with prompts
-updated readme template
-some default val fixes
-custom caption tests

* image path fix for readme

* code style

* bug fix

* --caption_column arg

* readme fix

---------

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Linoy Tsaban <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants