Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] register models and scheduelrs from diffusers #1692

Merged
merged 8 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
secondary_model = dict(type='SecondaryDiffusionImageNet2')

diffusion_scheduler = dict(
type='DDIMScheduler',
type='EditDDIMScheduler',
variance_type='learned_range',
beta_schedule='linear',
clip_sample=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
secondary_model = dict(type='SecondaryDiffusionImageNet2')

diffusion_scheduler = dict(
type='DDIMScheduler',
type='EditDDIMScheduler',
variance_type='learned_range',
beta_schedule='linear',
clip_sample=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
])

diffusion_scheduler = dict(
type='DDIMScheduler',
type='EditDDIMScheduler',
variance_type='learned_range',
beta_end=0.012,
beta_schedule='scaled_linear',
Expand Down
36 changes: 36 additions & 0 deletions mmedit/models/base_archs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
# To register Deconv
import warnings
from typing import List

from mmedit.utils import try_import
from .all_gather_layer import AllGatherLayer
from .aspp import ASPP
from .conv import * # noqa: F401, F403
Expand All @@ -19,6 +23,38 @@
from .upsample import PixelShufflePack
from .vgg import VGG16


def register_diffusers_models() -> List[str]:
"""Register models in ``diffusers.models`` to the ``MODELS`` registry.
zengyh1900 marked this conversation as resolved.
Show resolved Hide resolved
Specifically, the registered models from diffusers only defines the network
forward without training. See more details about diffusers in:
https://huggingface.co/docs/diffusers/api/models.

Returns:
List[str]: A list of registered DIFFUSION_MODELS' name.
"""
import inspect

from mmedit.registry import MODELS

diffusers = try_import('diffusers')
if diffusers is None:
warnings.warn('Diffusion Models are not registered as expect. '
'If you want to use diffusion models, '
'please install diffusers>=0.12.0.')
return None

DIFFUSERS_MODELS = []
for module_name in dir(diffusers.models):
module = getattr(diffusers.models, module_name)
if inspect.isclass(module):
zengyh1900 marked this conversation as resolved.
Show resolved Hide resolved
MODELS.register_module(name=module_name, module=module)
DIFFUSERS_MODELS.append(module_name)
return DIFFUSERS_MODELS


REGISTERED_DIFFUSERS_MODELS = register_diffusers_models()

__all__ = [
'ASPP', 'DepthwiseSeparableConvModule', 'SimpleGatedConvModule',
'LinearModule', 'conv2d', 'conv_transpose2d', 'pixel_unshuffle',
Expand Down
48 changes: 48 additions & 0 deletions mmedit/models/diffusion_schedulers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import List

from mmedit.utils import try_import
from .ddim_scheduler import EditDDIMScheduler
from .ddpm_scheduler import EditDDPMScheduler


def register_diffusers_schedulers() -> List[str]:
"""Register schedulers in ``diffusers.schedulers`` to the
``DIFFUSION_SCHEDULERS`` registry. Specifically, the registered schedulers
from diffusers define the methodology for iteratively adding noise to an
image or for updating a sample based on model outputs. See more details
about schedulers in diffusers here:
https://huggingface.co/docs/diffusers/api/schedulers/overview.

Returns:
List[str]: A list of registered DIFFUSION_SCHEDULERS' name.
"""

import inspect

from mmedit.registry import DIFFUSION_SCHEDULERS

diffusers = try_import('diffusers')
if diffusers is None:
warnings.warn('Diffusion Schedulers are not registered as expect. '
'If you want to use diffusion models, '
'please install diffusers>=0.12.0.')
return None

DIFFUSERS_SCHEDULERS = []
for module_name in dir(diffusers.schedulers):
if module_name.startswith('Flax'):
continue
elif module_name.endswith('Scheduler'):
_scheduler = getattr(diffusers.schedulers, module_name)
if inspect.isclass(_scheduler):
DIFFUSION_SCHEDULERS.register_module(
name=module_name, module=_scheduler)
DIFFUSERS_SCHEDULERS.append(module_name)
return DIFFUSERS_SCHEDULERS


REGISTERED_DIFFUSERS_SCHEDULERS = register_diffusers_schedulers()

__all__ = ['EditDDIMScheduler', 'EditDDPMScheduler']
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@


@DIFFUSION_SCHEDULERS.register_module()
class DDIMScheduler:
"""```DDIMScheduler``` support the diffusion and reverse process formulated
in https://arxiv.org/abs/2010.02502.
class EditDDIMScheduler:
"""```EditDDIMScheduler``` support the diffusion and reverse process
formulated in https://arxiv.org/abs/2010.02502.

The code is heavily influenced by https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py. # noqa
The difference is that we ensemble gradient-guided sampling in step function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@DIFFUSION_SCHEDULERS.register_module()
class DDPMScheduler:
class EditDDPMScheduler:

def __init__(self,
num_train_timesteps: int = 1000,
Expand All @@ -19,7 +19,7 @@ def __init__(self,
trained_betas: Optional[Union[np.array, list]] = None,
variance_type='fixed_small',
clip_sample=True):
"""```DDPMScheduler``` support the diffusion and reverse process
"""```EditDDPMScheduler``` support the diffusion and reverse process
formulated in https://arxiv.org/abs/2006.11239.

The code is heavily influenced by https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py. # noqa
Expand Down
8 changes: 3 additions & 5 deletions mmedit/models/editors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from .cain import CAIN, CAINNet
from .cyclegan import CycleGAN
from .dcgan import DCGAN
from .ddim import DDIMScheduler
from .ddpm import DDPMScheduler, DenoisingUnet
from .ddpm import DenoisingUnet
from .deepfillv1 import (ContextualAttentionModule, ContextualAttentionNeck,
DeepFillDecoder, DeepFillEncoder, DeepFillRefiner,
DeepFillv1Discriminators, DeepFillv1Inpaintor)
Expand Down Expand Up @@ -85,7 +84,6 @@
'ProgressiveGrowingGAN', 'SinGAN', 'AblatedDiffusionModel',
'DiscoDiffusion', 'IDLossModel', 'PESinGAN', 'MSPIEStyleGAN2',
'StyleGAN3Generator', 'InstColorization', 'NAFBaseline',
'NAFBaselineLocal', 'NAFNet', 'NAFNetLocal', 'DDIMScheduler',
'DDPMScheduler', 'DenoisingUnet', 'ClipWrapper', 'EG3D', 'Restormer',
'SwinIRNet', 'StableDiffusion'
'NAFBaselineLocal', 'NAFNet', 'NAFNetLocal', 'DenoisingUnet',
'ClipWrapper', 'EG3D', 'Restormer', 'SwinIRNet', 'StableDiffusion'
]
4 changes: 0 additions & 4 deletions mmedit/models/editors/ddim/__init__.py

This file was deleted.

3 changes: 1 addition & 2 deletions mmedit/models/editors/ddpm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ddpm_scheduler import DDPMScheduler
from .denoising_unet import DenoisingUnet

__all__ = ['DDPMScheduler', 'DenoisingUnet']
__all__ = ['DenoisingUnet']
2 changes: 1 addition & 1 deletion mmedit/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
# modules for diffusion models that support adding noise and denoising
DIFFUSION_SCHEDULERS = Registry(
'diffusion scheduler',
locations=['mmedit.models'],
locations=['mmedit.models.diffusion_schedulers'],
)

#######################################################################
Expand Down
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
av
av==8.0.3; python_version < '3.7'
diffusers>=0.12.0
einops
face-alignment
facexlib
Expand Down
Binary file added tests/data/video_interpolation_result.mp4
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from torchvision.version import __version__ as TV_VERSION

from mmedit.apis.inferencers.text2image_inferencer import Text2ImageInferencer
from mmedit.models import DDIMScheduler, DenoisingUnet, DiscoDiffusion
from mmedit.models import DenoisingUnet, DiscoDiffusion
from mmedit.models.diffusion_schedulers import EditDDIMScheduler
from mmedit.utils import register_all_modules

register_all_modules()
Expand Down Expand Up @@ -66,7 +67,7 @@ def setUp(self):
# mock clip
self.clip_models = [clip_mock_wrapper(), clip_mock_wrapper()]
# diffusion_scheduler
self.diffusion_scheduler = DDIMScheduler(
self.diffusion_scheduler = EditDDIMScheduler(
variance_type='learned_range',
beta_schedule='linear',
clip_sample=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import pytest
import torch

from mmedit.models.editors.ddim.ddim_scheduler import DDIMScheduler
from mmedit.models.diffusion_schedulers.ddim_scheduler import EditDDIMScheduler


def test_ddim():
modelout = torch.rand((1, 8, 32, 32))
sample = torch.rand((1, 4, 32, 32))
ddim = DDIMScheduler(
ddim = EditDDIMScheduler(
num_train_timesteps=1000, variance_type='learned_range')
ddim.set_timesteps(10)
result = ddim.step(modelout, 980, sample)
Expand All @@ -22,22 +22,22 @@ def test_ddim():


def test_ddim_init():
ddim = DDIMScheduler(
ddim = EditDDIMScheduler(
num_train_timesteps=1000, beta_schedule='scaled_linear')

ddim = DDIMScheduler(
ddim = EditDDIMScheduler(
num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')

assert isinstance(ddim, DDIMScheduler)
assert isinstance(ddim, EditDDIMScheduler)

with pytest.raises(Exception):
DDIMScheduler(num_train_timesteps=1000, beta_schedule='fake')
EditDDIMScheduler(num_train_timesteps=1000, beta_schedule='fake')


def test_ddim_step():
modelout = torch.rand((1, 8, 32, 32))
sample = torch.rand((1, 4, 32, 32))
ddim = DDIMScheduler(
ddim = EditDDIMScheduler(
num_train_timesteps=1000, variance_type='learned_range')
with pytest.raises(Exception):
ddim.step(modelout, 980, sample)
Expand Down
12 changes: 6 additions & 6 deletions tests/test_models/test_editors/test_ddpm/test_ddpm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import pytest
import torch

from mmedit.models.editors.ddpm.ddpm_scheduler import DDPMScheduler
from mmedit.models.diffusion_schedulers.ddpm_scheduler import EditDDPMScheduler


def test_ddpm():
modelout = torch.rand((1, 8, 32, 32))
sample = torch.rand((1, 4, 32, 32))
ddpm = DDPMScheduler(
ddpm = EditDDPMScheduler(
num_train_timesteps=1000, variance_type='learned_range')
result = ddpm.step(modelout, 980, sample)
assert result['prev_sample'].shape == (1, 4, 32, 32)
Expand All @@ -32,11 +32,11 @@ def test_ddpm():


def test_ddpm_init():
DDPMScheduler(trained_betas=1)
EditDDPMScheduler(trained_betas=1)

DDPMScheduler(beta_schedule='scaled_linear')
EditDDPMScheduler(beta_schedule='scaled_linear')

DDPMScheduler(beta_schedule='squaredcos_cap_v2')
EditDDPMScheduler(beta_schedule='squaredcos_cap_v2')

with pytest.raises(Exception):
DDPMScheduler(beta_schedule='tem')
EditDDPMScheduler(beta_schedule='tem')
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from mmengine.utils import digit_version
from torchvision.version import __version__ as TV_VERSION

from mmedit.models import DDIMScheduler, DenoisingUnet, DiscoDiffusion
from mmedit.models import DenoisingUnet, DiscoDiffusion
from mmedit.models.diffusion_schedulers import EditDDIMScheduler
from mmedit.utils import register_all_modules

register_all_modules()
Expand Down Expand Up @@ -66,7 +67,7 @@ def setUp(self):
# mock clip
self.clip_models = [clip_mock_wrapper(), clip_mock_wrapper()]
# diffusion_scheduler
self.diffusion_scheduler = DDIMScheduler(
self.diffusion_scheduler = EditDDIMScheduler(
variance_type='learned_range',
beta_schedule='linear',
clip_sample=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def setup_class(cls):
use_new_attention_order=True),
use_scale_shift_norm=True),
diffusion_scheduler=dict(
type='DDIMScheduler',
type='EditDDIMScheduler',
variance_type='learned_range',
beta_schedule='squaredcos_cap_v2'),
rgb2bgr=True,
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_infer(self):
assert samples.shape == (1, 3, 64, 64)
# test with ddpm scheduler
scheduler_kwargs = dict(
type='DDPMScheduler',
type='EditDDPMScheduler',
variance_type='learned_range',
num_train_timesteps=5)
# test no label infer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
])

diffusion_scheduler = dict(
type='DDIMScheduler',
type='EditDDIMScheduler',
variance_type='learned_range',
beta_end=0.012,
beta_schedule='scaled_linear',
Expand Down