Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HiRes Latent upscaler #96

Merged
merged 1 commit into from
Jan 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 54 additions & 8 deletions scripts/openvino_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
ControlNetModel,
StableDiffusionLatentUpscalePipeline,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
Expand Down Expand Up @@ -535,6 +536,23 @@ class NoWatermark:
def apply_watermark(self, img):
return img


def get_diffusers_upscaler(upscaler: str):
torch._dynamo.reset()
openvino_clear_caches()
model_name = "stabilityai/sd-x2-latent-upscaler"
print("OpenVINO Script: loading upscaling model: " + model_name)
sd_model = StableDiffusionLatentUpscalePipeline.from_pretrained(model_name, torch_dtype=torch.float32)
sd_model.safety_checker = None
sd_model.cond_stage_key = functools.partial(cond_stage_key, shared.sd_model)
sd_model.unet = torch.compile(sd_model.unet, backend="openvino")
sd_model.vae.decode = torch.compile(sd_model.vae.decode, backend="openvino")
shared.sd_diffusers_model = sd_model
del sd_model

return shared.sd_diffusers_model


def get_diffusers_sd_model(model_config, vae_ckpt, sampler_name, enable_caching, openvino_device, mode, is_xl_ckpt, refiner_ckpt, refiner_frac):
if (model_state.recompile == 1):
model_state.partition_id = 0
Expand Down Expand Up @@ -770,7 +788,8 @@ def init_new(self, all_prompts, all_seeds, all_subseeds):
else:
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

def process_images_openvino(p: StableDiffusionProcessing, model_config, vae_ckpt, sampler_name, enable_caching, openvino_device, mode, is_xl_ckpt, refiner_ckpt, refiner_frac) -> Processed:
def process_images_openvino(p: StableDiffusionProcessing, model_config, vae_ckpt, sampler_name, enable_caching, override_hires, upscaler, hires_steps, d_strength, openvino_device, mode, is_xl_ckpt, refiner_ckpt, refiner_frac) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""

"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
if (mode == 0 and p.enable_hr):
Expand Down Expand Up @@ -1092,6 +1111,23 @@ def callback(iter, t, latents):

devices.torch_gc()

# Hight resolutuon mode
if override_hires:
if upscaler == "Latent":
model_state.mode = -1
shared.sd_diffusers_model = get_diffusers_upscaler(upscaler)
img_idx = slice(len(output_images)) if p.batch_size == 1 else slice(1, len(output_images))
output_images[img_idx] = shared.sd_diffusers_model(
image=output_images[img_idx],
prompt=p.prompts,
negative_prompt=p.negative_prompts,
num_inference_steps=hires_steps,
guidance_scale=p.cfg_scale,
generator=generator,
callback = callback,
callback_steps = 1,
).images

res = Processed(
p,
images_list=output_images,
Expand All @@ -1102,7 +1138,8 @@ def callback(iter, t, latents):
index_of_first_image=index_of_first_image,
infotexts=infotexts,
)

if override_hires:
res.info = res.info + f", Hires upscaler: {upscaler}, Denoising strength: {d_strength}"
res.info = res.info + ", Warm up time: " + str(round(warmup_duration, 2)) + " secs "

if (generation_rate >= 1.0):
Expand All @@ -1116,6 +1153,9 @@ def callback(iter, t, latents):

return res

def on_change(mode):
return gr.update(visible=mode)

class Script(scripts.Script):
def title(self):
return "Accelerate with OpenVINO"
Expand Down Expand Up @@ -1170,6 +1210,12 @@ def get_refiner_list():
override_sampler = gr.Checkbox(label="Override the sampling selection from the main UI (Recommended as only below sampling methods have been validated for OpenVINO)", value=True)
sampler_name = gr.Radio(label="Select a sampling method", choices=["Euler a", "Euler", "LMS", "Heun", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras", "DDIM", "PLMS"], value="Euler a")
enable_caching = gr.Checkbox(label="Cache the compiled models on disk for faster model load in subsequent launches (Recommended)", value=True, elem_id=self.elem_id("enable_caching"))
override_hires = gr.Checkbox(label="Override the Hires.fix selection from the main UI (Recommended as only below upscalers have been validated for OpenVINO)", value=False, visible=self.is_txt2img)
with gr.Group(visible=False) as hires:
with gr.Row():
upscaler = gr.Dropdown(label="Upscaler", choices=["Latent"], value="Latent")
hires_steps = gr.Slider(1, 150, value=10, step=1, label="Steps")
d_strength = gr.Slider(0, 1, value=0.5, step=0.01, label="Strength")
warmup_status = gr.Textbox(label="Device", interactive=False, visible=False)
vae_status = gr.Textbox(label="VAE", interactive=False, visible=False)
gr.Markdown(
Expand All @@ -1184,6 +1230,8 @@ def get_refiner_list():
So it's normal for the first inference after a settings change to be slower, while subsequent inferences use the optimized compiled model and run faster.
""")

override_hires.change(on_change, override_hires, hires)

def device_change(choice):
if (model_state.device == choice):
return gr.update(value="Device selected is " + choice, visible=True)
Expand All @@ -1206,9 +1254,9 @@ def refiner_ckpt_change(choice):
else:
model_state.refiner_ckpt = choice
refiner_ckpt.change(refiner_ckpt_change, refiner_ckpt)
return [model_config, vae_ckpt, openvino_device, override_sampler, sampler_name, enable_caching, is_xl_ckpt, refiner_ckpt, refiner_frac]
return [model_config, vae_ckpt, openvino_device, override_sampler, sampler_name, enable_caching, override_hires, upscaler, hires_steps, d_strength, is_xl_ckpt, refiner_ckpt, refiner_frac]

def run(self, p, model_config, vae_ckpt, openvino_device, override_sampler, sampler_name, enable_caching, is_xl_ckpt, refiner_ckpt, refiner_frac):
def run(self, p, model_config, vae_ckpt, openvino_device, override_sampler, sampler_name, enable_caching, override_hires, upscaler, hires_steps, d_strength, is_xl_ckpt, refiner_ckpt, refiner_frac):
os.environ["OPENVINO_TORCH_BACKEND_DEVICE"] = str(openvino_device)

if enable_caching:
Expand All @@ -1225,14 +1273,12 @@ def run(self, p, model_config, vae_ckpt, openvino_device, override_sampler, samp
mode = 0
if self.is_txt2img:
mode = 0
processed = process_images_openvino(p, model_config, vae_ckpt, p.sampler_name, enable_caching, openvino_device, mode, is_xl_ckpt, refiner_ckpt, refiner_frac)
processed = process_images_openvino(p, model_config, vae_ckpt, p.sampler_name, enable_caching, override_hires, upscaler, hires_steps, d_strength, openvino_device, mode, is_xl_ckpt, refiner_ckpt, refiner_frac)
else:
if p.image_mask is None:
mode = 1
else:
mode = 2
p.init = functools.partial(init_new, p)
processed = process_images_openvino(p, model_config, vae_ckpt, p.sampler_name, enable_caching, openvino_device, mode, is_xl_ckpt, refiner_ckpt, refiner_frac)
processed = process_images_openvino(p, model_config, vae_ckpt, p.sampler_name, enable_caching, override_hires, upscaler, hires_steps, d_strength, openvino_device, mode, is_xl_ckpt, refiner_ckpt, refiner_frac)
return processed


Loading