Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for different model prediction types in DDIMInverseScheduler #2619

Merged
merged 16 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1156,8 +1156,8 @@ def invert(

# 7. Denoising loop where we obtain the cross-attention maps.
num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order
with self.progress_bar(total=num_inference_steps - 2) as progress_bar:
for i, t in enumerate(timesteps[1:-1]):
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
Expand Down
58 changes: 41 additions & 17 deletions src/diffusers/schedulers/scheduling_ddim_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
set_alpha_to_zero: bool = True,
clarencechen marked this conversation as resolved.
Show resolved Hide resolved
steps_offset: int = 0,
prediction_type: str = "epsilon",
):
Expand All @@ -144,11 +144,12 @@ def __init__(
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# At every step in inverted ddim, we are looking into the next alphas_cumprod
# For the final step, there is no next alphas_cumprod, and the index is out of bounds
# `set_alpha_to_zero` decides whether we set this parameter simply to zero
# in this case, self.step() just normalizes output by self.config.prediction_type
# or whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_zero else self.alphas_cumprod[-1]

# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
Expand All @@ -157,6 +158,7 @@ def __init__(
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64))

# Copy from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
clarencechen marked this conversation as resolved.
Show resolved Hide resolved
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
Expand Down Expand Up @@ -205,23 +207,45 @@ def step(
variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]:
e_t = model_output

x = sample
# 1. get previous step value (=t+1)
prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps

a_t = self.alphas_cumprod[timestep - 1]
a_prev = self.alphas_cumprod[prev_timestep - 1] if prev_timestep >= 0 else self.final_alpha_cumprod

pred_x0 = (x - (1 - a_t) ** 0.5 * e_t) / a_t.sqrt()
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = (
self.alphas_cumprod[prev_timestep]
if prev_timestep < self.config.num_train_timesteps
else self.final_alpha_cumprod
)

beta_prod_t = 1 - alpha_prod_t

# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)

dir_xt = (1.0 - a_prev).sqrt() * e_t
# 4. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon

prev_sample = a_prev.sqrt() * pred_x0 + dir_xt
# 5. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

if not return_dict:
return (prev_sample, pred_x0)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0)
return (prev_sample, pred_original_sample)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

def __len__(self):
return self.config.num_train_timesteps