From 115e4e02ba8a57c1bcd88c24cc2d5a68b1fc6241 Mon Sep 17 00:00:00 2001 From: fitz Date: Wed, 30 Nov 2022 23:55:17 +0800 Subject: [PATCH] fix --- stable_diffusion_tf/stable_diffusion.py | 34 +++++++++++++++---------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/stable_diffusion_tf/stable_diffusion.py b/stable_diffusion_tf/stable_diffusion.py index 2fae8a1..a9a34d3 100644 --- a/stable_diffusion_tf/stable_diffusion.py +++ b/stable_diffusion_tf/stable_diffusion.py @@ -1,3 +1,5 @@ +import logging + import numpy as np from tqdm import tqdm import math @@ -58,16 +60,19 @@ def generate( progress_call_back=None ): # 1) 将条件输入文本(prompt) Tokenize, prompt不能太长, 要小于77个token - inputs: List[int] = self.tokenizer.encode(prompt) # a list of int, start with 49406, end with 49407 - assert len(inputs) < MAX_TEXT_LEN, "Prompt is too long (should be < 77 tokens)" - phrase = inputs + [49407] * (MAX_TEXT_LEN - len(inputs)) # padding to MAX_TEXT_LEN - phrase = np.array(phrase)[None].astype("int32") # shape=(1, MAX_TEXT_LEN) - phrase = np.repeat(phrase, batch_size, axis=0) - - # Encode prompt tokens (and their positions) into a "context vector" pos_ids = np.array(list(range(MAX_TEXT_LEN)))[None].astype("int32") # shape=(1, MAX_TEXT_LEN) pos_ids = np.repeat(pos_ids, batch_size, axis=0) - context = self.text_encoder.predict_on_batch([phrase, pos_ids]) + if prompt: + inputs: List[int] = self.tokenizer.encode(prompt) # a list of int, start with 49406, end with 49407 + assert len(inputs) < MAX_TEXT_LEN, "Prompt is too long (should be < 77 tokens)" + phrase = inputs + [49407] * (MAX_TEXT_LEN - len(inputs)) # padding to MAX_TEXT_LEN + phrase = np.array(phrase)[None].astype("int32") # shape=(1, MAX_TEXT_LEN) + phrase = np.repeat(phrase, batch_size, axis=0) + # Encode prompt tokens (and their positions) into a "context vector" + context = self.text_encoder.predict_on_batch([phrase, pos_ids]) + else: + logging.info('prompt is empty!') + context = None input_image_tensor, input_image_array = None, None if input_image is not None: @@ -212,11 +217,14 @@ def get_model_output( unconditional_latent = self.diffusion_model.predict_on_batch( [latent, t_emb, unconditional_context] ) - latent = self.diffusion_model.predict_on_batch([latent, t_emb, context]) - return unconditional_latent + unconditional_guidance_scale * ( - latent - unconditional_latent - ) - + if context: + latent = self.diffusion_model.predict_on_batch([latent, t_emb, context]) + return unconditional_latent + unconditional_guidance_scale * ( + latent - unconditional_latent + ) + else: + return unconditional_latent + def get_x_prev_and_pred_x0(self, x, e_t, index, a_t, a_prev, temperature, seed): sigma_t = 0 sqrt_one_minus_at = math.sqrt(1 - a_t)