Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for playground v2.5 #3073

Merged
merged 5 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions args_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import ldm_patched.modules.args_parser as args_parser
import os

from tempfile import gettempdir

args_parser.parser.add_argument("--share", action='store_true', help="Set whether to share on Gradio.")

Expand Down
12 changes: 10 additions & 2 deletions ldm_patched/contrib/external_model_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction", "eps"],),
"sampling": (["v_prediction", "edm_playground_v2.5", "eps"],),
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
}}
Expand All @@ -121,17 +121,25 @@ def INPUT_TYPES(s):
def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone()

latent_format = None
sigma_data = 1.0
if sampling == "eps":
sampling_type = ldm_patched.modules.model_sampling.EPS
elif sampling == "v_prediction":
sampling_type = ldm_patched.modules.model_sampling.V_PREDICTION
elif sampling == "edm_playground_v2.5":
sampling_type = ldm_patched.modules.model_sampling.EDM
sigma_data = 0.5
latent_format = ldm_patched.modules.latent_formats.SDXL_Playground_2_5()

class ModelSamplingAdvanced(ldm_patched.modules.model_sampling.ModelSamplingContinuousEDM, sampling_type):
pass

model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_sigma_range(sigma_min, sigma_max)
model_sampling.set_parameters(sigma_min, sigma_max, sigma_data)
m.add_object_patch("model_sampling", model_sampling)
if latent_format is not None:
m.add_object_patch("latent_format", latent_format)
return (m, )

class RescaleCFG:
Expand Down
65 changes: 65 additions & 0 deletions ldm_patched/modules/latent_formats.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch

class LatentFormat:
scale_factor = 1.0
Expand Down Expand Up @@ -34,6 +35,70 @@ def __init__(self):
]
self.taesd_decoder_name = "taesdxl_decoder"

class SDXL_Playground_2_5(LatentFormat):
def __init__(self):
self.scale_factor = 0.5
self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1)
self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1)

self.latent_rgb_factors = [
# R G B
[ 0.3920, 0.4054, 0.4549],
[-0.2634, -0.0196, 0.0653],
[ 0.0568, 0.1687, -0.0755],
[-0.3112, -0.2359, -0.2076]
]
self.taesd_decoder_name = "taesdxl_decoder"

def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std

def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean


class SD_X4(LatentFormat):
def __init__(self):
self.scale_factor = 0.08333
self.latent_rgb_factors = [
[-0.2340, -0.3863, -0.3257],
[ 0.0994, 0.0885, -0.0908],
[-0.2833, -0.2349, -0.3741],
[ 0.2523, -0.0055, -0.1651]
]

class SC_Prior(LatentFormat):
def __init__(self):
self.scale_factor = 1.0
self.latent_rgb_factors = [
[-0.0326, -0.0204, -0.0127],
[-0.1592, -0.0427, 0.0216],
[ 0.0873, 0.0638, -0.0020],
[-0.0602, 0.0442, 0.1304],
[ 0.0800, -0.0313, -0.1796],
[-0.0810, -0.0638, -0.1581],
[ 0.1791, 0.1180, 0.0967],
[ 0.0740, 0.1416, 0.0432],
[-0.1745, -0.1888, -0.1373],
[ 0.2412, 0.1577, 0.0928],
[ 0.1908, 0.0998, 0.0682],
[ 0.0209, 0.0365, -0.0092],
[ 0.0448, -0.0650, -0.1728],
[-0.1658, -0.1045, -0.1308],
[ 0.0542, 0.1545, 0.1325],
[-0.0352, -0.1672, -0.2541]
]

class SC_B(LatentFormat):
def __init__(self):
self.scale_factor = 1.0 / 0.43
self.latent_rgb_factors = [
[ 0.1121, 0.2006, 0.1023],
[-0.2093, -0.0222, -0.0195],
[-0.3087, -0.1535, 0.0366],
[ 0.0290, -0.1574, -0.4078]
]
93 changes: 80 additions & 13 deletions ldm_patched/modules/model_sampling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import numpy as np
from ldm_patched.ldm.modules.diffusionmodules.util import make_beta_schedule
import math

Expand All @@ -12,12 +11,28 @@ def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma

def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else:
noise = noise * sigma

noise += latent_image
return noise

def inverse_noise_scaling(self, sigma, latent):
return latent

class V_PREDICTION(EPS):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5

class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5


class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None):
Expand All @@ -42,24 +57,23 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
alphas_cumprod = torch.cumprod(alphas, dim=0)

timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end

# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))

sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
self.set_sigmas(sigmas)
self.set_alphas_cumprod(alphas_cumprod.float())

def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())

def set_alphas_cumprod(self, alphas_cumprod):
self.register_buffer("alphas_cumprod", alphas_cumprod.float())
self.register_buffer('sigmas', sigmas.float())
self.register_buffer('log_sigmas', sigmas.log().float())

@property
def sigma_min(self):
Expand Down Expand Up @@ -94,18 +108,18 @@ def percent_to_sigma(self, percent):
class ModelSamplingContinuousEDM(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
self.sigma_data = 1.0

if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}

