Skip to content

Commit

Permalink
[CodeCamp2023-645]Add dreambooth new cfg (#2042)
Browse files Browse the repository at this point in the history
* new config of dreambooth

* add dreambooth mmagic new_config

* fix import name bug

---------

Co-authored-by: YanxingLiu <[email protected]>
Co-authored-by: rangoliu <[email protected]>
  • Loading branch information
3 people authored Oct 19, 2023
1 parent 9384f4b commit 80f9120
Show file tree
Hide file tree
Showing 5 changed files with 300 additions and 0 deletions.
95 changes: 95 additions & 0 deletions mmagic/configs/dreambooth/dreambooth-finetune_text_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
from .._base_.gen_default_runtime import *

from mmengine.dataset.sampler import InfiniteSampler
from torch.optim import AdamW

from mmagic.datasets.dreambooth_dataset import DreamBoothDataset
from mmagic.datasets.transforms.aug_shape import Resize
from mmagic.datasets.transforms.formatting import PackInputs
from mmagic.datasets.transforms.loading import LoadImageFromFile
from mmagic.engine import VisualizationHook
from mmagic.models.data_preprocessors.data_preprocessor import DataPreprocessor
from mmagic.models.editors.disco_diffusion.clip_wrapper import ClipWrapper
from mmagic.models.editors.dreambooth import DreamBooth

# config for model
stable_diffusion_v15_url = 'runwayml/stable-diffusion-v1-5'

val_prompts = [
'a sks dog in basket', 'a sks dog on the mountain',
'a sks dog beside a swimming pool', 'a sks dog on the desk',
'a sleeping sks dog', 'a screaming sks dog', 'a man in the garden'
]

model = dict(
type=DreamBooth,
vae=dict(
type='AutoencoderKL',
from_pretrained=stable_diffusion_v15_url,
subfolder='vae'),
unet=dict(
type='UNet2DConditionModel',
from_pretrained=stable_diffusion_v15_url,
subfolder='unet',
),
text_encoder=dict(
type=ClipWrapper,
clip_type='huggingface',
pretrained_model_name_or_path=stable_diffusion_v15_url,
subfolder='text_encoder'),
tokenizer=stable_diffusion_v15_url,
finetune_text_encoder=True,
scheduler=dict(
type='DDPMScheduler',
from_pretrained=stable_diffusion_v15_url,
subfolder='scheduler'),
test_scheduler=dict(
type='DDIMScheduler',
from_pretrained=stable_diffusion_v15_url,
subfolder='scheduler'),
data_preprocessor=dict(type=DataPreprocessor),
val_prompts=val_prompts)

train_cfg = dict(max_iters=1000)

optim_wrapper.update(
modules='.*unet',
optimizer=dict(type=AdamW, lr=5e-6),
accumulative_counts=4 # batch size = 4 * 1 = 4
)

pipeline = [
dict(type=LoadImageFromFile, key='img', channel_order='rgb'),
dict(type=Resize, scale=(512, 512)),
dict(type=PackInputs)
]

dataset = dict(
type=DreamBoothDataset,
data_root='./data/dreambooth',
concept_dir='imgs',
prompt='a photo of sks dog',
pipeline=pipeline)
train_dataloader = dict(
dataset=dataset,
num_workers=16,
sampler=dict(type=InfiniteSampler, shuffle=True),
persistent_workers=True,
batch_size=1)
val_cfg = val_evaluator = val_dataloader = None
test_cfg = test_evaluator = test_dataloader = None

# hooks
default_hooks.update(dict(logger=dict(interval=10)))
custom_hooks = [
dict(
type=VisualizationHook,
interval=50,
fixed_input=True,
vis_kwargs_list=dict(type='Data', name='fake_img'),
n_samples=1)
]
8 changes: 8 additions & 0 deletions mmagic/configs/dreambooth/dreambooth-prior_pre.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
from .dreambooth import *

# config for model
model.update(dict(prior_loss_weight=1, class_prior_prompt='a dog'))
93 changes: 93 additions & 0 deletions mmagic/configs/dreambooth/dreambooth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
from .._base_.gen_default_runtime import *

from mmengine.dataset.sampler import InfiniteSampler
from torch.optim import AdamW

from mmagic.datasets.dreambooth_dataset import DreamBoothDataset
from mmagic.datasets.transforms.aug_shape import Resize
from mmagic.datasets.transforms.formatting import PackInputs
from mmagic.datasets.transforms.loading import LoadImageFromFile
from mmagic.engine import VisualizationHook
from mmagic.models.data_preprocessors.data_preprocessor import DataPreprocessor
from mmagic.models.editors.disco_diffusion.clip_wrapper import ClipWrapper
from mmagic.models.editors.dreambooth import DreamBooth

stable_diffusion_v15_url = 'runwayml/stable-diffusion-v1-5'

val_prompts = [
'a sks dog in basket', 'a sks dog on the mountain',
'a sks dog beside a swimming pool', 'a sks dog on the desk',
'a sleeping sks dog', 'a screaming sks dog', 'a man in the garden'
]

model = dict(
type=DreamBooth,
vae=dict(
type='AutoencoderKL',
from_pretrained=stable_diffusion_v15_url,
subfolder='vae'),
unet=dict(
type='UNet2DConditionModel',
from_pretrained=stable_diffusion_v15_url,
subfolder='unet',
),
text_encoder=dict(
type=ClipWrapper,
clip_type='huggingface',
pretrained_model_name_or_path=stable_diffusion_v15_url,
subfolder='text_encoder'),
tokenizer=stable_diffusion_v15_url,
scheduler=dict(
type='DDPMScheduler',
from_pretrained=stable_diffusion_v15_url,
subfolder='scheduler'),
test_scheduler=dict(
type='DDIMScheduler',
from_pretrained=stable_diffusion_v15_url,
subfolder='scheduler'),
data_preprocessor=dict(type=DataPreprocessor),
val_prompts=val_prompts)

train_cfg = dict(max_iters=1000)

optim_wrapper.update(
modules='.*unet',
optimizer=dict(type=AdamW, lr=5e-6),
accumulative_counts=4 # batch size = 4 * 1 = 4
)

pipeline = [
dict(type=LoadImageFromFile, key='img', channel_order='rgb'),
dict(type=Resize, scale=(512, 512)),
dict(type=PackInputs)
]

dataset = dict(
type=DreamBoothDataset,
data_root='./data/dreambooth',
concept_dir='imgs',
prompt='a photo of sks dog',
pipeline=pipeline)
train_dataloader = dict(
dataset=dataset,
num_workers=16,
sampler=dict(type=InfiniteSampler, shuffle=True),
persistent_workers=True,
batch_size=1)
val_cfg = val_evaluator = val_dataloader = None
test_cfg = test_evaluator = test_dataloader = None

# hooks
default_hooks.update(dict(logger=dict(interval=10)))
custom_hooks = [
dict(
type=VisualizationHook,
interval=50,
fixed_input=True,
vis_kwargs_list=dict(type='Data', name='fake_img'),
n_samples=1)
]
7 changes: 7 additions & 0 deletions mmagic/configs/dreambooth/dreambooth_lora-prior_pre.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
from .dreambooth_lora import *

model.update(dict(prior_loss_weight=1, class_prior_prompt='a dog'))
97 changes: 97 additions & 0 deletions mmagic/configs/dreambooth/dreambooth_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
from .._base_.gen_default_runtime import *

from mmengine.dataset.sampler import InfiniteSampler
from torch.optim import AdamW

from mmagic.datasets.dreambooth_dataset import DreamBoothDataset
from mmagic.datasets.transforms.aug_shape import Resize
from mmagic.datasets.transforms.formatting import PackInputs
from mmagic.datasets.transforms.loading import LoadImageFromFile
from mmagic.engine import VisualizationHook
from mmagic.models.data_preprocessors.data_preprocessor import DataPreprocessor
from mmagic.models.editors.disco_diffusion.clip_wrapper import ClipWrapper
from mmagic.models.editors.dreambooth import DreamBooth

stable_diffusion_v15_url = 'runwayml/stable-diffusion-v1-5'

val_prompts = [
'a sks dog in basket', 'a sks dog on the mountain',
'a sks dog beside a swimming pool', 'a sks dog on the desk',
'a sleeping sks dog', 'a screaming sks dog', 'a man in the garden'
]
lora_config = dict(target_modules=['to_q', 'to_k', 'to_v'])

model = dict(
type=DreamBooth,
vae=dict(
type='AutoencoderKL',
from_pretrained=stable_diffusion_v15_url,
subfolder='vae'),
unet=dict(
type='UNet2DConditionModel',
from_pretrained=stable_diffusion_v15_url,
subfolder='unet',
),
text_encoder=dict(
type=ClipWrapper,
clip_type='huggingface',
pretrained_model_name_or_path=stable_diffusion_v15_url,
subfolder='text_encoder'),
tokenizer=stable_diffusion_v15_url,
scheduler=dict(
type='DDPMScheduler',
from_pretrained=stable_diffusion_v15_url,
subfolder='scheduler'),
test_scheduler=dict(
type='DDIMScheduler',
from_pretrained=stable_diffusion_v15_url,
subfolder='scheduler'),
data_preprocessor=dict(type=DataPreprocessor),
prior_loss_weight=0,
val_prompts=val_prompts,
lora_config=lora_config)

train_cfg = dict(max_iters=1000)

optim_wrapper = dict(
# Only optimize LoRA mappings
modules='.*.lora_mapping',
# NOTE: lr should be larger than dreambooth finetuning
optimizer=dict(type=AdamW, lr=5e-4),
accumulative_counts=1)

pipeline = [
dict(type=LoadImageFromFile, key='img', channel_order='rgb'),
dict(type=Resize, scale=(512, 512)),
dict(type=PackInputs)
]
dataset = dict(
type=DreamBoothDataset,
data_root='./data/dreambooth',
# TODO: rename to instance
concept_dir='imgs',
prompt='a photo of sks dog',
pipeline=pipeline)
train_dataloader = dict(
dataset=dataset,
num_workers=16,
sampler=dict(type=InfiniteSampler, shuffle=True),
persistent_workers=True,
batch_size=1)
val_cfg = val_evaluator = val_dataloader = None
test_cfg = test_evaluator = test_dataloader = None

# hooks
default_hooks.update(dict(logger=dict(interval=10)))
custom_hooks = [
dict(
type=VisualizationHook,
interval=50,
fixed_input=True,
vis_kwargs_list=dict(type='Data', name='fake_img'),
n_samples=1)
]

0 comments on commit 80f9120

Please sign in to comment.