Skip to content

Commit

Permalink
[SANA LoRA] sana lora training tests and misc. (#10296)
Browse files Browse the repository at this point in the history
* sana lora training tests and misc.

* remove push to hub

* Update examples/dreambooth/train_dreambooth_lora_sana.py

Co-authored-by: Aryan <[email protected]>

---------

Co-authored-by: Aryan <[email protected]>
  • Loading branch information
sayakpaul and a-r-r-o-w authored Dec 23, 2024
1 parent 02c777c commit 76e2727
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 24 deletions.
206 changes: 206 additions & 0 deletions examples/dreambooth/test_dreambooth_lora_sana.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import sys
import tempfile

import safetensors


sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402


logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)


class DreamBoothLoRASANA(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
pretrained_model_name_or_path = "hf-internal-testing/tiny-sana-pipe"
script_path = "examples/dreambooth/train_dreambooth_lora_sana.py"
transformer_layer_type = "transformer_blocks.0.attn1.to_k"

def test_dreambooth_lora_sana(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()

test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)

def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()

test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)

def test_dreambooth_lora_layers(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lora_layers {self.transformer_layer_type}
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()

test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names. In this test, we only params of
# `self.transformer_layer_type` should be in the state dict.
starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
self.assertTrue(starts_with_transformer)

def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
--max_sequence_length 16
""".split()

test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)

def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
--max_sequence_length 166
""".split()

test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)

self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})

resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
--max_sequence_length 16
""".split()

resume_run_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + resume_run_args)

self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
23 changes: 12 additions & 11 deletions examples/dreambooth/train_dreambooth_lora_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ def main(args):

# Load scheduler and models
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler"
args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
text_encoder = Gemma2Model.from_pretrained(
Expand All @@ -964,15 +964,6 @@ def main(args):
vae.requires_grad_(False)
text_encoder.requires_grad_(False)

# Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = SanaPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=None,
transformer=None,
text_encoder=text_encoder,
tokenizer=tokenizer,
)

# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
Expand All @@ -993,6 +984,15 @@ def main(args):
# because Gemma2 is particularly suited for bfloat16.
text_encoder.to(dtype=torch.bfloat16)

# Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = SanaPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=None,
transformer=None,
text_encoder=text_encoder,
tokenizer=tokenizer,
)

if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()

Expand Down Expand Up @@ -1182,6 +1182,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
prompt_embeds = prompt_embeds.to(transformer.dtype)
return prompt_embeds, prompt_attention_mask

# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
Expand Down Expand Up @@ -1216,7 +1217,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
vae_config_scaling_factor = vae.config.scaling_factor
if args.cache_latents:
latents_cache = []
vae = vae.to("cuda")
vae = vae.to(accelerator.device)
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
Expand Down
20 changes: 10 additions & 10 deletions tests/lora/test_lora_layers_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import unittest

import torch
from transformers import Gemma2ForCausalLM, GemmaTokenizer
from transformers import Gemma2Model, GemmaTokenizer

from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
Expand Down Expand Up @@ -73,7 +73,7 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
vae_cls = AutoencoderDC
tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma"
text_encoder_cls, text_encoder_id = Gemma2ForCausalLM, "hf-internal-testing/dummy-gemma-for-diffusers"
text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers"

@property
def output_shape(self):
Expand Down Expand Up @@ -105,34 +105,34 @@ def get_dummy_inputs(self, with_generator=True):

return noise, input_ids, pipeline_inputs

@unittest.skip("Not supported in Sana.")
@unittest.skip("Not supported in SANA.")
def test_modify_padding_mode(self):
pass

@unittest.skip("Not supported in Mochi.")
@unittest.skip("Not supported in SANA.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass

@unittest.skip("Not supported in Mochi.")
@unittest.skip("Not supported in SANA.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass

@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_partial_text_lora(self):
pass

@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora(self):
pass

@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_and_scale(self):
pass

@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_fused(self):
pass

@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_save_load(self):
pass
6 changes: 3 additions & 3 deletions tests/pipelines/sana/test_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import numpy as np
import torch
from transformers import Gemma2Config, Gemma2ForCausalLM, GemmaTokenizer
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer

from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
from diffusers.utils.testing_utils import (
Expand Down Expand Up @@ -101,7 +101,7 @@ def get_dummy_components(self):
torch.manual_seed(0)
text_encoder_config = Gemma2Config(
head_dim=16,
hidden_size=32,
hidden_size=8,
initializer_range=0.02,
intermediate_size=64,
max_position_embeddings=8192,
Expand All @@ -112,7 +112,7 @@ def get_dummy_components(self):
vocab_size=8,
attn_implementation="eager",
)
text_encoder = Gemma2ForCausalLM(text_encoder_config)
text_encoder = Gemma2Model(text_encoder_config)
tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")

components = {
Expand Down

0 comments on commit 76e2727

Please sign in to comment.