sigma_min = sampling_settings.get("sigma_min", 0.002)
sigma_max = sampling_settings.get("sigma_max", 120.0)
self.set_sigma_range(sigma_min, sigma_max)
sigma_data = sampling_settings.get("sigma_data", 1.0)
self.set_parameters(sigma_min, sigma_max, sigma_data)

def set_sigma_range(self, sigma_min, sigma_max):
def set_parameters(self, sigma_min, sigma_max, sigma_data):
self.sigma_data = sigma_data
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()

self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers
Expand Down Expand Up @@ -134,3 +148,56 @@ def percent_to_sigma(self, percent):

log_sigma_min = math.log(self.sigma_min)
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)

class StableCascadeSampling(ModelSamplingDiscrete):
def __init__(self, model_config=None):
super().__init__()

if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}

self.set_parameters(sampling_settings.get("shift", 1.0))

def set_parameters(self, shift=1.0, cosine_s=8e-3):
self.shift = shift
self.cosine_s = torch.tensor(cosine_s)
self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2

#This part is just for compatibility with some schedulers in the codebase
self.num_timesteps = 10000
sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
for x in range(self.num_timesteps):
t = (x + 1) / self.num_timesteps
sigmas[x] = self.sigma(t)

self.set_sigmas(sigmas)

def sigma(self, timestep):
alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod)

if self.shift != 1.0:
var = alpha_cumprod
logSNR = (var/(1-var)).log()
logSNR += 2 * torch.log(1.0 / torch.tensor(self.shift))
alpha_cumprod = logSNR.sigmoid()

alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999)
return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5

def timestep(self, sigma):
var = 1 / ((sigma * sigma) + 1)
var = var.clamp(0, 1.0)
s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device)
t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
return t

def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0

percent = 1.0 - percent
return self.sigma(torch.tensor(percent))
2 changes: 1 addition & 1 deletion ldm_patched/modules/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N

KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "tcd"]
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "tcd", "edm_playground_v2.5"]

class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
Expand Down
27 changes: 22 additions & 5 deletions modules/async_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,16 +828,33 @@ def handler(async_task):

if scheduler_name in ['lcm', 'tcd']:
final_scheduler_name = 'sgm_uniform'
if pipeline.final_unet is not None:
pipeline.final_unet = core.opModelSamplingDiscrete.patch(

def patch_discrete(unet):
return core.opModelSamplingDiscrete.patch(
pipeline.final_unet,
sampling=scheduler_name,
zsnr=False)[0]

if pipeline.final_unet is not None:
pipeline.final_unet = patch_discrete(pipeline.final_unet)
if pipeline.final_refiner_unet is not None:
pipeline.final_refiner_unet = core.opModelSamplingDiscrete.patch(
pipeline.final_refiner_unet,
pipeline.final_refiner_unet = patch_discrete(pipeline.final_refiner_unet)
print(f'Using {scheduler_name} scheduler.')
elif scheduler_name == 'edm_playground_v2.5':
final_scheduler_name = 'karras'

def patch_edm(unet):
return core.opModelSamplingContinuousEDM.patch(
unet,
sampling=scheduler_name,
zsnr=False)[0]
sigma_max=120.0,
sigma_min=0.002)[0]

if pipeline.final_unet is not None:
pipeline.final_unet = patch_edm(pipeline.final_unet)
if pipeline.final_refiner_unet is not None:
pipeline.final_refiner_unet = patch_edm(pipeline.final_refiner_unet)

print(f'Using {scheduler_name} scheduler.')

async_task.yields.append(['preview', (flags.preparation_step_count, 'Moving model to GPU ...', None)])
Expand Down
4 changes: 2 additions & 2 deletions modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from modules.util import get_file_from_folder_list
from ldm_patched.modules.lora import model_lora_keys_unet, model_lora_keys_clip
from modules.config import path_embeddings
from ldm_patched.contrib.external_model_advanced import ModelSamplingDiscrete

from ldm_patched.contrib.external_model_advanced import ModelSamplingDiscrete, ModelSamplingContinuousEDM

opEmptyLatentImage = EmptyLatentImage()
opVAEDecode = VAEDecode()
Expand All @@ -32,6 +31,7 @@
opControlNetApplyAdvanced = ControlNetApplyAdvanced()
opFreeU = FreeU_V2()
opModelSamplingDiscrete = ModelSamplingDiscrete()
opModelSamplingContinuousEDM = ModelSamplingContinuousEDM()


class StableDiffusionModel:
Expand Down
3 changes: 1 addition & 2 deletions modules/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@

KSAMPLER_NAMES = list(KSAMPLER.keys())

SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo",
"align_your_steps", "tcd"]
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo", "align_your_steps", "tcd", "edm_playground_v2.5"]
SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys())

sampler_list = SAMPLER_NAMES
Expand Down
2 changes: 0 additions & 2 deletions modules/patch_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ def patched_register_schedule(self, given_betas=None, beta_schedule="linear", ti
self.linear_end = linear_end
sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)
self.set_sigmas(sigmas)
alphas_cumprod = torch.tensor(alphas_cumprod, dtype=torch.float32)
self.set_alphas_cumprod(alphas_cumprod)
return


Expand Down
1 change: 1 addition & 0 deletions presets/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
!anime.json
!default.json
!lcm.json
!playground_v2.5.json
!realistic.json
!sai.json
Loading