diff --git a/examples/dreambooth/test_dreambooth_lora_sana.py b/examples/dreambooth/test_dreambooth_lora_sana.py new file mode 100644 index 000000000000..dfceb09a9736 --- /dev/null +++ b/examples/dreambooth/test_dreambooth_lora_sana.py @@ -0,0 +1,206 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + +import safetensors + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBoothLoRASANA(ExamplesTestsAccelerate): + instance_data_dir = "docs/source/en/imgs" + pretrained_model_name_or_path = "hf-internal-testing/tiny-sana-pipe" + script_path = "examples/dreambooth/train_dreambooth_lora_sana.py" + transformer_layer_type = "transformer_blocks.0.attn1.to_k" + + def test_dreambooth_lora_sana(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --resolution 32 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --max_sequence_length 16 + """.split() + + test_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_latent_caching(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --resolution 32 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --max_sequence_length 16 + """.split() + + test_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_layers(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --resolution 32 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lora_layers {self.transformer_layer_type} + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --max_sequence_length 16 + """.split() + + test_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. In this test, we only params of + # `self.transformer_layer_type` should be in the state dict. + starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --resolution=32 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + --max_sequence_length 16 + """.split() + + test_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --resolution=32 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=4 + --checkpointing_steps=2 + --max_sequence_length 166 + """.split() + + test_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + test_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"}) + + resume_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --resolution=32 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=8 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 + --max_sequence_length 16 + """.split() + + resume_run_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + resume_run_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index 4baa9f194feb..49c790ba04d7 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -943,7 +943,7 @@ def main(args): # Load scheduler and models noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - args.pretrained_model_name_or_path, subfolder="scheduler" + args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) text_encoder = Gemma2Model.from_pretrained( @@ -964,15 +964,6 @@ def main(args): vae.requires_grad_(False) text_encoder.requires_grad_(False) - # Initialize a text encoding pipeline and keep it to CPU for now. - text_encoding_pipeline = SanaPipeline.from_pretrained( - args.pretrained_model_name_or_path, - vae=None, - transformer=None, - text_encoder=text_encoder, - tokenizer=tokenizer, - ) - # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 @@ -993,6 +984,15 @@ def main(args): # because Gemma2 is particularly suited for bfloat16. text_encoder.to(dtype=torch.bfloat16) + # Initialize a text encoding pipeline and keep it to CPU for now. + text_encoding_pipeline = SanaPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=None, + transformer=None, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() @@ -1182,6 +1182,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): ) if args.offload: text_encoding_pipeline = text_encoding_pipeline.to("cpu") + prompt_embeds = prompt_embeds.to(transformer.dtype) return prompt_embeds, prompt_attention_mask # If no type of tuning is done on the text_encoder and custom instance prompts are NOT @@ -1216,7 +1217,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): vae_config_scaling_factor = vae.config.scaling_factor if args.cache_latents: latents_cache = [] - vae = vae.to("cuda") + vae = vae.to(accelerator.device) for batch in tqdm(train_dataloader, desc="Caching latents"): with torch.no_grad(): batch["pixel_values"] = batch["pixel_values"].to( diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py index 499ca89262a0..78f71527cb7e 100644 --- a/tests/lora/test_lora_layers_sana.py +++ b/tests/lora/test_lora_layers_sana.py @@ -16,7 +16,7 @@ import unittest import torch -from transformers import Gemma2ForCausalLM, GemmaTokenizer +from transformers import Gemma2Model, GemmaTokenizer from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel from diffusers.utils.testing_utils import floats_tensor, require_peft_backend @@ -73,7 +73,7 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): } vae_cls = AutoencoderDC tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma" - text_encoder_cls, text_encoder_id = Gemma2ForCausalLM, "hf-internal-testing/dummy-gemma-for-diffusers" + text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers" @property def output_shape(self): @@ -105,34 +105,34 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs - @unittest.skip("Not supported in Sana.") + @unittest.skip("Not supported in SANA.") def test_modify_padding_mode(self): pass - @unittest.skip("Not supported in Mochi.") + @unittest.skip("Not supported in SANA.") def test_simple_inference_with_text_denoiser_block_scale(self): pass - @unittest.skip("Not supported in Mochi.") + @unittest.skip("Not supported in SANA.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + @unittest.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + @unittest.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + @unittest.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + @unittest.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_text_lora_fused(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") + @unittest.skip("Text encoder LoRA is not supported in SANA.") def test_simple_inference_with_text_lora_save_load(self): pass diff --git a/tests/pipelines/sana/test_sana.py b/tests/pipelines/sana/test_sana.py index f8551fff8447..21de4e04437a 100644 --- a/tests/pipelines/sana/test_sana.py +++ b/tests/pipelines/sana/test_sana.py @@ -18,7 +18,7 @@ import numpy as np import torch -from transformers import Gemma2Config, Gemma2ForCausalLM, GemmaTokenizer +from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel from diffusers.utils.testing_utils import ( @@ -101,7 +101,7 @@ def get_dummy_components(self): torch.manual_seed(0) text_encoder_config = Gemma2Config( head_dim=16, - hidden_size=32, + hidden_size=8, initializer_range=0.02, intermediate_size=64, max_position_embeddings=8192, @@ -112,7 +112,7 @@ def get_dummy_components(self): vocab_size=8, attn_implementation="eager", ) - text_encoder = Gemma2ForCausalLM(text_encoder_config) + text_encoder = Gemma2Model(text_encoder_config) tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma") components = {