-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SD3 Image-to-Image and Inpainting (#7295)
## Summary Add support for SD3 image-to-image and inpainting. Similar to FLUX, the implementation supports fractional denoise_start/denoise_end for more fine-grained denoise strength control, and a gradient mask adjustment schedule for smoother inpainting seams. ## Example Workflow <img width="1016" alt="image" src="https://github.com/user-attachments/assets/ee598d77-be80-4ca7-9355-c3cbefa2ef43"> Result ![image](https://github.com/user-attachments/assets/43953fa7-0e4e-42b5-84e8-85cfeeeee00b) ## QA Instructions - [x] Regression test of text-to-image - [x] Test image-to-image without mask - [x] Test that adjusting denoising_start allows fine-grained control of amount of change in image-to-image - [x] Test inpainting with mask - [x] Smoke test SD1, SDXL, FLUX image-to-image to make sure there was no regression with the frontend changes. ## Merge Plan <!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.--> ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
- Loading branch information
Showing
36 changed files
with
783 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import einops | ||
import torch | ||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL | ||
|
||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation | ||
from invokeai.app.invocations.fields import ( | ||
FieldDescriptions, | ||
ImageField, | ||
Input, | ||
InputField, | ||
WithBoard, | ||
WithMetadata, | ||
) | ||
from invokeai.app.invocations.model import VAEField | ||
from invokeai.app.invocations.primitives import LatentsOutput | ||
from invokeai.app.services.shared.invocation_context import InvocationContext | ||
from invokeai.backend.model_manager.load.load_base import LoadedModel | ||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor | ||
|
||
|
||
@invocation( | ||
"sd3_i2l", | ||
title="SD3 Image to Latents", | ||
tags=["image", "latents", "vae", "i2l", "sd3"], | ||
category="image", | ||
version="1.0.0", | ||
classification=Classification.Prototype, | ||
) | ||
class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard): | ||
"""Generates latents from an image.""" | ||
|
||
image: ImageField = InputField(description="The image to encode") | ||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection) | ||
|
||
@staticmethod | ||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: | ||
with vae_info as vae: | ||
assert isinstance(vae, AutoencoderKL) | ||
|
||
vae.disable_tiling() | ||
|
||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) | ||
with torch.inference_mode(): | ||
image_tensor_dist = vae.encode(image_tensor).latent_dist | ||
# TODO: Use seed to make sampling reproducible. | ||
latents: torch.Tensor = image_tensor_dist.sample().to(dtype=vae.dtype) | ||
|
||
latents = vae.config.scaling_factor * latents | ||
|
||
return latents | ||
|
||
@torch.no_grad() | ||
def invoke(self, context: InvocationContext) -> LatentsOutput: | ||
image = context.images.get_pil(self.image.image_name) | ||
|
||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) | ||
if image_tensor.dim() == 3: | ||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") | ||
|
||
vae_info = context.models.load(self.vae.vae) | ||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor) | ||
|
||
latents = latents.to("cpu") | ||
name = context.tensors.save(tensor=latents) | ||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None) |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import torch | ||
|
||
|
||
class InpaintExtension: | ||
"""A class for managing inpainting with SD3.""" | ||
|
||
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor): | ||
"""Initialize InpaintExtension. | ||
Args: | ||
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0). | ||
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be | ||
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the | ||
inpainted region with the background. | ||
noise (torch.Tensor): The noise tensor used to noise the init_latents. | ||
""" | ||
assert init_latents.dim() == inpaint_mask.dim() == noise.dim() == 4 | ||
assert init_latents.shape[-2:] == inpaint_mask.shape[-2:] == noise.shape[-2:] | ||
|
||
self._init_latents = init_latents | ||
self._inpaint_mask = inpaint_mask | ||
self._noise = noise | ||
|
||
def _apply_mask_gradient_adjustment(self, t_prev: float) -> torch.Tensor: | ||
"""Applies inpaint mask gradient adjustment and returns the inpaint mask to be used at the current timestep.""" | ||
# As we progress through the denoising process, we promote gradient regions of the mask to have a full weight of | ||
# 1.0. This helps to produce more coherent seams around the inpainted region. We experimented with a (small) | ||
# number of promotion strategies (e.g. gradual promotion based on timestep), but found that a simple cutoff | ||
# threshold worked well. | ||
# We use a small epsilon to avoid any potential issues with floating point precision. | ||
eps = 1e-4 | ||
mask_gradient_t_cutoff = 0.5 | ||
if t_prev > mask_gradient_t_cutoff: | ||
# Early in the denoising process, use the inpaint mask as-is. | ||
return self._inpaint_mask | ||
else: | ||
# After the cut-off, promote all non-zero mask values to 1.0. | ||
mask = self._inpaint_mask.where(self._inpaint_mask <= (0.0 + eps), 1.0) | ||
|
||
return mask | ||
|
||
def merge_intermediate_latents_with_init_latents( | ||
self, intermediate_latents: torch.Tensor, t_prev: float | ||
) -> torch.Tensor: | ||
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e. | ||
update the intermediate latents to keep the regions that are not being inpainted on the correct noise | ||
trajectory. | ||
This function should be called after each denoising step. | ||
""" | ||
|
||
mask = self._apply_mask_gradient_adjustment(t_prev) | ||
|
||
# Noise the init latents for the current timestep. | ||
noised_init_latents = self._noise * t_prev + (1.0 - t_prev) * self._init_latents | ||
|
||
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask. | ||
return intermediate_latents * mask + noised_init_latents * (1.0 - mask) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.