Skip to content

Commit

Permalink
Add support for different model prediction types in DDIMInverseScheduler
Browse files Browse the repository at this point in the history
Resolve alpha_prod_t_prev index issue for final step of inversion
  • Loading branch information
clarencechen committed Mar 9, 2023
1 parent 186689a commit fcee91a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1156,8 +1156,8 @@ def invert(

# 7. Denoising loop where we obtain the cross-attention maps.
num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order
with self.progress_bar(total=num_inference_steps - 2) as progress_bar:
for i, t in enumerate(timesteps[1:-1]):
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
Expand Down
57 changes: 40 additions & 17 deletions src/diffusers/schedulers/scheduling_ddim_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
set_alpha_to_zero: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
):
Expand All @@ -144,11 +144,12 @@ def __init__(
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# At every step in inverted ddim, we are looking into the next alphas_cumprod
# For the final step, there is no next alphas_cumprod, and the index is out of bounds
# `set_alpha_to_zero` decides whether we set this parameter simply to zero
# in this case, self.step() just normalizes output by self.config.prediction_type
# or whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_zero else self.alphas_cumprod[-1]

# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
Expand All @@ -157,6 +158,7 @@ def __init__(
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64))

# Copy from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
Expand Down Expand Up @@ -205,23 +207,44 @@ def step(
variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]:
e_t = model_output

x = sample
# 1. get previous step value (=t+1)
prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps

a_t = self.alphas_cumprod[timestep - 1]
a_prev = self.alphas_cumprod[prev_timestep - 1] if prev_timestep >= 0 else self.final_alpha_cumprod

pred_x0 = (x - (1 - a_t) ** 0.5 * e_t) / a_t.sqrt()
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = (
self.alphas_cumprod[prev_timestep]
if prev_timestep < self.config.num_train_timesteps
else self.final_alpha_cumprod
)

beta_prod_t = 1 - alpha_prod_t

# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)

dir_xt = (1.0 - a_prev).sqrt() * e_t
# 4. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output

prev_sample = a_prev.sqrt() * pred_x0 + dir_xt
# 5. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

if not return_dict:
return (prev_sample, pred_x0)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0)
return (prev_sample, pred_original_sample)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

def __len__(self):
return self.config.num_train_timesteps

0 comments on commit fcee91a

Please sign in to comment.