Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: simplify and enhance prompt weight splitting #258

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 32 additions & 51 deletions ldm/simplet2i.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,22 +487,19 @@ def _get_uc_and_c(self, prompt, skip_normalize):

uc = self.model.get_learned_conditioning([''])

# weighted sub-prompts
subprompts, weights = T2I._split_weighted_subprompts(prompt)
if len(subprompts) > 1:
# get weighted sub-prompts
weighted_subprompts = T2I._split_weighted_subprompts(prompt, skip_normalize)

if len(weighted_subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# get total weight for normalizing
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(0, len(subprompts)):
weight = weights[i]
if not skip_normalize:
weight = weight / totalWeight
self._log_tokenization(subprompts[i])
for i in range(0, len(weighted_subprompts)):
subprompt, weight = weighted_subprompts[i]
self._log_tokenization(subprompt)
c = torch.add(
c,
self.model.get_learned_conditioning([subprompts[i]]),
self.model.get_learned_conditioning([subprompt]),
alpha=weight,
)
else: # just standard 1 prompt
Expand Down Expand Up @@ -616,52 +613,36 @@ def _load_img(self, path, width, height):
image = torch.from_numpy(image)
return 2.0 * image - 1.0

def _split_weighted_subprompts(text):
def _split_weighted_subprompts(text, skip_normalize=False):
"""
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
remaining = len(text)
prompts = []
weights = []
while remaining > 0:
if ':' in text:
idx = text.index(':') # first occurrence from start
# grab up to index as sub-prompt
prompt = text[:idx]
remaining -= idx
# remove from main text
text = text[idx + 1 :]
# find value for weight
if ' ' in text:
idx = text.index(' ') # first occurence
else: # no space, read to end
idx = len(text)
if idx != 0:
try:
weight = float(text[:idx])
except: # couldn't treat as float
print(
f"Warning: '{text[:idx]}' is not a value, are you missing a space?"
)
weight = 1.0
else: # no value found
weight = 1.0
# remove from main text
remaining -= idx
text = text[idx + 1 :]
# append the sub-prompt and its weight
prompts.append(prompt)
weights.append(weight)
else: # no : found
if len(text) > 0: # there is still text though
# take remainder as weight 1
prompts.append(text)
weights.append(1.0)
remaining = 0
return prompts, weights
prompt_parser = re.compile("""
(?P<prompt> # capture group for 'prompt'
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt'
(?: # non-capture group
:+ # match one or more ':' characters
(?P<weight> # capture group for 'weight'
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
)? # end weight capture group, make optional
\s* # strip spaces after weight
| # OR
$ # else, if no ':' then match end of line
) # end non-capture group
""", re.VERBOSE)
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
if skip_normalize:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
print("Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
equal_weight = 1 / len(parsed_prompts)
return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]

# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
Expand Down