Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for different model prediction types in DDIMInverseScheduler #2619

Merged
merged 16 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ class Pix2PixInversionPipelineOutput(BaseOutput):
>>> source_embeds = pipeline.get_embeds(source_prompts)
>>> target_embeds = pipeline.get_embeds(target_prompts)
>>> # the latents can then be used to edit a real image
>>> # when using Stable Diffusion 2 or other models that use v-prediction
>>> # set `cross_attention_guidance_amount` to 0.01 or less to avoid input latent gradient explosion
clarencechen marked this conversation as resolved.
Show resolved Hide resolved

>>> image = pipeline(
... caption,
Expand Down Expand Up @@ -730,6 +732,23 @@ def prepare_image_latents(self, image, batch_size, dtype, device, generator=None

return latents

def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int):
pred_type = self.inverse_scheduler.config.prediction_type
alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep]

beta_prod_t = 1 - alpha_prod_t

if pred_type == "epsilon":
return model_output
elif pred_type == "sample":
return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5)
elif pred_type == "v_prediction":
return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`"
)

def auto_corr_loss(self, hidden_states, generator=None):
batch_size, channel, height, width = hidden_states.shape
if batch_size > 1:
Expand Down Expand Up @@ -1156,8 +1175,8 @@ def invert(

# 7. Denoising loop where we obtain the cross-attention maps.
num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order
with self.progress_bar(total=num_inference_steps - 2) as progress_bar:
for i, t in enumerate(timesteps[1:-1]):
with self.progress_bar(total=num_inference_steps - 1) as progress_bar:
for i, t in enumerate(timesteps[:-1]):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
Expand All @@ -1181,7 +1200,11 @@ def invert(
if lambda_auto_corr > 0:
for _ in range(num_auto_corr_rolls):
var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
l_ac = self.auto_corr_loss(var, generator=generator)

# Derive epsilon from model output before regularizing to IID standard normal
var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t)
clarencechen marked this conversation as resolved.
Show resolved Hide resolved

l_ac = self.auto_corr_loss(var_epsilon, generator=generator)
l_ac.backward()

grad = var.grad.detach() / num_auto_corr_rolls
Expand All @@ -1190,7 +1213,10 @@ def invert(
if lambda_kl > 0:
var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)

l_kld = self.kl_divergence(var)
# Derive epsilon from model output before regularizing to IID standard normal
var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t)

l_kld = self.kl_divergence(var_epsilon)
l_kld.backward()

grad = var.grad.detach()
Expand Down
87 changes: 64 additions & 23 deletions src/diffusers/schedulers/scheduling_ddim_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import BaseOutput
from diffusers.utils import BaseOutput, deprecate


@dataclass
Expand Down Expand Up @@ -96,15 +96,17 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
set_alpha_to_one (`bool`, default `True`):
option to clip predicted sample for numerical stability.
clip_sample_range (`float`, default `1.0`):
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
set_alpha_to_zero (`bool`, default `True`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `0`,
otherwise it uses the value of alpha at step `num_train_timesteps - 1`.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
`set_alpha_to_zero=False`, to make the last step use step `num_train_timesteps - 1` for the previous alpha
product.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
Expand All @@ -122,10 +124,18 @@ def __init__(
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
set_alpha_to_zero: bool = True,
clarencechen marked this conversation as resolved.
Show resolved Hide resolved
steps_offset: int = 0,
prediction_type: str = "epsilon",
clip_sample_range: float = 1.0,
**kwargs,
):
if kwargs.get("set_alpha_to_one", None) is not None:
deprecation_message = (
"The `set_alpha_to_one` argument is deprecated. Please use `set_alpha_to_zero` instead."
)
deprecate("set_alpha_to_one", "1.0.0", deprecation_message, standard_warn=False)
set_alpha_to_zero = kwargs["set_alpha_to_one"]
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
Expand All @@ -144,11 +154,12 @@ def __init__(
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# At every step in inverted ddim, we are looking into the next alphas_cumprod
# For the final step, there is no next alphas_cumprod, and the index is out of bounds
# `set_alpha_to_zero` decides whether we set this parameter simply to zero
# in this case, self.step() just output the predicted noise
# or whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_zero else self.alphas_cumprod[-1]

# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
Expand All @@ -157,6 +168,7 @@ def __init__(
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64))

# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
Expand Down Expand Up @@ -205,23 +217,52 @@ def step(
variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]:
e_t = model_output

x = sample
# 1. get previous step value (=t+1)
prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps

a_t = self.alphas_cumprod[timestep - 1]
a_prev = self.alphas_cumprod[prev_timestep - 1] if prev_timestep >= 0 else self.final_alpha_cumprod
# 2. compute alphas, betas
# change original implementation to exactly match noise levels for analogous forward process
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = (
self.alphas_cumprod[prev_timestep]
if prev_timestep < self.config.num_train_timesteps
else self.final_alpha_cumprod
)

beta_prod_t = 1 - alpha_prod_t

# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)

pred_x0 = (x - (1 - a_t) ** 0.5 * e_t) / a_t.sqrt()
# 4. Clip or threshold "predicted x_0"
if self.config.clip_sample:
pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)

dir_xt = (1.0 - a_prev).sqrt() * e_t
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon

prev_sample = a_prev.sqrt() * pred_x0 + dir_xt
# 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

if not return_dict:
return (prev_sample, pred_x0)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0)
return (prev_sample, pred_original_sample)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

def __len__(self):
return self.config.num_train_timesteps
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,6 @@ def test_stable_diffusion_pix2pix_inversion(self):
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
)
pipe.inverse_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)

caption = "a photography of a cat with flowers"
Expand All @@ -366,6 +365,28 @@ def test_stable_diffusion_pix2pix_inversion(self):

assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 5e-2

def test_stable_diffusion_2_pix2pix_inversion(self):
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16
)
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)

caption = "a photography of a cat with flowers"
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)

generator = torch.manual_seed(0)
output = pipe.invert(caption, image=self.raw_image, generator=generator, num_inference_steps=10)
inv_latents = output[0]

image_slice = inv_latents[0, -3:, -3:, -1].flatten()

assert inv_latents.shape == (1, 4, 64, 64)
expected_slice = np.array([0.7515, -0.2397, 0.4922, -0.9736, -0.7031, 0.4846, -1.0781, 1.1309, -0.6973])

assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 5e-2

def test_stable_diffusion_pix2pix_full(self):
# numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog.png
expected_image = load_numpy(
Expand All @@ -375,7 +396,6 @@ def test_stable_diffusion_pix2pix_full(self):
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
)
pipe.inverse_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)

caption = "a photography of a cat with flowers"
Expand Down Expand Up @@ -407,3 +427,44 @@ def test_stable_diffusion_pix2pix_full(self):

max_diff = np.abs(expected_image - image).mean()
assert max_diff < 0.05

def test_stable_diffusion_2_pix2pix_full(self):
# numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog_2.png
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog_2.npy"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this needs to be in https://huggingface.co/datasets/hf-internal-testing/diffusers-images/tree/main/pix2pix? Feel free to open a PR to the repository to add it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, could you maybe open a PR to add the numpy image? This would be very nice!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for opening it - merged it!

)

pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16
)
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)

caption = "a photography of a cat with flowers"
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)

generator = torch.manual_seed(0)
output = pipe.invert(caption, image=self.raw_image, generator=generator)
inv_latents = output[0]

source_prompts = 4 * ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
target_prompts = 4 * ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"]

source_embeds = pipe.get_embeds(source_prompts)
target_embeds = pipe.get_embeds(target_prompts)

image = pipe(
caption,
source_embeds=source_embeds,
target_embeds=target_embeds,
num_inference_steps=125,
cross_attention_guidance_amount=0.015,
generator=generator,
latents=inv_latents,
negative_prompt=caption,
output_type="np",
).images

max_diff = np.abs(expected_image - image).mean()
assert max_diff < 0.05