diff --git a/args_manager.py b/args_manager.py index e023da276..5a2b37c97 100644 --- a/args_manager.py +++ b/args_manager.py @@ -1,7 +1,4 @@ import ldm_patched.modules.args_parser as args_parser -import os - -from tempfile import gettempdir args_parser.parser.add_argument("--share", action='store_true', help="Set whether to share on Gradio.") diff --git a/ldm_patched/contrib/external_model_advanced.py b/ldm_patched/contrib/external_model_advanced.py index 9b52c36b5..b9f0ebdca 100644 --- a/ldm_patched/contrib/external_model_advanced.py +++ b/ldm_patched/contrib/external_model_advanced.py @@ -108,7 +108,7 @@ class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["v_prediction", "eps"],), + "sampling": (["v_prediction", "edm_playground_v2.5", "eps"],), "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), }} @@ -121,17 +121,25 @@ def INPUT_TYPES(s): def patch(self, model, sampling, sigma_max, sigma_min): m = model.clone() + latent_format = None + sigma_data = 1.0 if sampling == "eps": sampling_type = ldm_patched.modules.model_sampling.EPS elif sampling == "v_prediction": sampling_type = ldm_patched.modules.model_sampling.V_PREDICTION + elif sampling == "edm_playground_v2.5": + sampling_type = ldm_patched.modules.model_sampling.EDM + sigma_data = 0.5 + latent_format = ldm_patched.modules.latent_formats.SDXL_Playground_2_5() class ModelSamplingAdvanced(ldm_patched.modules.model_sampling.ModelSamplingContinuousEDM, sampling_type): pass model_sampling = ModelSamplingAdvanced(model.model.model_config) - model_sampling.set_sigma_range(sigma_min, sigma_max) + model_sampling.set_parameters(sigma_min, sigma_max, sigma_data) m.add_object_patch("model_sampling", model_sampling) + if latent_format is not None: + m.add_object_patch("latent_format", latent_format) return (m, ) class RescaleCFG: diff --git a/ldm_patched/modules/latent_formats.py b/ldm_patched/modules/latent_formats.py index 2252a075e..1606793e0 100644 --- a/ldm_patched/modules/latent_formats.py +++ b/ldm_patched/modules/latent_formats.py @@ -1,3 +1,4 @@ +import torch class LatentFormat: scale_factor = 1.0 @@ -34,6 +35,70 @@ def __init__(self): ] self.taesd_decoder_name = "taesdxl_decoder" +class SDXL_Playground_2_5(LatentFormat): + def __init__(self): + self.scale_factor = 0.5 + self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1) + self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1) + + self.latent_rgb_factors = [ + # R G B + [ 0.3920, 0.4054, 0.4549], + [-0.2634, -0.0196, 0.0653], + [ 0.0568, 0.1687, -0.0755], + [-0.3112, -0.2359, -0.2076] + ] + self.taesd_decoder_name = "taesdxl_decoder" + + def process_in(self, latent): + latents_mean = self.latents_mean.to(latent.device, latent.dtype) + latents_std = self.latents_std.to(latent.device, latent.dtype) + return (latent - latents_mean) * self.scale_factor / latents_std + + def process_out(self, latent): + latents_mean = self.latents_mean.to(latent.device, latent.dtype) + latents_std = self.latents_std.to(latent.device, latent.dtype) + return latent * latents_std / self.scale_factor + latents_mean + + class SD_X4(LatentFormat): def __init__(self): self.scale_factor = 0.08333 + self.latent_rgb_factors = [ + [-0.2340, -0.3863, -0.3257], + [ 0.0994, 0.0885, -0.0908], + [-0.2833, -0.2349, -0.3741], + [ 0.2523, -0.0055, -0.1651] + ] + +class SC_Prior(LatentFormat): + def __init__(self): + self.scale_factor = 1.0 + self.latent_rgb_factors = [ + [-0.0326, -0.0204, -0.0127], + [-0.1592, -0.0427, 0.0216], + [ 0.0873, 0.0638, -0.0020], + [-0.0602, 0.0442, 0.1304], + [ 0.0800, -0.0313, -0.1796], + [-0.0810, -0.0638, -0.1581], + [ 0.1791, 0.1180, 0.0967], + [ 0.0740, 0.1416, 0.0432], + [-0.1745, -0.1888, -0.1373], + [ 0.2412, 0.1577, 0.0928], + [ 0.1908, 0.0998, 0.0682], + [ 0.0209, 0.0365, -0.0092], + [ 0.0448, -0.0650, -0.1728], + [-0.1658, -0.1045, -0.1308], + [ 0.0542, 0.1545, 0.1325], + [-0.0352, -0.1672, -0.2541] + ] + +class SC_B(LatentFormat): + def __init__(self): + self.scale_factor = 1.0 / 0.43 + self.latent_rgb_factors = [ + [ 0.1121, 0.2006, 0.1023], + [-0.2093, -0.0222, -0.0195], + [-0.3087, -0.1535, 0.0366], + [ 0.0290, -0.1574, -0.4078] + ] \ No newline at end of file diff --git a/ldm_patched/modules/model_sampling.py b/ldm_patched/modules/model_sampling.py index 57f51a000..bd8cb18c2 100644 --- a/ldm_patched/modules/model_sampling.py +++ b/ldm_patched/modules/model_sampling.py @@ -1,5 +1,4 @@ import torch -import numpy as np from ldm_patched.ldm.modules.diffusionmodules.util import make_beta_schedule import math @@ -12,12 +11,28 @@ def calculate_denoised(self, sigma, model_output, model_input): sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) return model_input - model_output * sigma + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + if max_denoise: + noise = noise * torch.sqrt(1.0 + sigma ** 2.0) + else: + noise = noise * sigma + + noise += latent_image + return noise + + def inverse_noise_scaling(self, sigma, latent): + return latent class V_PREDICTION(EPS): def calculate_denoised(self, sigma, model_output, model_input): sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 +class EDM(V_PREDICTION): + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + class ModelSamplingDiscrete(torch.nn.Module): def __init__(self, model_config=None): @@ -42,24 +57,23 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps else: betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) alphas = 1. - betas - alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32) - # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + alphas_cumprod = torch.cumprod(alphas, dim=0) timesteps, = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end + # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) + # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) + # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) + sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 self.set_sigmas(sigmas) - self.set_alphas_cumprod(alphas_cumprod.float()) def set_sigmas(self, sigmas): - self.register_buffer('sigmas', sigmas) - self.register_buffer('log_sigmas', sigmas.log()) - - def set_alphas_cumprod(self, alphas_cumprod): - self.register_buffer("alphas_cumprod", alphas_cumprod.float()) + self.register_buffer('sigmas', sigmas.float()) + self.register_buffer('log_sigmas', sigmas.log().float()) @property def sigma_min(self): @@ -94,8 +108,6 @@ def percent_to_sigma(self, percent): class ModelSamplingContinuousEDM(torch.nn.Module): def __init__(self, model_config=None): super().__init__() - self.sigma_data = 1.0 - if model_config is not None: sampling_settings = model_config.sampling_settings else: @@ -103,9 +115,11 @@ def __init__(self, model_config=None): sigma_min = sampling_settings.get("sigma_min", 0.002) sigma_max = sampling_settings.get("sigma_max", 120.0) - self.set_sigma_range(sigma_min, sigma_max) + sigma_data = sampling_settings.get("sigma_data", 1.0) + self.set_parameters(sigma_min, sigma_max, sigma_data) - def set_sigma_range(self, sigma_min, sigma_max): + def set_parameters(self, sigma_min, sigma_max, sigma_data): + self.sigma_data = sigma_data sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp() self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers @@ -134,3 +148,56 @@ def percent_to_sigma(self, percent): log_sigma_min = math.log(self.sigma_min) return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min) + +class StableCascadeSampling(ModelSamplingDiscrete): + def __init__(self, model_config=None): + super().__init__() + + if model_config is not None: + sampling_settings = model_config.sampling_settings + else: + sampling_settings = {} + + self.set_parameters(sampling_settings.get("shift", 1.0)) + + def set_parameters(self, shift=1.0, cosine_s=8e-3): + self.shift = shift + self.cosine_s = torch.tensor(cosine_s) + self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 + + #This part is just for compatibility with some schedulers in the codebase + self.num_timesteps = 10000 + sigmas = torch.empty((self.num_timesteps), dtype=torch.float32) + for x in range(self.num_timesteps): + t = (x + 1) / self.num_timesteps + sigmas[x] = self.sigma(t) + + self.set_sigmas(sigmas) + + def sigma(self, timestep): + alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod) + + if self.shift != 1.0: + var = alpha_cumprod + logSNR = (var/(1-var)).log() + logSNR += 2 * torch.log(1.0 / torch.tensor(self.shift)) + alpha_cumprod = logSNR.sigmoid() + + alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999) + return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5 + + def timestep(self, sigma): + var = 1 / ((sigma * sigma) + 1) + var = var.clamp(0, 1.0) + s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device) + t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s + return t + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 999999999.9 + if percent >= 1.0: + return 0.0 + + percent = 1.0 - percent + return self.sigma(torch.tensor(percent)) \ No newline at end of file diff --git a/ldm_patched/modules/samplers.py b/ldm_patched/modules/samplers.py index 35cb3d738..9ed1fcd28 100644 --- a/ldm_patched/modules/samplers.py +++ b/ldm_patched/modules/samplers.py @@ -523,7 +523,7 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "tcd"] + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "tcd", "edm_playground_v2.5"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): diff --git a/modules/async_worker.py b/modules/async_worker.py index 9c16d6fcb..76e10f924 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -828,16 +828,33 @@ def handler(async_task): if scheduler_name in ['lcm', 'tcd']: final_scheduler_name = 'sgm_uniform' - if pipeline.final_unet is not None: - pipeline.final_unet = core.opModelSamplingDiscrete.patch( + + def patch_discrete(unet): + return core.opModelSamplingDiscrete.patch( pipeline.final_unet, sampling=scheduler_name, zsnr=False)[0] + + if pipeline.final_unet is not None: + pipeline.final_unet = patch_discrete(pipeline.final_unet) if pipeline.final_refiner_unet is not None: - pipeline.final_refiner_unet = core.opModelSamplingDiscrete.patch( - pipeline.final_refiner_unet, + pipeline.final_refiner_unet = patch_discrete(pipeline.final_refiner_unet) + print(f'Using {scheduler_name} scheduler.') + elif scheduler_name == 'edm_playground_v2.5': + final_scheduler_name = 'karras' + + def patch_edm(unet): + return core.opModelSamplingContinuousEDM.patch( + unet, sampling=scheduler_name, - zsnr=False)[0] + sigma_max=120.0, + sigma_min=0.002)[0] + + if pipeline.final_unet is not None: + pipeline.final_unet = patch_edm(pipeline.final_unet) + if pipeline.final_refiner_unet is not None: + pipeline.final_refiner_unet = patch_edm(pipeline.final_refiner_unet) + print(f'Using {scheduler_name} scheduler.') async_task.yields.append(['preview', (flags.preparation_step_count, 'Moving model to GPU ...', None)]) diff --git a/modules/core.py b/modules/core.py index 3ca4cc5b8..78c897592 100644 --- a/modules/core.py +++ b/modules/core.py @@ -21,8 +21,7 @@ from modules.util import get_file_from_folder_list from ldm_patched.modules.lora import model_lora_keys_unet, model_lora_keys_clip from modules.config import path_embeddings -from ldm_patched.contrib.external_model_advanced import ModelSamplingDiscrete - +from ldm_patched.contrib.external_model_advanced import ModelSamplingDiscrete, ModelSamplingContinuousEDM opEmptyLatentImage = EmptyLatentImage() opVAEDecode = VAEDecode() @@ -32,6 +31,7 @@ opControlNetApplyAdvanced = ControlNetApplyAdvanced() opFreeU = FreeU_V2() opModelSamplingDiscrete = ModelSamplingDiscrete() +opModelSamplingContinuousEDM = ModelSamplingContinuousEDM() class StableDiffusionModel: diff --git a/modules/flags.py b/modules/flags.py index 25b0caaec..adaea1d19 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -48,8 +48,7 @@ KSAMPLER_NAMES = list(KSAMPLER.keys()) -SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo", - "align_your_steps", "tcd"] +SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo", "align_your_steps", "tcd", "edm_playground_v2.5"] SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys()) sampler_list = SAMPLER_NAMES diff --git a/modules/patch_precision.py b/modules/patch_precision.py index 22ffda0ad..83569bdd1 100644 --- a/modules/patch_precision.py +++ b/modules/patch_precision.py @@ -51,8 +51,6 @@ def patched_register_schedule(self, given_betas=None, beta_schedule="linear", ti self.linear_end = linear_end sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32) self.set_sigmas(sigmas) - alphas_cumprod = torch.tensor(alphas_cumprod, dtype=torch.float32) - self.set_alphas_cumprod(alphas_cumprod) return diff --git a/presets/.gitignore b/presets/.gitignore index 481930c56..27e74136a 100644 --- a/presets/.gitignore +++ b/presets/.gitignore @@ -2,5 +2,6 @@ !anime.json !default.json !lcm.json +!playground_v2.5.json !realistic.json !sai.json \ No newline at end of file diff --git a/presets/playground_v2.5.json b/presets/playground_v2.5.json new file mode 100644 index 000000000..311bbc1dd --- /dev/null +++ b/presets/playground_v2.5.json @@ -0,0 +1,51 @@ +{ + "default_model": "playground-v2.5-1024px-aesthetic.fp16.safetensors", + "default_refiner": "None", + "default_refiner_switch": 0.5, + "default_loras": [ + [ + true, + "None", + 1.0 + ], + [ + true, + "None", + 1.0 + ], + [ + true, + "None", + 1.0 + ], + [ + true, + "None", + 1.0 + ], + [ + true, + "None", + 1.0 + ] + ], + "default_cfg_scale": 3, + "default_sample_sharpness": 2.0, + "default_sampler": "dpmpp_2m", + "default_scheduler": "edm_playground_v2.5", + "default_performance": "Speed", + "default_prompt": "", + "default_prompt_negative": "", + "default_styles": [ + "Fooocus V2", + "Fooocus Enhance", + "Fooocus Sharp" + ], + "default_aspect_ratio": "1024*1024", + "checkpoint_downloads": { + "playground-v2.5-1024px-aesthetic.fp16.safetensors": "https://huggingface.co/mashb1t/fav_models/resolve/main/fav/playground-v2.5-1024px-aesthetic.fp16.safetensors" + }, + "embeddings_downloads": {}, + "lora_downloads": {}, + "previous_default_models": [] +} \ No newline at end of file