From 1af7dd45f9cc2784c13656f1d1af876f56fba80c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20Forst=C3=A9n?= Date: Wed, 15 Mar 2023 05:16:30 +0200 Subject: [PATCH] Controlnet training (#2545) * Controlnet training code initial commit Works with circle dataset: https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md * Script for adding a controlnet to existing model * Fix control image transform Control image should be in 0..1 range. * Add license header and remove more unused configs * controlnet training readme * Allow nonlocal model in add_controlnet.py * Formatting * Remove unused code * Code quality * Initialize controlnet in training script * Formatting * Address review comments * doc style * explicit constructor args and submodule names * hub dataset NOTE - not tested * empty prompts * add conditioning image * rename * remove instance data dir * image_transforms -> -1,1 . conditioning_image_transformers -> 0, 1 * nits * remove local rank config I think this isn't necessary in any of our training scripts * validation images * proportion_empty_prompts typo * weight copying to controlnet bug * call log validation fix * fix * gitignore wandb * fix progress bar and resume from checkpoint iteration * initial step fix * log multiple images * fix * fixes * tracker project name configurable * misc * add controlnet requirements.txt * update docs * image labels * small fixes * log validation using existing models for pipeline * fix for deepspeed saving * memory usage docs * Update examples/controlnet/train_controlnet.py Co-authored-by: Sayak Paul * Update examples/controlnet/train_controlnet.py Co-authored-by: Sayak Paul * Update examples/controlnet/README.md Co-authored-by: Sayak Paul * Update examples/controlnet/README.md Co-authored-by: Sayak Paul * Update examples/controlnet/README.md Co-authored-by: Sayak Paul * Update examples/controlnet/README.md Co-authored-by: Sayak Paul * Update examples/controlnet/README.md Co-authored-by: Sayak Paul * Update examples/controlnet/README.md Co-authored-by: Sayak Paul * Update examples/controlnet/README.md Co-authored-by: Sayak Paul * Update examples/controlnet/README.md Co-authored-by: Sayak Paul * remove extra is main process check * link to dataset in intro paragraph * remove unnecessary paragraph * note on deepspeed * Update examples/controlnet/README.md Co-authored-by: Patrick von Platen * assert -> value error * weights and biases note * move images out of git * remove .gitignore --------- Co-authored-by: William Berman Co-authored-by: Sayak Paul Co-authored-by: Patrick von Platen --- models/controlnet.py | 55 +++++++++++++++++++ .../pipeline_stable_diffusion_controlnet.py | 14 ++++- 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/models/controlnet.py b/models/controlnet.py index a9686dfc31bf..5895ae4de5b9 100644 --- a/models/controlnet.py +++ b/models/controlnet.py @@ -29,6 +29,7 @@ UNetMidBlock2DCrossAttn, get_down_block, ) +from .unet_2d_condition import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -257,6 +258,60 @@ def __init__( upcast_attention=upcast_attention, ) + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + ): + r""" + Instantiate Controlnet class from UNet2DConditionModel. + + Parameters: + unet (`UNet2DConditionModel`): + UNet model which weights are copied to the ControlNet. Note that all configuration options are also + copied where applicable. + """ + controlnet = cls( + in_channels=unet.config.in_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + attention_head_dim=unet.config.attention_head_dim, + use_linear_projection=unet.config.use_linear_projection, + class_embed_type=unet.config.class_embed_type, + num_class_embeds=unet.config.num_class_embeds, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + ) + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if controlnet.class_embedding: + controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) + + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) + + return controlnet + @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttnProcessor]: diff --git a/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index d6440214df1e..08643c6b891a 100644 --- a/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -611,9 +611,17 @@ def prepare_image( image = [image] if isinstance(image[0], PIL.Image.Image): - image = [ - np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image - ] + images = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + + image = images + image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2)