Skip to content

Commit

Permalink
Add support for different model prediction types in DDIMInverseSchedu…
Browse files Browse the repository at this point in the history
…ler (huggingface#2619)

* Add support for different model prediction types in DDIMInverseScheduler
Resolve alpha_prod_t_prev index issue for final step of inversion

* Fix old bug introduced when prediction type is "sample"

* Add support for sample clipping for numerical stability and deprecate old kwarg

* Detach sample, alphas, betas

Derive predicted noise from model output before dist. regularization

Style cleanup

* Log loss for debugging

* Revert "Log loss for debugging"

This reverts commit 76ea9c8.

* Add comments

* Add inversion equivalence test

* Add expected data for Pix2PixZero pipeline tests with SD 2

* Update tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py

* Remove cruft and add more explanatory comments

---------

Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
clarencechen and patrickvonplaten authored Mar 14, 2023
1 parent 89dadde commit 68552ed
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ class Pix2PixInversionPipelineOutput(BaseOutput):
>>> source_embeds = pipeline.get_embeds(source_prompts)
>>> target_embeds = pipeline.get_embeds(target_prompts)
>>> # the latents can then be used to edit a real image
>>> # when using Stable Diffusion 2 or other models that use v-prediction
>>> # set `cross_attention_guidance_amount` to 0.01 or less to avoid input latent gradient explosion
>>> image = pipeline(
... caption,
Expand Down Expand Up @@ -730,6 +732,23 @@ def prepare_image_latents(self, image, batch_size, dtype, device, generator=None

return latents

def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int):
pred_type = self.inverse_scheduler.config.prediction_type
alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep]

beta_prod_t = 1 - alpha_prod_t

if pred_type == "epsilon":
return model_output
elif pred_type == "sample":
return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5)
elif pred_type == "v_prediction":
return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`"
)

def auto_corr_loss(self, hidden_states, generator=None):
batch_size, channel, height, width = hidden_states.shape
if batch_size > 1:
Expand Down Expand Up @@ -1156,8 +1175,8 @@ def invert(

# 7. Denoising loop where we obtain the cross-attention maps.
num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order
with self.progress_bar(total=num_inference_steps - 2) as progress_bar:
for i, t in enumerate(timesteps[1:-1]):
with self.progress_bar(total=num_inference_steps - 1) as progress_bar:
for i, t in enumerate(timesteps[:-1]):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
Expand All @@ -1181,7 +1200,11 @@ def invert(
if lambda_auto_corr > 0:
for _ in range(num_auto_corr_rolls):
var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
l_ac = self.auto_corr_loss(var, generator=generator)

# Derive epsilon from model output before regularizing to IID standard normal
var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t)

l_ac = self.auto_corr_loss(var_epsilon, generator=generator)
l_ac.backward()

grad = var.grad.detach() / num_auto_corr_rolls
Expand All @@ -1190,7 +1213,10 @@ def invert(
if lambda_kl > 0:
var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)

l_kld = self.kl_divergence(var)
# Derive epsilon from model output before regularizing to IID standard normal
var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t)

l_kld = self.kl_divergence(var_epsilon)
l_kld.backward()

grad = var.grad.detach()
Expand Down
87 changes: 64 additions & 23 deletions schedulers/scheduling_ddim_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import BaseOutput
from diffusers.utils import BaseOutput, deprecate


@dataclass
Expand Down Expand Up @@ -96,15 +96,17 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
set_alpha_to_one (`bool`, default `True`):
option to clip predicted sample for numerical stability.
clip_sample_range (`float`, default `1.0`):
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
set_alpha_to_zero (`bool`, default `True`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `0`,
otherwise it uses the value of alpha at step `num_train_timesteps - 1`.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
`set_alpha_to_zero=False`, to make the last step use step `num_train_timesteps - 1` for the previous alpha
product.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
Expand All @@ -122,10 +124,18 @@ def __init__(
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
set_alpha_to_zero: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
clip_sample_range: float = 1.0,
**kwargs,
):
if kwargs.get("set_alpha_to_one", None) is not None:
deprecation_message = (
"The `set_alpha_to_one` argument is deprecated. Please use `set_alpha_to_zero` instead."
)
deprecate("set_alpha_to_one", "1.0.0", deprecation_message, standard_warn=False)
set_alpha_to_zero = kwargs["set_alpha_to_one"]
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
Expand All @@ -144,11 +154,12 @@ def __init__(
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# At every step in inverted ddim, we are looking into the next alphas_cumprod
# For the final step, there is no next alphas_cumprod, and the index is out of bounds
# `set_alpha_to_zero` decides whether we set this parameter simply to zero
# in this case, self.step() just output the predicted noise
# or whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_zero else self.alphas_cumprod[-1]

# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
Expand All @@ -157,6 +168,7 @@ def __init__(
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64))

# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
Expand Down Expand Up @@ -205,23 +217,52 @@ def step(
variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]:
e_t = model_output

x = sample
# 1. get previous step value (=t+1)
prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps

a_t = self.alphas_cumprod[timestep - 1]
a_prev = self.alphas_cumprod[prev_timestep - 1] if prev_timestep >= 0 else self.final_alpha_cumprod
# 2. compute alphas, betas
# change original implementation to exactly match noise levels for analogous forward process
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = (
self.alphas_cumprod[prev_timestep]
if prev_timestep < self.config.num_train_timesteps
else self.final_alpha_cumprod
)

beta_prod_t = 1 - alpha_prod_t

# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)

pred_x0 = (x - (1 - a_t) ** 0.5 * e_t) / a_t.sqrt()
# 4. Clip or threshold "predicted x_0"
if self.config.clip_sample:
pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)

dir_xt = (1.0 - a_prev).sqrt() * e_t
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon

prev_sample = a_prev.sqrt() * pred_x0 + dir_xt
# 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

if not return_dict:
return (prev_sample, pred_x0)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0)
return (prev_sample, pred_original_sample)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

def __len__(self):
return self.config.num_train_timesteps

0 comments on commit 68552ed

Please sign in to comment.