Skip to content

Commit

Permalink
restart server redesign
Browse files Browse the repository at this point in the history
  • Loading branch information
vladmandic committed May 3, 2023
1 parent 0af6c70 commit 5d8c787
Show file tree
Hide file tree
Showing 11 changed files with 117 additions and 78 deletions.
2 changes: 1 addition & 1 deletion extensions-builtin/sd-webui-controlnet
Submodule sd-webui-controlnet updated 1 files
+1 −5 README.md
1 change: 0 additions & 1 deletion javascript/progressbar.js
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ function formatTime(secs) {

function setTitle(progress) {
var title = 'SD.Next'
console.log('progress:', progress)
if (progress) title += ' ' + progress.split(' ')[0].trim();
if (document.title != title) document.title = title;
}
Expand Down
30 changes: 27 additions & 3 deletions javascript/ui.js
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,33 @@ function update_token_counter(button_id) {
token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
}

function monitor_server_status() {
document.open();
document.write(`
<html>
<head><title>SD.Next</title></head>
<body style="background: #222222; font-size: 1rem; font-family:monospace; margin-top:20%; color:lightgray; text-align:center">
<h1>Waiting for server...</h1>
<script>
function monitor_server_status() {
fetch('http://127.0.0.1:7860/sdapi/v1/progress')
.then((res) => { !res?.ok ? setTimeout(monitor_server_status, 1000) : location.reload(); })
.catch((e) => setTimeout(monitor_server_status, 1000))
}
window.onload = () => monitor_server_status();
</script>
</body>
</html>
`);
document.close();
}

function restart_reload(){
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
setTimeout(function(){location.reload()},8000)
document.body.style = "background: #222222; font-size: 1rem; font-family:monospace; margin-top:20%; color:lightgray; text-align:center"
document.body.innerHTML = "<h1>Server shutdown in progress...</h1>"
fetch('http://127.0.0.1:7860/sdapi/v1/progress')
.then((res) => setTimeout(restart_reload, 1000))
.catch((e) => setTimeout(monitor_server_status, 500))
return []
}

Expand Down Expand Up @@ -307,7 +331,7 @@ function create_theme_element() {
}

function preview_theme() {
const name = gradioApp().getElementById('setting_gradio_theme').querySelectorAll('span')[1].innerText; // ugly but we want current value without the need to set apply
const name = gradioApp().getElementById('setting_gradio_theme').querySelectorAll('input')?.[0].value || '';
if (name === 'black-orange' || name.startsWith('gradio/')) {
el = document.getElementById('theme-preview') || create_theme_element();
el.style.display = el.style.display === 'block' ? 'none' : 'block';
Expand Down
81 changes: 62 additions & 19 deletions launch.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
### majority of this file is superflous, but used by some extensions as helpers during extension installation

import subprocess
import os
import sys
import time
import shlex
import logging
import subprocess

commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
sys.argv += shlex.split(commandline_args)
Expand All @@ -28,7 +27,7 @@
skip_install = False # parsed by some extensions


def commit_hash():
def commit_hash(): # compatbility function
global stored_commit_hash # pylint: disable=global-statement
if stored_commit_hash is not None:
return stored_commit_hash
Expand All @@ -39,7 +38,7 @@ def commit_hash():
return stored_commit_hash


def run(command, desc=None, errdesc=None, custom_env=None, live=False):
def run(command, desc=None, errdesc=None, custom_env=None, live=False): # compatbility function
if desc is not None:
installer.log.info(desc)
if live:
Expand All @@ -56,41 +55,74 @@ def run(command, desc=None, errdesc=None, custom_env=None, live=False):
return result.stdout.decode(encoding="utf8", errors="ignore")


def check_run(command):
def check_run(command): # compatbility function
result = subprocess.run(command, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
return result.returncode == 0


def is_installed(package):
def is_installed(package): # compatbility function
return installer.installed(package)


def repo_dir(name):
def repo_dir(name): # compatbility function
return os.path.join(script_path, dir_repos, name)


def run_python(code, desc=None, errdesc=None):
def run_python(code, desc=None, errdesc=None): # compatbility function
return run(f'"{sys.executable}" -c "{code}"', desc, errdesc)


def run_pip(pkg, desc=None):
def run_pip(pkg, desc=None): # compatbility function
if desc is None:
desc = pkg
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
return run(f'"{sys.executable}" -m pip {pkg} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")


def check_run_python(code):
def check_run_python(code): # compatbility function
return check_run(f'"{sys.executable}" -c "{code}"')


def git_clone(url, tgt, _name, commithash=None):
def git_clone(url, tgt, _name, commithash=None): # compatbility function
installer.clone(url, tgt, commithash)


def run_extension_installer(ext_dir):
def run_extension_installer(ext_dir): # compatbility function
installer.run_extension_installer(ext_dir)


def get_memory_stats():
import psutil
def gb(val: float):
return round(val / 1024 / 1024 / 1024, 2)
process = psutil.Process(os.getpid())
res = process.memory_info()
ram_total = 100 * res.rss / process.memory_percent()
return f'used: {gb(res.rss)} total: {gb(ram_total)}'


def start_server(immediate=True, server=None):
import gc
import importlib.util
collected = 0
if server is not None:
server = None
collected = gc.collect()
if not immediate:
time.sleep(3)
installer.log.debug(f'Memory {get_memory_stats()} Collected {collected}')
module_spec = importlib.util.spec_from_file_location('webui', 'webui.py')
server = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(server)
if args.test:
installer.log.info("Test only")
server.wants_restart = False
else:
server = server.webui()
installer.log.info(f'Memory {get_memory_stats()}')
return server


if __name__ == "__main__":
if args.version:
installer.add_args()
Expand All @@ -105,9 +137,20 @@ def run_extension_installer(ext_dir):
installer.log.info(f"Server arguments: {sys.argv[1:]}")
installer.log.debug('Starting WebUI')
logging.disable(logging.NOTSET if args.debug else logging.DEBUG)
if args.test:
installer.log.info("Test only")
import webui
exit(0)
import webui
webui.webui()

instance = start_server(immediate=True, server=None)
while True:
try:
alive = instance.thread.is_alive()
except:
alive = False
if round(time.time()) % 30 == 0:
installer.log.debug(f'Server alive: {alive} Memory {get_memory_stats()}')
if not alive:
if instance.wants_restart:
installer.log.info('Server restarting...')
instance = start_server(immediate=False, server=instance)
else:
installer.log.info('Exiting...')
break
time.sleep(1)
15 changes: 6 additions & 9 deletions modules/devices.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import sys
import contextlib
import torch
from modules import shared
from modules import cmd_args, shared
try:
import intel_extension_for_pytorch as ipex
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
except:
pass

if sys.platform == "darwin":
from modules import mac_specific
from modules import mac_specific # pylint: disable=ungrouped-imports


def has_mps() -> bool:
Expand All @@ -17,11 +17,10 @@ def has_mps() -> bool:
else:
return mac_specific.has_mps

def extract_device_id(args, name):
def extract_device_id(args, name): # pylint: disable=redefined-outer-name
for x in range(len(args)):
if name in args[x]:
return args[x + 1]

return None


Expand All @@ -48,7 +47,7 @@ def get_optimal_device_name():
if has_mps():
return "mps"
try:
import torch_directml
import torch_directml # pylint: disable=import-error
if torch_directml.is_available():
return get_dml_device_string()
else:
Expand Down Expand Up @@ -110,9 +109,7 @@ def set_cuda_params():
dtype_vae = torch.float32
unet_needs_upcast = shared.opts.upcast_sampling


from modules.cmd_args import parser
args = parser.parse_args()
args = cmd_args.parser.parse_args()
if args.use_ipex:
cpu = torch.device("xpu") #Use XPU instead of CPU. %20 Perf improvement on weak CPUs.
print("Using XPU instead of CPU.")
Expand Down
12 changes: 6 additions & 6 deletions modules/mac_specific.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import platform
from packaging import version
import torch
try:
import intel_extension_for_pytorch as ipex
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
except:
pass
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version


# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
Expand All @@ -22,7 +22,7 @@ def check_for_mps() -> bool:


# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs):
def cumsum_fix(input, cumsum_func, *args, **kwargs): # pylint: disable=redefined-builtin
if input.device.type == 'mps':
output_dtype = kwargs.get('dtype', input.dtype)
if output_dtype == torch.int64:
Expand All @@ -46,14 +46,14 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
elif version.parse(torch.__version__) > version.parse("1.13.1"):
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) # pylint: disable=unnecessary-lambda-assignment
CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
Expand Down
3 changes: 1 addition & 2 deletions modules/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from collections import namedtuple
import gradio as gr
from modules import paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors
from installer import log

AlwaysVisible = object()

Expand Down Expand Up @@ -219,7 +218,7 @@ def register_scripts_from_module(module):
for _key, script_class in module.__dict__.items():
if type(script_class) != type:
continue
log.debug(f'Registering script: {scriptfile.path}')
# log.debug(f'Registering script: {scriptfile.path}')
if issubclass(script_class, Script):
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
Expand Down
9 changes: 5 additions & 4 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,19 +697,20 @@ def clear(self):
mem_mon.start()


def restart_server():
def restart_server(restart=True):
if demo is None:
return
log.info('Server shutdown requested')
try:
import logging
log.setLevel(logging.DEBUG if cmd_opts.debug else logging.CRITICAL)
demo.server.wants_restart = restart
demo.server.should_exit = True
demo.server.force_exit = True
demo.close(verbose=False)
demo.server.close()
except:
pass
log.info('Server shutdown')
if restart:
log.info('Server will restart')


def listfiles(dirname):
Expand Down
17 changes: 6 additions & 11 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,38 +1229,32 @@ def fun():

def run_settings(*args):
changed = []

for key, value, comp in zip(opts.data_labels.keys(), args, components):
assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"

for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp == dummy_component:
continue

if opts.set(key, value):
changed.append(key)

try:
opts.save(shared.config_filename)
except RuntimeError:
return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
return opts.dumpjson(), f'{len(changed)} Settings changed without save: {", ".join(changed)}'
return opts.dumpjson(), f'{len(changed)} Settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}'

def run_settings_single(value, key):
if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson()

if not opts.set(key, value):
return gr.update(value=getattr(opts, key)), opts.dumpjson()

opts.save(shared.config_filename)

return get_value_for_setting(key), opts.dumpjson()

with gr.Blocks(analytics_enabled=False) as settings_interface:
with gr.Row():
settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
restart_submit = gr.Button(value="Restart UI", variant='primary', elem_id="restart_submit")
restart_submit = gr.Button(value="Restart server", variant='primary', elem_id="restart_submit")
shutdown_submit = gr.Button(value="Shutdown server", variant='primary', elem_id="shutdown_submit")
preview_theme = gr.Button(value="Preview theme", variant='primary', elem_id="settings_preview_theme")
unload_sd_model = gr.Button(value='Unload checkpoint', variant='primary', elem_id="sett_unload_sd_model")
reload_sd_model = gr.Button(value='Reload checkpoint', variant='primary', elem_id="sett_reload_sd_model")
Expand Down Expand Up @@ -1392,7 +1386,8 @@ def reload_scripts():
inputs=components,
outputs=[text_settings, result],
)
restart_submit.click(fn=shared.restart_server, _js="restart_reload")
restart_submit.click(fn=lambda x: shared.restart_server(restart=True), _js="restart_reload")
shutdown_submit.click(fn=lambda x: shared.restart_server(restart=False), _js="restart_reload")

for i, k, item in quicksettings_list:
component = component_dict[k]
Expand Down
Loading

0 comments on commit 5d8c787

Please sign in to comment.