-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CodeCamp2023-645]Add dreambooth new cfg (#2042)
* new config of dreambooth * add dreambooth mmagic new_config * fix import name bug --------- Co-authored-by: YanxingLiu <[email protected]> Co-authored-by: rangoliu <[email protected]>
- Loading branch information
1 parent
9384f4b
commit 80f9120
Showing
5 changed files
with
300 additions
and
0 deletions.
There are no files selected for viewing
95 changes: 95 additions & 0 deletions
95
mmagic/configs/dreambooth/dreambooth-finetune_text_encoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from .._base_.gen_default_runtime import * | ||
|
||
from mmengine.dataset.sampler import InfiniteSampler | ||
from torch.optim import AdamW | ||
|
||
from mmagic.datasets.dreambooth_dataset import DreamBoothDataset | ||
from mmagic.datasets.transforms.aug_shape import Resize | ||
from mmagic.datasets.transforms.formatting import PackInputs | ||
from mmagic.datasets.transforms.loading import LoadImageFromFile | ||
from mmagic.engine import VisualizationHook | ||
from mmagic.models.data_preprocessors.data_preprocessor import DataPreprocessor | ||
from mmagic.models.editors.disco_diffusion.clip_wrapper import ClipWrapper | ||
from mmagic.models.editors.dreambooth import DreamBooth | ||
|
||
# config for model | ||
stable_diffusion_v15_url = 'runwayml/stable-diffusion-v1-5' | ||
|
||
val_prompts = [ | ||
'a sks dog in basket', 'a sks dog on the mountain', | ||
'a sks dog beside a swimming pool', 'a sks dog on the desk', | ||
'a sleeping sks dog', 'a screaming sks dog', 'a man in the garden' | ||
] | ||
|
||
model = dict( | ||
type=DreamBooth, | ||
vae=dict( | ||
type='AutoencoderKL', | ||
from_pretrained=stable_diffusion_v15_url, | ||
subfolder='vae'), | ||
unet=dict( | ||
type='UNet2DConditionModel', | ||
from_pretrained=stable_diffusion_v15_url, | ||
subfolder='unet', | ||
), | ||
text_encoder=dict( | ||
type=ClipWrapper, | ||
clip_type='huggingface', | ||
pretrained_model_name_or_path=stable_diffusion_v15_url, | ||
subfolder='text_encoder'), | ||
tokenizer=stable_diffusion_v15_url, | ||
finetune_text_encoder=True, | ||
scheduler=dict( | ||
type='DDPMScheduler', | ||
from_pretrained=stable_diffusion_v15_url, | ||
subfolder='scheduler'), | ||
test_scheduler=dict( | ||
type='DDIMScheduler', | ||
from_pretrained=stable_diffusion_v15_url, | ||
subfolder='scheduler'), | ||
data_preprocessor=dict(type=DataPreprocessor), | ||
val_prompts=val_prompts) | ||
|
||
train_cfg = dict(max_iters=1000) | ||
|
||
optim_wrapper.update( | ||
modules='.*unet', | ||
optimizer=dict(type=AdamW, lr=5e-6), | ||
accumulative_counts=4 # batch size = 4 * 1 = 4 | ||
) | ||
|
||
pipeline = [ | ||
dict(type=LoadImageFromFile, key='img', channel_order='rgb'), | ||
dict(type=Resize, scale=(512, 512)), | ||
dict(type=PackInputs) | ||
] | ||
|
||
dataset = dict( | ||
type=DreamBoothDataset, | ||
data_root='./data/dreambooth', | ||
concept_dir='imgs', | ||
prompt='a photo of sks dog', | ||
pipeline=pipeline) | ||
train_dataloader = dict( | ||
dataset=dataset, | ||
num_workers=16, | ||
sampler=dict(type=InfiniteSampler, shuffle=True), | ||
persistent_workers=True, | ||
batch_size=1) | ||
val_cfg = val_evaluator = val_dataloader = None | ||
test_cfg = test_evaluator = test_dataloader = None | ||
|
||
# hooks | ||
default_hooks.update(dict(logger=dict(interval=10))) | ||
custom_hooks = [ | ||
dict( | ||
type=VisualizationHook, | ||
interval=50, | ||
fixed_input=True, | ||
vis_kwargs_list=dict(type='Data', name='fake_img'), | ||
n_samples=1) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from .dreambooth import * | ||
|
||
# config for model | ||
model.update(dict(prior_loss_weight=1, class_prior_prompt='a dog')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from .._base_.gen_default_runtime import * | ||
|
||
from mmengine.dataset.sampler import InfiniteSampler | ||
from torch.optim import AdamW | ||
|
||
from mmagic.datasets.dreambooth_dataset import DreamBoothDataset | ||
from mmagic.datasets.transforms.aug_shape import Resize | ||
from mmagic.datasets.transforms.formatting import PackInputs | ||
from mmagic.datasets.transforms.loading import LoadImageFromFile | ||
from mmagic.engine import VisualizationHook | ||
from mmagic.models.data_preprocessors.data_preprocessor import DataPreprocessor | ||
from mmagic.models.editors.disco_diffusion.clip_wrapper import ClipWrapper | ||
from mmagic.models.editors.dreambooth import DreamBooth | ||
|
||
stable_diffusion_v15_url = 'runwayml/stable-diffusion-v1-5' | ||
|
||
val_prompts = [ | ||
'a sks dog in basket', 'a sks dog on the mountain', | ||
'a sks dog beside a swimming pool', 'a sks dog on the desk', | ||
'a sleeping sks dog', 'a screaming sks dog', 'a man in the garden' | ||
] | ||
|
||
model = dict( | ||
type=DreamBooth, | ||
vae=dict( | ||
type='AutoencoderKL', | ||
from_pretrained=stable_diffusion_v15_url, | ||
subfolder='vae'), | ||
unet=dict( | ||
type='UNet2DConditionModel', | ||
from_pretrained=stable_diffusion_v15_url, | ||
subfolder='unet', | ||
), | ||
text_encoder=dict( | ||
type=ClipWrapper, | ||
clip_type='huggingface', | ||
pretrained_model_name_or_path=stable_diffusion_v15_url, | ||
subfolder='text_encoder'), | ||
tokenizer=stable_diffusion_v15_url, | ||
scheduler=dict( | ||
type='DDPMScheduler', | ||
from_pretrained=stable_diffusion_v15_url, | ||
subfolder='scheduler'), | ||
test_scheduler=dict( | ||
type='DDIMScheduler', | ||
from_pretrained=stable_diffusion_v15_url, | ||
subfolder='scheduler'), | ||
data_preprocessor=dict(type=DataPreprocessor), | ||
val_prompts=val_prompts) | ||
|
||
train_cfg = dict(max_iters=1000) | ||
|
||
optim_wrapper.update( | ||
modules='.*unet', | ||
optimizer=dict(type=AdamW, lr=5e-6), | ||
accumulative_counts=4 # batch size = 4 * 1 = 4 | ||
) | ||
|
||
pipeline = [ | ||
dict(type=LoadImageFromFile, key='img', channel_order='rgb'), | ||
dict(type=Resize, scale=(512, 512)), | ||
dict(type=PackInputs) | ||
] | ||
|
||
dataset = dict( | ||
type=DreamBoothDataset, | ||
data_root='./data/dreambooth', | ||
concept_dir='imgs', | ||
prompt='a photo of sks dog', | ||
pipeline=pipeline) | ||
train_dataloader = dict( | ||
dataset=dataset, | ||
num_workers=16, | ||
sampler=dict(type=InfiniteSampler, shuffle=True), | ||
persistent_workers=True, | ||
batch_size=1) | ||
val_cfg = val_evaluator = val_dataloader = None | ||
test_cfg = test_evaluator = test_dataloader = None | ||
|
||
# hooks | ||
default_hooks.update(dict(logger=dict(interval=10))) | ||
custom_hooks = [ | ||
dict( | ||
type=VisualizationHook, | ||
interval=50, | ||
fixed_input=True, | ||
vis_kwargs_list=dict(type='Data', name='fake_img'), | ||
n_samples=1) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from .dreambooth_lora import * | ||
|
||
model.update(dict(prior_loss_weight=1, class_prior_prompt='a dog')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from .._base_.gen_default_runtime import * | ||
|
||
from mmengine.dataset.sampler import InfiniteSampler | ||
from torch.optim import AdamW | ||
|
||
from mmagic.datasets.dreambooth_dataset import DreamBoothDataset | ||
from mmagic.datasets.transforms.aug_shape import Resize | ||
from mmagic.datasets.transforms.formatting import PackInputs | ||
from mmagic.datasets.transforms.loading import LoadImageFromFile | ||
from mmagic.engine import VisualizationHook | ||
from mmagic.models.data_preprocessors.data_preprocessor import DataPreprocessor | ||
from mmagic.models.editors.disco_diffusion.clip_wrapper import ClipWrapper | ||
from mmagic.models.editors.dreambooth import DreamBooth | ||
|
||
stable_diffusion_v15_url = 'runwayml/stable-diffusion-v1-5' | ||
|
||
val_prompts = [ | ||
'a sks dog in basket', 'a sks dog on the mountain', | ||
'a sks dog beside a swimming pool', 'a sks dog on the desk', | ||
'a sleeping sks dog', 'a screaming sks dog', 'a man in the garden' | ||
] | ||
lora_config = dict(target_modules=['to_q', 'to_k', 'to_v']) | ||
|
||
model = dict( | ||
type=DreamBooth, | ||
vae=dict( | ||
type='AutoencoderKL', | ||
from_pretrained=stable_diffusion_v15_url, | ||
subfolder='vae'), | ||
unet=dict( | ||
type='UNet2DConditionModel', | ||
from_pretrained=stable_diffusion_v15_url, | ||
subfolder='unet', | ||
), | ||
text_encoder=dict( | ||
type=ClipWrapper, | ||
clip_type='huggingface', | ||
pretrained_model_name_or_path=stable_diffusion_v15_url, | ||
subfolder='text_encoder'), | ||
tokenizer=stable_diffusion_v15_url, | ||
scheduler=dict( | ||
type='DDPMScheduler', | ||
from_pretrained=stable_diffusion_v15_url, | ||
subfolder='scheduler'), | ||
test_scheduler=dict( | ||
type='DDIMScheduler', | ||
from_pretrained=stable_diffusion_v15_url, | ||
subfolder='scheduler'), | ||
data_preprocessor=dict(type=DataPreprocessor), | ||
prior_loss_weight=0, | ||
val_prompts=val_prompts, | ||
lora_config=lora_config) | ||
|
||
train_cfg = dict(max_iters=1000) | ||
|
||
optim_wrapper = dict( | ||
# Only optimize LoRA mappings | ||
modules='.*.lora_mapping', | ||
# NOTE: lr should be larger than dreambooth finetuning | ||
optimizer=dict(type=AdamW, lr=5e-4), | ||
accumulative_counts=1) | ||
|
||
pipeline = [ | ||
dict(type=LoadImageFromFile, key='img', channel_order='rgb'), | ||
dict(type=Resize, scale=(512, 512)), | ||
dict(type=PackInputs) | ||
] | ||
dataset = dict( | ||
type=DreamBoothDataset, | ||
data_root='./data/dreambooth', | ||
# TODO: rename to instance | ||
concept_dir='imgs', | ||
prompt='a photo of sks dog', | ||
pipeline=pipeline) | ||
train_dataloader = dict( | ||
dataset=dataset, | ||
num_workers=16, | ||
sampler=dict(type=InfiniteSampler, shuffle=True), | ||
persistent_workers=True, | ||
batch_size=1) | ||
val_cfg = val_evaluator = val_dataloader = None | ||
test_cfg = test_evaluator = test_dataloader = None | ||
|
||
# hooks | ||
default_hooks.update(dict(logger=dict(interval=10))) | ||
custom_hooks = [ | ||
dict( | ||
type=VisualizationHook, | ||
interval=50, | ||
fixed_input=True, | ||
vis_kwargs_list=dict(type='Data', name='fake_img'), | ||
n_samples=1) | ||
] |