Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp2023-648]MMagic 新 config 体验与适配 GuidedDiffusion #2005

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions mmagic/configs/_base_/datasets/imagenet_512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset.sampler import DefaultSampler

from mmagic.datasets.imagenet_dataset import ImageNet
from mmagic.datasets.transforms.aug_shape import Flip, Resize
from mmagic.datasets.transforms.crop import (CenterCropLongEdge,
RandomCropLongEdge)
from mmagic.datasets.transforms.formatting import PackInputs
from mmagic.datasets.transforms.loading import LoadImageFromFile

# dataset settings
dataset_type = ImageNet

# different from mmcls, we adopt the setting used in BigGAN.
# We use `RandomCropLongEdge` in training and `CenterCropLongEdge` in testing.
train_pipeline = [
dict(type=LoadImageFromFile, key='gt'),
dict(type=RandomCropLongEdge, keys='gt'),
dict(type=Resize, scale=(512, 512), keys='gt', backend='pillow'),
dict(type=Flip, keys='gt', flip_ratio=0.5, direction='horizontal'),
dict(type=PackInputs)
]

test_pipeline = [
dict(type=LoadImageFromFile, key='gt'),
dict(type=CenterCropLongEdge, keys='gt'),
dict(type=Resize, scale=(512, 512), keys='gt', backend='pillow'),
dict(type=PackInputs)
]

train_dataloader = dict(
batch_size=None,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='./data/imagenet/',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
persistent_workers=True)

val_dataloader = dict(
batch_size=None,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='./data/imagenet/',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=test_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
persistent_workers=True)

test_dataloader = val_dataloader
55 changes: 55 additions & 0 deletions mmagic/configs/_base_/datasets/imagenet_64.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset.sampler import DefaultSampler

from mmagic.datasets.imagenet_dataset import ImageNet
from mmagic.datasets.transforms.aug_shape import Flip, Resize
from mmagic.datasets.transforms.crop import (CenterCropLongEdge,
RandomCropLongEdge)
from mmagic.datasets.transforms.formatting import PackInputs
from mmagic.datasets.transforms.loading import LoadImageFromFile

# dataset settings
dataset_type = ImageNet

# different from mmcls, we adopt the setting used in BigGAN.
# We use `RandomCropLongEdge` in training and `CenterCropLongEdge` in testing.
train_pipeline = [
dict(type=LoadImageFromFile, key='gt'),
dict(type=RandomCropLongEdge, keys='gt'),
dict(type=Resize, scale=(64, 64), keys='gt', backend='pillow'),
dict(type=Flip, keys='gt', flip_ratio=0.5, direction='horizontal'),
dict(type=PackInputs)
]

test_pipeline = [
dict(type=LoadImageFromFile, key='gt'),
dict(type=CenterCropLongEdge, keys='gt'),
dict(type=Resize, scale=(64, 64), keys='gt', backend='pillow'),
dict(type=PackInputs)
]

train_dataloader = dict(
batch_size=None,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='./data/imagenet/',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
persistent_workers=True)

val_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='./data/imagenet/',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=test_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
persistent_workers=True)

test_dataloader = val_dataloader
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
from .adm_ddim250_8xb32_imagenet_256x256 import * # noqa: F401,F403

from mmagic.evaluation.metrics import FrechetInceptionDistance
from mmagic.models.editors.guided_diffusion.classifier import EncoderUNetModel

model.update(
dict(
classifier=dict(
type=EncoderUNetModel,
image_size=256,
in_channels=3,
model_channels=128,
out_channels=1000,
num_res_blocks=2,
attention_resolutions=(8, 16, 32),
channel_mult=(1, 1, 2, 2, 4, 4),
use_fp16=False,
num_head_channels=64,
use_scale_shift_norm=True,
resblock_updown=True,
pool='attention')))

metrics = [
dict(
type=FrechetInceptionDistance,
prefix='FID-Full-50k',
fake_nums=50000,
inception_style='StyleGAN',
sample_model='orig',
sample_kwargs=dict(
num_inference_steps=250, show_progress=True, classifier_scale=1.))
]

val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
from .adm_ddim250_8xb32_imagenet_512x512 import * # noqa: F401,F403

from mmagic.evaluation.metrics import FrechetInceptionDistance
from mmagic.models.editors.guided_diffusion.classifier import EncoderUNetModel

model.update(
dict(
classifier=dict(
type=EncoderUNetModel,
image_size=512,
in_channels=3,
model_channels=128,
out_channels=1000,
num_res_blocks=2,
attention_resolutions=(16, 32, 64),
channel_mult=(0.5, 1, 1, 2, 2, 4, 4),
use_fp16=False,
num_head_channels=64,
use_scale_shift_norm=True,
resblock_updown=True,
pool='attention')))

metrics = [
dict(
type=FrechetInceptionDistance,
prefix='FID-Full-50k',
fake_nums=50000,
inception_style='StyleGAN',
sample_model='orig',
sample_kwargs=dict(
num_inference_steps=250, show_progress=True, classifier_scale=1.))
]

