Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
fitz authored and fitz committed Nov 30, 2022
1 parent 9bb32eb commit 115e4e0
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions stable_diffusion_tf/stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import numpy as np
from tqdm import tqdm
import math
Expand Down Expand Up @@ -58,16 +60,19 @@ def generate(
progress_call_back=None
):
# 1) 将条件输入文本(prompt) Tokenize, prompt不能太长, 要小于77个token
inputs: List[int] = self.tokenizer.encode(prompt) # a list of int, start with 49406, end with 49407
assert len(inputs) < MAX_TEXT_LEN, "Prompt is too long (should be < 77 tokens)"
phrase = inputs + [49407] * (MAX_TEXT_LEN - len(inputs)) # padding to MAX_TEXT_LEN
phrase = np.array(phrase)[None].astype("int32") # shape=(1, MAX_TEXT_LEN)
phrase = np.repeat(phrase, batch_size, axis=0)

# Encode prompt tokens (and their positions) into a "context vector"
pos_ids = np.array(list(range(MAX_TEXT_LEN)))[None].astype("int32") # shape=(1, MAX_TEXT_LEN)
pos_ids = np.repeat(pos_ids, batch_size, axis=0)
context = self.text_encoder.predict_on_batch([phrase, pos_ids])
if prompt:
inputs: List[int] = self.tokenizer.encode(prompt) # a list of int, start with 49406, end with 49407
assert len(inputs) < MAX_TEXT_LEN, "Prompt is too long (should be < 77 tokens)"
phrase = inputs + [49407] * (MAX_TEXT_LEN - len(inputs)) # padding to MAX_TEXT_LEN
phrase = np.array(phrase)[None].astype("int32") # shape=(1, MAX_TEXT_LEN)
phrase = np.repeat(phrase, batch_size, axis=0)
# Encode prompt tokens (and their positions) into a "context vector"
context = self.text_encoder.predict_on_batch([phrase, pos_ids])
else:
logging.info('prompt is empty!')
context = None

input_image_tensor, input_image_array = None, None
if input_image is not None:
Expand Down Expand Up @@ -212,11 +217,14 @@ def get_model_output(
unconditional_latent = self.diffusion_model.predict_on_batch(
[latent, t_emb, unconditional_context]
)
latent = self.diffusion_model.predict_on_batch([latent, t_emb, context])
return unconditional_latent + unconditional_guidance_scale * (
latent - unconditional_latent
)

if context:
latent = self.diffusion_model.predict_on_batch([latent, t_emb, context])
return unconditional_latent + unconditional_guidance_scale * (
latent - unconditional_latent
)
else:
return unconditional_latent

def get_x_prev_and_pred_x0(self, x, e_t, index, a_t, a_prev, temperature, seed):
sigma_t = 0
sqrt_one_minus_at = math.sqrt(1 - a_t)
Expand Down

0 comments on commit 115e4e0

Please sign in to comment.