Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge dev branch #5011

Merged
merged 8 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 2 additions & 16 deletions extensions/coqui_tts/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,13 @@
from pathlib import Path

import gradio as gr
from TTS.api import TTS
from TTS.utils.synthesizer import Synthesizer

from modules import chat, shared, ui_chat
from modules.logging_colors import logger
from modules.ui import create_refresh_button
from modules.utils import gradio

try:
from TTS.api import TTS
from TTS.utils.synthesizer import Synthesizer
except ModuleNotFoundError:
logger.error(
"Could not find the TTS module. Make sure to install the requirements for the coqui_tts extension."
"\n"
"\nLinux / Mac:\npip install -r extensions/coqui_tts/requirements.txt\n"
"\nWindows:\npip install -r extensions\\coqui_tts\\requirements.txt\n"
"\n"
"If you used the one-click installer, paste the command above in the terminal window launched after running the \"cmd_\" script. On Windows, that's \"cmd_windows.bat\"."
)

raise

os.environ["COQUI_TOS_AGREED"] = "1"

params = {
Expand Down
2 changes: 2 additions & 0 deletions extensions/openai/script.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import json
import logging
import os
import traceback
from threading import Thread
Expand Down Expand Up @@ -367,6 +368,7 @@ def on_start(public_url: str):
if shared.args.admin_key and shared.args.admin_key != shared.args.api_key:
logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n')

logging.getLogger("uvicorn.error").propagate = False
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)


Expand Down
2 changes: 1 addition & 1 deletion modules/GPTQ_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def load_quantized(model_name):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
pt_path = find_quantized_model_file(model_name)
if not pt_path:
logger.error("Could not find the quantized model in .pt or .safetensors format, exiting...")
logger.error("Could not find the quantized model in .pt or .safetensors format. Exiting.")
exit()
else:
logger.info(f"Found the following quantized model: {pt_path}")
Expand Down
2 changes: 1 addition & 1 deletion modules/LoRA.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def add_lora_transformers(lora_names):

# Add a LoRA when another LoRA is already present
if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys():
logger.info(f"Adding the LoRA(s) named {added_set} to the model...")
logger.info(f"Adding the LoRA(s) named {added_set} to the model")
for lora in added_set:
shared.model.load_adapter(get_lora_path(lora), lora)

Expand Down
3 changes: 2 additions & 1 deletion modules/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def generate_chat_prompt(user_input, state, **kwargs):
else:
renderer = chat_renderer
if state['context'].strip() != '':
messages.append({"role": "system", "content": state['context']})
context = replace_character_names(state['context'], state['name1'], state['name2'])
messages.append({"role": "system", "content": context})

