diff --git a/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index b5a352c785ee..0e58701d93a7 100644 --- a/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -153,6 +153,8 @@ class Pix2PixInversionPipelineOutput(BaseOutput): >>> source_embeds = pipeline.get_embeds(source_prompts) >>> target_embeds = pipeline.get_embeds(target_prompts) >>> # the latents can then be used to edit a real image + >>> # when using Stable Diffusion 2 or other models that use v-prediction + >>> # set `cross_attention_guidance_amount` to 0.01 or less to avoid input latent gradient explosion >>> image = pipeline( ... caption, @@ -730,6 +732,23 @@ def prepare_image_latents(self, image, batch_size, dtype, device, generator=None return latents + def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): + pred_type = self.inverse_scheduler.config.prediction_type + alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + + if pred_type == "epsilon": + return model_output + elif pred_type == "sample": + return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) + elif pred_type == "v_prediction": + return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" + ) + def auto_corr_loss(self, hidden_states, generator=None): batch_size, channel, height, width = hidden_states.shape if batch_size > 1: @@ -1156,8 +1175,8 @@ def invert( # 7. Denoising loop where we obtain the cross-attention maps. num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order - with self.progress_bar(total=num_inference_steps - 2) as progress_bar: - for i, t in enumerate(timesteps[1:-1]): + with self.progress_bar(total=num_inference_steps - 1) as progress_bar: + for i, t in enumerate(timesteps[:-1]): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) @@ -1181,7 +1200,11 @@ def invert( if lambda_auto_corr > 0: for _ in range(num_auto_corr_rolls): var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) - l_ac = self.auto_corr_loss(var, generator=generator) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_ac = self.auto_corr_loss(var_epsilon, generator=generator) l_ac.backward() grad = var.grad.detach() / num_auto_corr_rolls @@ -1190,7 +1213,10 @@ def invert( if lambda_kl > 0: var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) - l_kld = self.kl_divergence(var) + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_kld = self.kl_divergence(var_epsilon) l_kld.backward() grad = var.grad.detach() diff --git a/schedulers/scheduling_ddim_inverse.py b/schedulers/scheduling_ddim_inverse.py index 7006bd133932..2c9fc036a027 100644 --- a/schedulers/scheduling_ddim_inverse.py +++ b/schedulers/scheduling_ddim_inverse.py @@ -23,7 +23,7 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import SchedulerMixin -from diffusers.utils import BaseOutput +from diffusers.utils import BaseOutput, deprecate @dataclass @@ -96,15 +96,17 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. clip_sample (`bool`, default `True`): - option to clip predicted sample between -1 and 1 for numerical stability. - set_alpha_to_one (`bool`, default `True`): + option to clip predicted sample for numerical stability. + clip_sample_range (`float`, default `1.0`): + the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_zero (`bool`, default `True`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final - step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, - otherwise it uses the value of alpha at step 0. + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `0`, + otherwise it uses the value of alpha at step `num_train_timesteps - 1`. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in - stable diffusion. + `set_alpha_to_zero=False`, to make the last step use step `num_train_timesteps - 1` for the previous alpha + product. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 @@ -122,10 +124,18 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, - set_alpha_to_one: bool = True, + set_alpha_to_zero: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", + clip_sample_range: float = 1.0, + **kwargs, ): + if kwargs.get("set_alpha_to_one", None) is not None: + deprecation_message = ( + "The `set_alpha_to_one` argument is deprecated. Please use `set_alpha_to_zero` instead." + ) + deprecate("set_alpha_to_one", "1.0.0", deprecation_message, standard_warn=False) + set_alpha_to_zero = kwargs["set_alpha_to_one"] if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -144,11 +154,12 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # At every step in ddim, we are looking into the previous alphas_cumprod - # For the final step, there is no previous alphas_cumprod because we are already at 0 - # `set_alpha_to_one` decides whether we set this parameter simply to one or - # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + # At every step in inverted ddim, we are looking into the next alphas_cumprod + # For the final step, there is no next alphas_cumprod, and the index is out of bounds + # `set_alpha_to_zero` decides whether we set this parameter simply to zero + # in this case, self.step() just output the predicted noise + # or whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_zero else self.alphas_cumprod[-1] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 @@ -157,6 +168,7 @@ def __init__( self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64)) + # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the @@ -205,23 +217,52 @@ def step( variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: - e_t = model_output - - x = sample + # 1. get previous step value (=t+1) prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps - a_t = self.alphas_cumprod[timestep - 1] - a_prev = self.alphas_cumprod[prev_timestep - 1] if prev_timestep >= 0 else self.final_alpha_cumprod + # 2. compute alphas, betas + # change original implementation to exactly match noise levels for analogous forward process + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + self.alphas_cumprod[prev_timestep] + if prev_timestep < self.config.num_train_timesteps + else self.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) - pred_x0 = (x - (1 - a_t) ** 0.5 * e_t) / a_t.sqrt() + # 4. Clip or threshold "predicted x_0" + if self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) - dir_xt = (1.0 - a_prev).sqrt() * e_t + # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon - prev_sample = a_prev.sqrt() * pred_x0 + dir_xt + # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if not return_dict: - return (prev_sample, pred_x0) - return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0) + return (prev_sample, pred_original_sample) + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def __len__(self): return self.config.num_train_timesteps