val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
from .adm_ddim250_8xb32_imagenet_64x64 import * # noqa: F401,F403

from mmagic.evaluation.metrics import FrechetInceptionDistance
from mmagic.models.editors.guided_diffusion.classifier import EncoderUNetModel

model.update(
dict(
classifier=dict(
type=EncoderUNetModel,
image_size=64,
in_channels=3,
model_channels=128,
out_channels=1000,
num_res_blocks=4,
attention_resolutions=(2, 4, 8),
channel_mult=(1, 2, 3, 4),
use_fp16=False,
num_head_channels=64,
use_scale_shift_norm=True,
resblock_updown=True,
pool='attention')))

metrics = [
dict(
type=FrechetInceptionDistance,
prefix='FID-Full-50k',
fake_nums=50000,
inception_style='StyleGAN',
sample_model='orig',
sample_kwargs=dict(
num_inference_steps=250, show_progress=True, classifier_scale=1.))
]

val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
from .._base_.datasets.imagenet_64 import * # noqa: F401,F403
from .._base_.gen_default_runtime import * # noqa: F401,F403

from mmagic.engine.hooks.visualization_hook import VisualizationHook
from mmagic.evaluation.metrics import FrechetInceptionDistance
from mmagic.models.data_preprocessors.data_preprocessor import DataPreprocessor
from mmagic.models.diffusion_schedulers.ddim_scheduler import EditDDIMScheduler
from mmagic.models.editors.ddpm.denoising_unet import (DenoisingUnet,
MultiHeadAttentionBlock)
from mmagic.models.editors.guided_diffusion.adm import AblatedDiffusionModel

model = dict(
type=AblatedDiffusionModel,
data_preprocessor=dict(type=DataPreprocessor),
unet=dict(
type=DenoisingUnet,
image_size=256,
in_channels=3,
base_channels=256,
resblocks_per_downsample=2,
attention_res=(32, 16, 8),
norm_cfg=dict(type='GN32', num_groups=32),
dropout=0.1,
num_classes=1000,
use_fp16=False,
resblock_updown=True,
attention_cfg=dict(
type=MultiHeadAttentionBlock,
num_heads=4,
num_head_channels=64,
use_new_attention_order=False),
use_scale_shift_norm=True),
diffusion_scheduler=dict(
type=EditDDIMScheduler,
variance_type='learned_range',
beta_schedule='linear'),
rgb2bgr=True,
use_fp16=False)

test_dataloader.update(dict(batch_size=32, num_workers=8))
train_cfg = dict(max_iters=100000)
metrics = [
dict(
type=FrechetInceptionDistance,
prefix='FID-Full-50k',
fake_nums=50000,
inception_style='StyleGAN',
sample_model='orig',
sample_kwargs=dict(
num_inference_steps=250, show_progress=True, classifier_scale=1.))
]

val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)

# VIS_HOOK
custom_hooks = [dict(type=VisualizationHook, interval=5000, fixed_input=True)]
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import read_base

with read_base():
from .._base_.datasets.imagenet_512 import * # noqa: F401,F403
from .._base_.gen_default_runtime import * # noqa: F401,F403

from mmagic.evaluation.metrics import FrechetInceptionDistance
from mmagic.models.data_preprocessors.data_preprocessor import DataPreprocessor
from mmagic.models.diffusion_schedulers.ddim_scheduler import EditDDIMScheduler
from mmagic.models.editors.ddpm.denoising_unet import (DenoisingUnet,
MultiHeadAttentionBlock)
from mmagic.models.editors.guided_diffusion import AblatedDiffusionModel

model = dict(
type=AblatedDiffusionModel,
data_preprocessor=dict(type=DataPreprocessor),
unet=dict(
type=DenoisingUnet,
image_size=512,
in_channels=3,
base_channels=256,
resblocks_per_downsample=2,
attention_res=(32, 16, 8),
norm_cfg=dict(type='GN32', num_groups=32),
dropout=0.1,
num_classes=1000,
use_fp16=False,
resblock_updown=True,
attention_cfg=dict(
type=MultiHeadAttentionBlock,
num_heads=4,
num_head_channels=64,
use_new_attention_order=False),
use_scale_shift_norm=True),
diffusion_scheduler=dict(
type=EditDDIMScheduler,
variance_type='learned_range',
beta_schedule='linear'),
rgb2bgr=True,
use_fp16=False)

test_dataloader.update(dict(batch_size=32, num_workers=8))
train_cfg = dict(max_iters=100000)
metrics = [
dict(
type=FrechetInceptionDistance,
prefix='FID-Full-50k',
fake_nums=50000,
inception_style='StyleGAN',
sample_model='orig',
sample_kwargs=dict(
num_inference_steps=250, show_progress=True, classifier_scale=1.))
]

val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)
Loading