From 5d8c787a7bee2a5ac0ae5e3ee91fe9c6e371bfa2 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Wed, 3 May 2023 17:20:22 -0400 Subject: [PATCH] restart server redesign --- extensions-builtin/sd-webui-controlnet | 2 +- .../stable-diffusion-webui-images-browser | 2 +- javascript/progressbar.js | 1 - javascript/ui.js | 30 ++++++- launch.py | 81 ++++++++++++++----- modules/devices.py | 15 ++-- modules/mac_specific.py | 12 +-- modules/scripts.py | 3 +- modules/shared.py | 9 ++- modules/ui.py | 17 ++-- webui.py | 23 +----- 11 files changed, 117 insertions(+), 78 deletions(-) diff --git a/extensions-builtin/sd-webui-controlnet b/extensions-builtin/sd-webui-controlnet index 8fd1fcdc5..23c0c8030 160000 --- a/extensions-builtin/sd-webui-controlnet +++ b/extensions-builtin/sd-webui-controlnet @@ -1 +1 @@ -Subproject commit 8fd1fcdc536792a957fc4734636765550edbbfcc +Subproject commit 23c0c80306861c1b90a9025dd1f52d3810a3c0d5 diff --git a/extensions-builtin/stable-diffusion-webui-images-browser b/extensions-builtin/stable-diffusion-webui-images-browser index a396a9f90..2f5bbd88e 160000 --- a/extensions-builtin/stable-diffusion-webui-images-browser +++ b/extensions-builtin/stable-diffusion-webui-images-browser @@ -1 +1 @@ -Subproject commit a396a9f90c6cd2fbd16fd2dc5aef3d210491bbc6 +Subproject commit 2f5bbd88e814d446f873ebfa1a752864cfd1cc53 diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 1eb49c35a..09c7539cf 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -38,7 +38,6 @@ function formatTime(secs) { function setTitle(progress) { var title = 'SD.Next' - console.log('progress:', progress) if (progress) title += ' ' + progress.split(' ')[0].trim(); if (document.title != title) document.title = title; } diff --git a/javascript/ui.js b/javascript/ui.js index db8e1f381..e5781767a 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -277,9 +277,33 @@ function update_token_counter(button_id) { token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); } +function monitor_server_status() { + document.open(); + document.write(` + + SD.Next + +

Waiting for server...

+ + + + `); + document.close(); +} + function restart_reload(){ - document.body.innerHTML='

Reloading...

'; - setTimeout(function(){location.reload()},8000) + document.body.style = "background: #222222; font-size: 1rem; font-family:monospace; margin-top:20%; color:lightgray; text-align:center" + document.body.innerHTML = "

Server shutdown in progress...

" + fetch('http://127.0.0.1:7860/sdapi/v1/progress') + .then((res) => setTimeout(restart_reload, 1000)) + .catch((e) => setTimeout(monitor_server_status, 500)) return [] } @@ -307,7 +331,7 @@ function create_theme_element() { } function preview_theme() { - const name = gradioApp().getElementById('setting_gradio_theme').querySelectorAll('span')[1].innerText; // ugly but we want current value without the need to set apply + const name = gradioApp().getElementById('setting_gradio_theme').querySelectorAll('input')?.[0].value || ''; if (name === 'black-orange' || name.startsWith('gradio/')) { el = document.getElementById('theme-preview') || create_theme_element(); el.style.display = el.style.display === 'block' ? 'none' : 'block'; diff --git a/launch.py b/launch.py index 7646dd8c2..29c6f90ef 100644 --- a/launch.py +++ b/launch.py @@ -1,10 +1,9 @@ -### majority of this file is superflous, but used by some extensions as helpers during extension installation - -import subprocess import os import sys +import time import shlex import logging +import subprocess commandline_args = os.environ.get('COMMANDLINE_ARGS', "") sys.argv += shlex.split(commandline_args) @@ -28,7 +27,7 @@ skip_install = False # parsed by some extensions -def commit_hash(): +def commit_hash(): # compatbility function global stored_commit_hash # pylint: disable=global-statement if stored_commit_hash is not None: return stored_commit_hash @@ -39,7 +38,7 @@ def commit_hash(): return stored_commit_hash -def run(command, desc=None, errdesc=None, custom_env=None, live=False): +def run(command, desc=None, errdesc=None, custom_env=None, live=False): # compatbility function if desc is not None: installer.log.info(desc) if live: @@ -56,41 +55,74 @@ def run(command, desc=None, errdesc=None, custom_env=None, live=False): return result.stdout.decode(encoding="utf8", errors="ignore") -def check_run(command): +def check_run(command): # compatbility function result = subprocess.run(command, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) return result.returncode == 0 -def is_installed(package): +def is_installed(package): # compatbility function return installer.installed(package) -def repo_dir(name): +def repo_dir(name): # compatbility function return os.path.join(script_path, dir_repos, name) -def run_python(code, desc=None, errdesc=None): +def run_python(code, desc=None, errdesc=None): # compatbility function return run(f'"{sys.executable}" -c "{code}"', desc, errdesc) -def run_pip(pkg, desc=None): +def run_pip(pkg, desc=None): # compatbility function if desc is None: desc = pkg index_url_line = f' --index-url {index_url}' if index_url != '' else '' return run(f'"{sys.executable}" -m pip {pkg} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") -def check_run_python(code): +def check_run_python(code): # compatbility function return check_run(f'"{sys.executable}" -c "{code}"') -def git_clone(url, tgt, _name, commithash=None): +def git_clone(url, tgt, _name, commithash=None): # compatbility function installer.clone(url, tgt, commithash) -def run_extension_installer(ext_dir): +def run_extension_installer(ext_dir): # compatbility function installer.run_extension_installer(ext_dir) + +def get_memory_stats(): + import psutil + def gb(val: float): + return round(val / 1024 / 1024 / 1024, 2) + process = psutil.Process(os.getpid()) + res = process.memory_info() + ram_total = 100 * res.rss / process.memory_percent() + return f'used: {gb(res.rss)} total: {gb(ram_total)}' + + +def start_server(immediate=True, server=None): + import gc + import importlib.util + collected = 0 + if server is not None: + server = None + collected = gc.collect() + if not immediate: + time.sleep(3) + installer.log.debug(f'Memory {get_memory_stats()} Collected {collected}') + module_spec = importlib.util.spec_from_file_location('webui', 'webui.py') + server = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(server) + if args.test: + installer.log.info("Test only") + server.wants_restart = False + else: + server = server.webui() + installer.log.info(f'Memory {get_memory_stats()}') + return server + + if __name__ == "__main__": if args.version: installer.add_args() @@ -105,9 +137,20 @@ def run_extension_installer(ext_dir): installer.log.info(f"Server arguments: {sys.argv[1:]}") installer.log.debug('Starting WebUI') logging.disable(logging.NOTSET if args.debug else logging.DEBUG) - if args.test: - installer.log.info("Test only") - import webui - exit(0) - import webui - webui.webui() + + instance = start_server(immediate=True, server=None) + while True: + try: + alive = instance.thread.is_alive() + except: + alive = False + if round(time.time()) % 30 == 0: + installer.log.debug(f'Server alive: {alive} Memory {get_memory_stats()}') + if not alive: + if instance.wants_restart: + installer.log.info('Server restarting...') + instance = start_server(immediate=False, server=instance) + else: + installer.log.info('Exiting...') + break + time.sleep(1) diff --git a/modules/devices.py b/modules/devices.py index 4db77cb96..529f75152 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,14 +1,14 @@ import sys import contextlib import torch -from modules import shared +from modules import cmd_args, shared try: - import intel_extension_for_pytorch as ipex + import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import except: pass if sys.platform == "darwin": - from modules import mac_specific + from modules import mac_specific # pylint: disable=ungrouped-imports def has_mps() -> bool: @@ -17,11 +17,10 @@ def has_mps() -> bool: else: return mac_specific.has_mps -def extract_device_id(args, name): +def extract_device_id(args, name): # pylint: disable=redefined-outer-name for x in range(len(args)): if name in args[x]: return args[x + 1] - return None @@ -48,7 +47,7 @@ def get_optimal_device_name(): if has_mps(): return "mps" try: - import torch_directml + import torch_directml # pylint: disable=import-error if torch_directml.is_available(): return get_dml_device_string() else: @@ -110,9 +109,7 @@ def set_cuda_params(): dtype_vae = torch.float32 unet_needs_upcast = shared.opts.upcast_sampling - -from modules.cmd_args import parser -args = parser.parse_args() +args = cmd_args.parser.parse_args() if args.use_ipex: cpu = torch.device("xpu") #Use XPU instead of CPU. %20 Perf improvement on weak CPUs. print("Using XPU instead of CPU.") diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 2455800d5..4c5efa544 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -1,11 +1,11 @@ +import platform +from packaging import version import torch try: - import intel_extension_for_pytorch as ipex + import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import except: pass -import platform from modules.sd_hijack_utils import CondFunc -from packaging import version # has_mps is only available in nightly pytorch (for now) and macOS 12.3+. @@ -22,7 +22,7 @@ def check_for_mps() -> bool: # MPS workaround for https://github.com/pytorch/pytorch/issues/89784 -def cumsum_fix(input, cumsum_func, *args, **kwargs): +def cumsum_fix(input, cumsum_func, *args, **kwargs): # pylint: disable=redefined-builtin if input.device.type == 'mps': output_dtype = kwargs.get('dtype', input.dtype) if output_dtype == torch.int64: @@ -46,14 +46,14 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) - # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 + # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') # MPS workaround for https://github.com/pytorch/pytorch/issues/90532 CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) elif version.parse(torch.__version__) > version.parse("1.13.1"): cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) - cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) + cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) # pylint: disable=unnecessary-lambda-assignment CondFunc('torch.cumsum', cumsum_fix_func, None) CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) diff --git a/modules/scripts.py b/modules/scripts.py index 47da0309b..1f244c66c 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -4,7 +4,6 @@ from collections import namedtuple import gradio as gr from modules import paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors -from installer import log AlwaysVisible = object() @@ -219,7 +218,7 @@ def register_scripts_from_module(module): for _key, script_class in module.__dict__.items(): if type(script_class) != type: continue - log.debug(f'Registering script: {scriptfile.path}') + # log.debug(f'Registering script: {scriptfile.path}') if issubclass(script_class, Script): scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing): diff --git a/modules/shared.py b/modules/shared.py index 3c74d0b99..682fe9fef 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -697,19 +697,20 @@ def clear(self): mem_mon.start() -def restart_server(): +def restart_server(restart=True): if demo is None: return + log.info('Server shutdown requested') try: - import logging - log.setLevel(logging.DEBUG if cmd_opts.debug else logging.CRITICAL) + demo.server.wants_restart = restart demo.server.should_exit = True demo.server.force_exit = True demo.close(verbose=False) demo.server.close() except: pass - log.info('Server shutdown') + if restart: + log.info('Server will restart') def listfiles(dirname): diff --git a/modules/ui.py b/modules/ui.py index 412fafbee..32cb6a216 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1229,38 +1229,32 @@ def fun(): def run_settings(*args): changed = [] - for key, value, comp in zip(opts.data_labels.keys(), args, components): assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" - for key, value, comp in zip(opts.data_labels.keys(), args, components): if comp == dummy_component: continue - if opts.set(key, value): changed.append(key) - try: opts.save(shared.config_filename) except RuntimeError: - return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} Settings changed without save: {", ".join(changed)}' + return opts.dumpjson(), f'{len(changed)} Settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}' def run_settings_single(value, key): if not opts.same_type(value, opts.data_labels[key].default): return gr.update(visible=True), opts.dumpjson() - if not opts.set(key, value): return gr.update(value=getattr(opts, key)), opts.dumpjson() - opts.save(shared.config_filename) - return get_value_for_setting(key), opts.dumpjson() with gr.Blocks(analytics_enabled=False) as settings_interface: with gr.Row(): settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") - restart_submit = gr.Button(value="Restart UI", variant='primary', elem_id="restart_submit") + restart_submit = gr.Button(value="Restart server", variant='primary', elem_id="restart_submit") + shutdown_submit = gr.Button(value="Shutdown server", variant='primary', elem_id="shutdown_submit") preview_theme = gr.Button(value="Preview theme", variant='primary', elem_id="settings_preview_theme") unload_sd_model = gr.Button(value='Unload checkpoint', variant='primary', elem_id="sett_unload_sd_model") reload_sd_model = gr.Button(value='Reload checkpoint', variant='primary', elem_id="sett_reload_sd_model") @@ -1392,7 +1386,8 @@ def reload_scripts(): inputs=components, outputs=[text_settings, result], ) - restart_submit.click(fn=shared.restart_server, _js="restart_reload") + restart_submit.click(fn=lambda x: shared.restart_server(restart=True), _js="restart_reload") + shutdown_submit.click(fn=lambda x: shared.restart_server(restart=False), _js="restart_reload") for i, k, item in quicksettings_list: component = component_dict[k] diff --git a/webui.py b/webui.py index 5b2157f0d..5309aa7df 100644 --- a/webui.py +++ b/webui.py @@ -1,7 +1,6 @@ import os import re import sys -import time import signal import asyncio import logging @@ -226,6 +225,7 @@ def start_ui(): show_api=True, favicon_path='automatic.ico', ) + shared.demo.server.wants_restart = False setup_middleware(app, cmd_opts) cmd_opts.autolaunch = False @@ -244,26 +244,7 @@ def webui(): start_ui() load_model() log.info(f"Startup time: {startup_timer.summary()}") - - while True: - try: - alive = shared.demo.server.thread.is_alive() - except: - alive = False - if not alive: - log.warning('Server restart') - startup_timer.reset() - start_ui() - log.info(f"Startup time: {startup_timer.summary()}") - time.sleep(1) - - """ - import sys - import types - from modules.paths_internal import script_path - libs = [name for name, m in sys.modules.items() if isinstance(m, types.ModuleType) and (getattr(m, '__file__', '') or '').startswith(script_path)] - print(libs) - """ + return shared.demo.server if __name__ == "__main__":