insert_pos = len(messages)
for user_msg, assistant_msg in reversed(history):
Expand Down
9 changes: 7 additions & 2 deletions modules/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,14 @@ def load_extensions():
for i, name in enumerate(shared.args.extensions):
if name in available_extensions:
if name != 'api':
logger.info(f'Loading the extension "{name}"...')
logger.info(f'Loading the extension "{name}"')
try:
exec(f"import extensions.{name}.script")
try:
exec(f"import extensions.{name}.script")
except ModuleNotFoundError:
logger.error(f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\nLinux / Mac:\n\npip install -r extensions/{name}/requirements.txt --upgrade\n\nWindows:\n\npip install -r extensions\\{name}\\requirements.txt --upgrade\n\nIf you used the one-click installer, paste the command above in the terminal window opened after launching the cmd script for your OS.")
raise

extension = getattr(extensions, name).script
apply_settings(extension, name)
if extension not in setup_called and hasattr(extension, "setup"):
Expand Down
176 changes: 63 additions & 113 deletions modules/logging_colors.py
Original file line number Diff line number Diff line change
@@ -1,117 +1,67 @@
# Copied from https://stackoverflow.com/a/1336640

import logging
import platform

logging.basicConfig(
format='%(asctime)s %(levelname)s:%(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)


def add_coloring_to_emit_windows(fn):
# add methods we need to the class
def _out_handle(self):
import ctypes
return ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
out_handle = property(_out_handle)

def _set_color(self, code):
import ctypes

# Constants from the Windows API
self.STD_OUTPUT_HANDLE = -11
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
ctypes.windll.kernel32.SetConsoleTextAttribute(hdl, code)

setattr(logging.StreamHandler, '_set_color', _set_color)

def new(*args):
FOREGROUND_BLUE = 0x0001 # text color contains blue.
FOREGROUND_GREEN = 0x0002 # text color contains green.
FOREGROUND_RED = 0x0004 # text color contains red.
FOREGROUND_INTENSITY = 0x0008 # text color is intensified.
FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED
# winbase.h
# STD_INPUT_HANDLE = -10
# STD_OUTPUT_HANDLE = -11
# STD_ERROR_HANDLE = -12

# wincon.h
# FOREGROUND_BLACK = 0x0000
FOREGROUND_BLUE = 0x0001
FOREGROUND_GREEN = 0x0002
# FOREGROUND_CYAN = 0x0003
FOREGROUND_RED = 0x0004
FOREGROUND_MAGENTA = 0x0005
FOREGROUND_YELLOW = 0x0006
# FOREGROUND_GREY = 0x0007
FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified.

# BACKGROUND_BLACK = 0x0000
# BACKGROUND_BLUE = 0x0010
# BACKGROUND_GREEN = 0x0020
# BACKGROUND_CYAN = 0x0030
# BACKGROUND_RED = 0x0040
# BACKGROUND_MAGENTA = 0x0050
BACKGROUND_YELLOW = 0x0060
# BACKGROUND_GREY = 0x0070
BACKGROUND_INTENSITY = 0x0080 # background color is intensified.

levelno = args[1].levelno
if (levelno >= 50):
color = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY
elif (levelno >= 40):
color = FOREGROUND_RED | FOREGROUND_INTENSITY
elif (levelno >= 30):
color = FOREGROUND_YELLOW | FOREGROUND_INTENSITY
elif (levelno >= 20):
color = FOREGROUND_GREEN
elif (levelno >= 10):
color = FOREGROUND_MAGENTA
else:
color = FOREGROUND_WHITE
args[0]._set_color(color)

ret = fn(*args)
args[0]._set_color(FOREGROUND_WHITE)
# print "after"
return ret
return new


def add_coloring_to_emit_ansi(fn):
# add methods we need to the class
def new(*args):
levelno = args[1].levelno
if (levelno >= 50):
color = '\x1b[31m' # red
elif (levelno >= 40):
color = '\x1b[31m' # red
elif (levelno >= 30):
color = '\x1b[33m' # yellow
elif (levelno >= 20):
color = '\x1b[32m' # green
elif (levelno >= 10):
color = '\x1b[35m' # pink
else:
color = '\x1b[0m' # normal
args[1].msg = color + args[1].msg + '\x1b[0m' # normal
# print "after"
return fn(*args)
return new

logger = logging.getLogger('text-generation-webui')

if platform.system() == 'Windows':
# Windows does not support ANSI escapes and we are using API calls to set the console color
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
else:
# all non-Windows platforms are supporting ANSI escapes so we use them
logging.StreamHandler.emit = add_coloring_to_emit_ansi(logging.StreamHandler.emit)
# log = logging.getLogger()
# log.addFilter(log_filter())
# //hdlr = logging.StreamHandler()
# //hdlr.setFormatter(formatter())

logger = logging.getLogger('text-generation-webui')
logger.setLevel(logging.DEBUG)
def setup_logging():
'''
Copied from: https://github.com/vladmandic/automatic

All credits to vladmandic.
'''

class RingBuffer(logging.StreamHandler):
def __init__(self, capacity):
super().__init__()
self.capacity = capacity
self.buffer = []
self.formatter = logging.Formatter('{ "asctime":"%(asctime)s", "created":%(created)f, "facility":"%(name)s", "pid":%(process)d, "tid":%(thread)d, "level":"%(levelname)s", "module":"%(module)s", "func":"%(funcName)s", "msg":"%(message)s" }')

def emit(self, record):
msg = self.format(record)
# self.buffer.append(json.loads(msg))
self.buffer.append(msg)
if len(self.buffer) > self.capacity:
self.buffer.pop(0)

def get(self):
return self.buffer

from rich.console import Console
from rich.logging import RichHandler
from rich.pretty import install as pretty_install
from rich.theme import Theme
from rich.traceback import install as traceback_install

level = logging.DEBUG
logger.setLevel(logging.DEBUG) # log to file is always at level debug for facility `sd`
console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({
"traceback.border": "black",
"traceback.border.syntax_error": "black",
"inspect.value.border": "black",
}))
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', handlers=[logging.NullHandler()]) # redirect default logger to null
pretty_install(console=console)
traceback_install(console=console, extra_lines=1, max_frames=10, width=console.width, word_wrap=False, indent_guides=False, suppress=[])
while logger.hasHandlers() and len(logger.handlers) > 0:
logger.removeHandler(logger.handlers[0])

# handlers
rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=level, console=console)
rh.setLevel(level)
logger.addHandler(rh)

rb = RingBuffer(100) # 100 entries default in log ring buffer
rb.setLevel(level)
logger.addHandler(rb)
logger.buffer = rb.buffer

# overrides
logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("diffusers").setLevel(logging.ERROR)
logging.getLogger("torch").setLevel(logging.ERROR)
logging.getLogger("lycoris").handlers = logger.handlers


setup_logging()
10 changes: 7 additions & 3 deletions modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@


def load_model(model_name, loader=None):
logger.info(f"Loading {model_name}...")
logger.info(f"Loading {model_name}")
t0 = time.time()

shared.is_seq2seq = False
Expand Down Expand Up @@ -413,8 +413,12 @@ def ExLlamav2_HF_loader(model_name):


def HQQ_loader(model_name):
from hqq.engine.hf import HQQModelForCausalLM
from hqq.core.quantize import HQQLinear, HQQBackend
try:
from hqq.core.quantize import HQQBackend, HQQLinear
from hqq.engine.hf import HQQModelForCausalLM
except ModuleNotFoundError:
logger.error("HQQ is not installed. You can install it with:\n\npip install hqq")
return None

logger.info(f"Loading HQQ model with backend: {shared.args.hqq_backend}")

Expand Down
34 changes: 19 additions & 15 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,22 +204,26 @@
if hasattr(args, arg):
provided_arguments.append(arg)

# Deprecation warnings
deprecated_args = ['notebook', 'chat', 'no_stream', 'mul_mat_q', 'use_fast']
for k in deprecated_args:
if getattr(args, k):
logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.')

# Security warnings
if args.trust_remote_code:
logger.warning('trust_remote_code is enabled. This is dangerous.')
if 'COLAB_GPU' not in os.environ and not args.nowebui:
if args.share:
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
if args.multi_user:
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')


def do_cmd_flags_warnings():

# Deprecation warnings
for k in deprecated_args:
if getattr(args, k):
logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.')

# Security warnings
if args.trust_remote_code:
logger.warning('trust_remote_code is enabled. This is dangerous.')
if 'COLAB_GPU' not in os.environ and not args.nowebui:
if args.share:
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
if args.multi_user:
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')


def fix_loader_name(name):
Expand Down
Loading