Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for multiple separate style*.csv files #14081

Closed
wants to merge 45 commits into from
Closed
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
85c0dd6
change default data locations
Jul 14, 2023
9f1f1ee
Merge branch 'master' of https://github.com/AUTOMATIC1111/stable-diff…
Jul 25, 2023
4921c36
allow controlnet extension
Jul 25, 2023
4a357cd
pull from upstream master of this fork
Jul 25, 2023
df97caa
HF_DATASETS_CACHE
Jul 27, 2023
873edba
Merge branch 'master' of https://github.com/AUTOMATIC1111/stable-diff…
Jul 27, 2023
24bc2a4
updated webui launch parameters
Jul 31, 2023
1fa6942
remove
Jul 31, 2023
86e9335
commandline arguments
Aug 24, 2023
42bea28
Merge branch 'master' of https://github.com/AUTOMATIC1111/stable-diff…
Aug 24, 2023
cd69688
Merge branch 'master' of https://github.com/AUTOMATIC1111/stable-diff…
Sep 1, 2023
3c9951e
ignore version number files
Sep 14, 2023
ba60f53
restore original
Oct 19, 2023
eed0cdc
ignore python defaults and "local" files
Oct 19, 2023
540f7f4
enable wildcard matching for styles CSV files
Oct 19, 2023
4bf0239
Merge branch 'master' of https://github.com/AUTOMATIC1111/stable-diff…
Nov 4, 2023
7f3d3ca
insert per-file dividers in styles list
Nov 6, 2023
1e80fb6
use list of regexes to rewrite styled prompts
Nov 7, 2023
f292a15
Merge branch 'dev' of https://github.com/AUTOMATIC1111/stable-diffusi…
Nov 7, 2023
05d7e40
ruff linting
Nov 7, 2023
503a5e3
Merge branch 'dev' into multiple-style-files
Nov 7, 2023
39b707b
Merge pull request #1 from cjj1977/multiple-style-files
Nov 7, 2023
9bcb3ec
ignore "do_not_save" styles when saving
Nov 7, 2023
362be28
Merge branch 'multiple-style-files' of https://github.com/cjj1977/sta…
Nov 7, 2023
1d92909
Merge pull request #2 from cjj1977/multiple-style-files
Nov 7, 2023
a3e0b65
clean_text function
Nov 8, 2023
799259c
Use clean_text()
Nov 8, 2023
65dc026
more visible divider in style list
Nov 8, 2023
6d73ea4
improve divider
Nov 8, 2023
8d7a46e
Merge branch 'dev' of https://github.com/AUTOMATIC1111/stable-diffusi…
Nov 23, 2023
cbfd194
Merge branch 'dev' of https://github.com/AUTOMATIC1111/stable-diffusi…
Nov 23, 2023
961f977
Merge branch 'dev' into multiple-style-files
Nov 23, 2023
acb34d1
Merge pull request #3 from cjj1977/multiple-style-files
Nov 23, 2023
4aeeb7c
ignore *venv/
Nov 23, 2023
19d8653
make minimal changes
Nov 23, 2023
70f2d31
make minimal changes
Nov 23, 2023
c95e8b4
minimise changes
Nov 23, 2023
ecff022
reinstate file
Nov 23, 2023
05b8d79
Merge branch 'dev' of https://github.com/AUTOMATIC1111/stable-diffusi…
Nov 27, 2023
55ddf94
sync gitignore from dev
Nov 27, 2023
0801f75
fix for Issue 14005
Nov 27, 2023
28a4f26
Merge pull request #4 from cjj1977/multiple-style-files
Nov 27, 2023
6a6483e
Merge branch 'dev' of https://github.com/cjj1977/stable-diffusion-web…
Nov 27, 2023
00f2c70
move local excludes to .git/info/exclude
Nov 27, 2023
70223f2
reset to match master
Nov 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 171 additions & 32 deletions modules/styles.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import csv
import fnmatch
import os
import os.path
import re
Expand All @@ -10,6 +11,23 @@ class PromptStyle(typing.NamedTuple):
name: str
prompt: str
negative_prompt: str
path: str = None


