Skip to content

Commit

Permalink
Add multistep DPM-Solver discrete scheduler (#1132)
Browse files Browse the repository at this point in the history
* add dpmsolver discrete pytorch scheduler

* fix some typos in dpm-solver pytorch

* add dpm-solver pytorch in stable-diffusion pipeline

* add jax/flax version dpm-solver

* change code style

* change code style

* add docs

* add `add_noise` method for dpmsolver

* add pytorch unit test for dpmsolver

* add dummy object for pytorch dpmsolver

* Update src/diffusers/schedulers/scheduling_dpmsolver_discrete.py

Co-authored-by: Suraj Patil <[email protected]>

* Update tests/test_config.py

Co-authored-by: Suraj Patil <[email protected]>

* Update tests/test_config.py

Co-authored-by: Suraj Patil <[email protected]>

* resolve the code comments

* rename the file

* change class name

* fix code style

* add auto docs for dpmsolver multistep

* add more explanations for the stabilizing trick (for steps < 15)

* delete the dummy file

* change the API name of predict_epsilon, algorithm_type and solver_type

* add compatible lists

Co-authored-by: Suraj Patil <[email protected]>
  • Loading branch information
LuChengTHU and patil-suraj authored Nov 6, 2022
1 parent 08a6dc8 commit b4a1ed8
Show file tree
Hide file tree
Showing 17 changed files with 1,362 additions and 5 deletions.
6 changes: 6 additions & 0 deletions docs/source/api/schedulers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ Original paper can be found [here](https://arxiv.org/abs/2010.02502).

[[autodoc]] DDPMScheduler

#### Multistep DPM-Solver

Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [improved version](https://arxiv.org/abs/2211.01095). The original implementation can be found [here](https://github.com/LuChengTHU/dpm-solver).

[[autodoc]] DPMSolverMultistepScheduler

#### Variance exploding, stochastic sampling from Karras et. al

Original paper can be found [here](https://arxiv.org/abs/2006.11239).
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .schedulers import (
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
IPNDMScheduler,
Expand Down Expand Up @@ -92,6 +93,7 @@
from .schedulers import (
FlaxDDIMScheduler,
FlaxDDPMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxKarrasVeScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@

from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...pipeline_flax_utils import FlaxDiffusionPipeline
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler
from ...schedulers import (
FlaxDDIMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
)
from ...utils import logging
from . import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
Expand Down Expand Up @@ -43,7 +48,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
[`FlaxDPMSolverMultistepScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
Expand All @@ -57,7 +63,9 @@ def __init__(
text_encoder: FlaxCLIPTextModel,
tokenizer: CLIPTokenizer,
unet: FlaxUNet2DConditionModel,
scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler],
scheduler: Union[
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
],
safety_checker: FlaxStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
dtype: jnp.dtype = jnp.float32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
Expand Down Expand Up @@ -59,7 +60,12 @@ def __init__(
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
if is_torch_available():
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from .scheduling_euler_discrete import EulerDiscreteScheduler
from .scheduling_ipndm import IPNDMScheduler
Expand All @@ -35,6 +36,7 @@
if is_flax_available():
from .scheduling_ddim_flax import FlaxDDIMScheduler
from .scheduling_ddpm_flax import FlaxDDPMScheduler
from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]

@register_to_config
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]

@register_to_config
Expand Down
Loading

0 comments on commit b4a1ed8

Please sign in to comment.