Skip to content

Commit

Permalink
Controlnet training (huggingface#2545)
Browse files Browse the repository at this point in the history
* Controlnet training code initial commit

Works with circle dataset: https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md

* Script for adding a controlnet to existing model

* Fix control image transform

Control image should be in 0..1 range.

* Add license header and remove more unused configs

* controlnet training readme

* Allow nonlocal model in add_controlnet.py

* Formatting

* Remove unused code

* Code quality

* Initialize controlnet in training script

* Formatting

* Address review comments

* doc style

* explicit constructor args and submodule names

* hub dataset

NOTE -  not tested

* empty prompts

* add conditioning image

* rename

* remove instance data dir

* image_transforms -> -1,1 . conditioning_image_transformers -> 0, 1

* nits

* remove local rank config

I think this isn't necessary in any of our training scripts

* validation images

* proportion_empty_prompts typo

* weight copying to controlnet bug

* call log validation fix

* fix

* gitignore wandb

* fix progress bar and resume from checkpoint iteration

* initial step fix

* log multiple images

* fix

* fixes

* tracker project name configurable

* misc

* add controlnet requirements.txt

* update docs

* image labels

* small fixes

* log validation using existing models for pipeline

* fix for deepspeed saving

* memory usage docs

* Update examples/controlnet/train_controlnet.py

Co-authored-by: Sayak Paul <[email protected]>

* Update examples/controlnet/train_controlnet.py

Co-authored-by: Sayak Paul <[email protected]>

* Update examples/controlnet/README.md

Co-authored-by: Sayak Paul <[email protected]>

* Update examples/controlnet/README.md

Co-authored-by: Sayak Paul <[email protected]>

* Update examples/controlnet/README.md

Co-authored-by: Sayak Paul <[email protected]>

* Update examples/controlnet/README.md

Co-authored-by: Sayak Paul <[email protected]>

* Update examples/controlnet/README.md

Co-authored-by: Sayak Paul <[email protected]>

* Update examples/controlnet/README.md

Co-authored-by: Sayak Paul <[email protected]>

* Update examples/controlnet/README.md

Co-authored-by: Sayak Paul <[email protected]>

* Update examples/controlnet/README.md

Co-authored-by: Sayak Paul <[email protected]>

* remove extra is main process check

* link to dataset in intro paragraph

* remove unnecessary paragraph

* note on deepspeed

* Update examples/controlnet/README.md

Co-authored-by: Patrick von Platen <[email protected]>

* assert -> value error

* weights and biases note

* move images out of git

* remove .gitignore

---------

Co-authored-by: William Berman <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
4 people authored Mar 15, 2023
1 parent 68552ed commit 1af7dd4
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 3 deletions.
55 changes: 55 additions & 0 deletions models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
UNetMidBlock2DCrossAttn,
get_down_block,
)
from .unet_2d_condition import UNet2DConditionModel


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -257,6 +258,60 @@ def __init__(
upcast_attention=upcast_attention,
)

@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
load_weights_from_unet: bool = True,
):
r"""
Instantiate Controlnet class from UNet2DConditionModel.
Parameters:
unet (`UNet2DConditionModel`):
UNet model which weights are copied to the ControlNet. Note that all configuration options are also
copied where applicable.
"""
controlnet = cls(
in_channels=unet.config.in_channels,
flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift,
down_block_types=unet.config.down_block_types,
only_cross_attention=unet.config.only_cross_attention,
block_out_channels=unet.config.block_out_channels,
layers_per_block=unet.config.layers_per_block,
downsample_padding=unet.config.downsample_padding,
mid_block_scale_factor=unet.config.mid_block_scale_factor,
act_fn=unet.config.act_fn,
norm_num_groups=unet.config.norm_num_groups,
norm_eps=unet.config.norm_eps,
cross_attention_dim=unet.config.cross_attention_dim,
attention_head_dim=unet.config.attention_head_dim,
use_linear_projection=unet.config.use_linear_projection,
class_embed_type=unet.config.class_embed_type,
num_class_embeds=unet.config.num_class_embeds,
upcast_attention=unet.config.upcast_attention,
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
)

if load_weights_from_unet:
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())

if controlnet.class_embedding:
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())

controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())

return controlnet

@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttnProcessor]:
Expand Down
14 changes: 11 additions & 3 deletions pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,9 +611,17 @@ def prepare_image(
image = [image]

if isinstance(image[0], PIL.Image.Image):
image = [
np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image
]
images = []

for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)

image = images

image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
Expand Down

0 comments on commit 1af7dd4

Please sign in to comment.