def clean_text(text: str) -> str:
"""
Iterating through a list of regular expressions and replacement strings, we
clean up the prompt and style text to make it easier to match against each
other.
"""
re_list = [
("multiple commas", re.compile("(,+\s+)+,?"), ", "),
("multiple spaces", re.compile("\s{2,}"), " "),
]
for _, regex, replace in re_list:
text = regex.sub(replace, text)

return text.strip(", ")


def merge_prompts(style_prompt: str, prompt: str) -> str:
Expand All @@ -26,41 +44,64 @@ def apply_styles_to_prompt(prompt, styles):
for style in styles:
prompt = merge_prompts(style, prompt)

return prompt
return clean_text(prompt)


re_spaces = re.compile(" +")
def unwrap_style_text_from_prompt(style_text, prompt):
"""
Checks the prompt to see if the style text is wrapped around it. If so,
returns True plus the prompt text without the style text. Otherwise, returns
False with the original prompt.


def extract_style_text_from_prompt(style_text, prompt):
stripped_prompt = re.sub(re_spaces, " ", prompt.strip())
stripped_style_text = re.sub(re_spaces, " ", style_text.strip())
Note that the "cleaned" version of the style text is only used for matching
purposes here. It isn't returned; the original style text is not modified.
"""
stripped_prompt = clean_text(prompt)
stripped_style_text = clean_text(style_text)
if "{prompt}" in stripped_style_text:
left, right = stripped_style_text.split("{prompt}", 2)
# Work out whether the prompt is wrapped in the style text. If so, we
# return True and the "inner" prompt text that isn't part of the style.
try:
left, right = stripped_style_text.split("{prompt}", 2)
except ValueError as e:
# If the style text has multple "{prompt}"s, we can't split it into
# two parts. This is an error, but we can't do anything about it.
print("Unable to compare style text to prompt:`n{style_text}")
print(f"Error: {e}")
return False, prompt
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
return True, prompt
else:
# Work out whether the given prompt ends with the style text. If so, we
# return True and the prompt text up to where the style text starts.
if stripped_prompt.endswith(stripped_style_text):
prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]

if prompt.endswith(', '):
prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]
if prompt.endswith(", "):
prompt = prompt[:-2]

return True, prompt

return False, prompt


def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
"""
Takes a style and compares it to the prompt and negative prompt. If the style
matches, returns True plus the prompt and negative prompt with the style text
removed. Otherwise, returns False with the original prompt and negative prompt.
"""
if not style.prompt and not style.negative_prompt:
return False, prompt, negative_prompt

match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
match_positive, extracted_positive = unwrap_style_text_from_prompt(
style.prompt, prompt
)
if not match_positive:
return False, prompt, negative_prompt

match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
match_negative, extracted_negative = unwrap_style_text_from_prompt(
style.negative_prompt, negative_prompt
)
if not match_negative:
return False, prompt, negative_prompt

Expand All @@ -69,25 +110,88 @@ def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):

class StyleDatabase:
def __init__(self, path: str):
self.no_style = PromptStyle("None", "", "")
self.no_style = PromptStyle("None", "", "", None)
self.styles = {}
self.path = path

folder, file = os.path.split(self.path)
self.default_file = file.split("*")[0] + ".csv"
if self.default_file == ".csv":
self.default_file = "styles.csv"
self.default_path = os.path.join(folder, self.default_file)

self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]

self.reload()

def reload(self):
"""
Clears the style database and reloads the styles from the CSV file(s)
matching the path used to initialize the database.
"""
self.styles.clear()

if not os.path.exists(self.path):
path, filename = os.path.split(self.path)

if "*" in filename:
fileglob = filename.split("*")[0] + "*.csv"
filelist = []
for file in os.listdir(path):
if fnmatch.fnmatch(file, fileglob):
filelist.append(file)
# Add a visible divider to the style list
half_len = round(len(file) / 2)
divider = f"{'-' * (20 - half_len)} {file.upper()}"
divider = f"{divider} {'-' * (40 - len(divider))}"
self.styles[divider] = PromptStyle(
f"{divider}", None, None, "do_not_save"
)
# Add styles from this CSV file
self.load_from_csv(os.path.join(path, file))
if len(filelist) == 0:
print(f"No styles found in {path} matching {fileglob}")
return
elif not os.path.exists(self.path):
print(f"Style database not found: {self.path}")
return
else:
self.load_from_csv(self.path)

with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
def load_from_csv(self, path: str):
with open(path, "r", encoding="utf-8-sig", newline="") as file:
reader = csv.DictReader(file, skipinitialspace=True)
for row in reader:
# Ignore empty rows or rows starting with a comment
if not row or row["name"].startswith("#"):
continue
# Support loading old CSV format with "name, text"-columns
prompt = row["prompt"] if "prompt" in row else row["text"]
negative_prompt = row.get("negative_prompt", "")
self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
# Add style to database
self.styles[row["name"]] = PromptStyle(
row["name"], prompt, negative_prompt, path
)

def get_style_paths(self) -> list():
"""
Returns a list of all distinct paths, including the default path, of
files that styles are loaded from."""
# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
self.styles[style.name] = style._replace(path=self.default_path)

# Create a list of all distinct paths, including the default path
style_paths = set()
style_paths.add(self.default_path)
for _, style in self.styles.items():
if style.path:
style_paths.add(style.path)

# Remove any paths for styles that are just list dividers
style_paths.remove("do_not_save")

return list(style_paths)

def get_style_prompts(self, styles):
return [self.styles.get(x, self.no_style).prompt for x in styles]
Expand All @@ -96,20 +200,53 @@ def get_negative_style_prompts(self, styles):
return [self.styles.get(x, self.no_style).negative_prompt for x in styles]

def apply_styles_to_prompt(self, prompt, styles):
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
return apply_styles_to_prompt(
prompt, [self.styles.get(x, self.no_style).prompt for x in styles]
)

def apply_negative_styles_to_prompt(self, prompt, styles):
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])

def save_styles(self, path: str) -> None:
# Always keep a backup file around
if os.path.exists(path):
shutil.copy(path, f"{path}.bak")

with open(path, "w", encoding="utf-8-sig", newline='') as file:
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
writer.writeheader()
writer.writerows(style._asdict() for k, style in self.styles.items())
return apply_styles_to_prompt(
prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]
)

def save_styles(self, path: str = None) -> None:
# The path argument is deprecated, but kept for backwards compatibility
_ = path

# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
self.styles[style.name] = style._replace(path=self.default_path)

# Create a list of all distinct paths, including the default path
style_paths = set()
style_paths.add(self.default_path)
for _, style in self.styles.items():
if style.path:
style_paths.add(style.path)

# Remove any paths for styles that are just list dividers
style_paths.remove("do_not_save")

csv_names = [os.path.split(path)[1].lower() for path in style_paths]

for style_path in style_paths:
# Always keep a backup file around
if os.path.exists(style_path):
shutil.copy(style_path, f"{style_path}.bak")

# Write the styles to the CSV file
with open(style_path, "w", encoding="utf-8-sig", newline="") as file:
writer = csv.DictWriter(file, fieldnames=self.prompt_fields)
writer.writeheader()
for style in (s for s in self.styles.values() if s.path == style_path):
# Skip style list dividers, e.g. "STYLES.CSV"
if style.name.lower().strip("# ") in csv_names:
continue
# Write style fields, ignoring the path field
writer.writerow(
{k: v for k, v in style._asdict().items() if k != "path"}
)

def extract_styles_from_prompt(self, prompt, negative_prompt):
extracted = []
Expand All @@ -120,7 +257,9 @@ def extract_styles_from_prompt(self, prompt, negative_prompt):
found_style = None

for style in applicable_styles:
is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt)
is_match, new_prompt, new_neg_prompt = extract_original_prompts(
style, prompt, negative_prompt
)
if is_match:
found_style = style
prompt = new_prompt
Expand Down