Skip to content

Commit

Permalink
Merge branch 'divamgupta:master' into costa
Browse files Browse the repository at this point in the history
  • Loading branch information
costiash authored Nov 11, 2022
2 parents 10a6354 + 9d92d82 commit 50b592f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
7 changes: 7 additions & 0 deletions img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
help="the prompt to render",
)

parser.add_argument(
"--negative-prompt",
type=str,
help="the negative prompt to use (if any)",
)

parser.add_argument(
"--steps",
type=int,
Expand Down Expand Up @@ -45,6 +51,7 @@

img = generator.generate(
args.prompt,
negative_prompt=args.negative_prompt,
num_steps=args.steps,
unconditional_guidance_scale=7.5,
temperature=1,
Expand Down
18 changes: 13 additions & 5 deletions stable_diffusion_tf/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, img_height=1000, img_width=1000, jit_compile=False, download_
def generate(
self,
prompt,
negative_prompt=None,
batch_size=1,
num_steps=25,
unconditional_guidance_scale=7.5,
Expand Down Expand Up @@ -85,13 +86,19 @@ def generate(
latent_mask_tensor = tf.cast(tf.repeat(latent_mask, batch_size , axis=0), self.dtype)


# Tokenize negative prompt or use default padding tokens
unconditional_tokens = _UNCONDITIONAL_TOKENS
if negative_prompt is not None:
inputs = self.tokenizer.encode(negative_prompt)
assert len(inputs) < 77, "Negative prompt is too long (should be < 77 tokens)"
unconditional_tokens = inputs + [49407] * (77 - len(inputs))

# Encode unconditional tokens (and their positions into an
# "unconditional context vector"
unconditional_tokens = np.array(_UNCONDITIONAL_TOKENS)[None].astype("int32")
unconditional_tokens = np.array(unconditional_tokens)[None].astype("int32")
unconditional_tokens = np.repeat(unconditional_tokens, batch_size, axis=0)
self.unconditional_tokens = tf.convert_to_tensor(unconditional_tokens)
unconditional_context = self.text_encoder.predict_on_batch(
[self.unconditional_tokens, pos_ids]
[unconditional_tokens, pos_ids]
)
timesteps = np.arange(1, 1000, 1000 // num_steps)
input_img_noise_t = timesteps[ int(len(timesteps)*input_image_strength) ]
Expand Down Expand Up @@ -146,9 +153,10 @@ def timestep_embedding(self, timesteps, dim=320, max_period=10000):
embedding = np.concatenate([np.cos(args), np.sin(args)])
return tf.convert_to_tensor(embedding.reshape(1, -1),dtype=self.dtype)

def add_noise(self, x , t ):
def add_noise(self, x , t , noise=None ):
batch_size,w,h = x.shape[0] , x.shape[1] , x.shape[2]
noise = tf.random.normal((batch_size,w,h,4), dtype=self.dtype)
if noise is None:
noise = tf.random.normal((batch_size,w,h,4), dtype=self.dtype)
sqrt_alpha_prod = _ALPHAS_CUMPROD[t] ** 0.5
sqrt_one_minus_alpha_prod = (1 - _ALPHAS_CUMPROD[t]) ** 0.5

Expand Down
7 changes: 7 additions & 0 deletions text2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
help="the prompt to render",
)

parser.add_argument(
"--negative-prompt",
type=str,
help="the negative prompt to use (if any)",
)

parser.add_argument(
"--output",
type=str,
Expand Down Expand Up @@ -68,6 +74,7 @@
generator = StableDiffusion(img_height=args.H, img_width=args.W, jit_compile=False)
img = generator.generate(
args.prompt,
negative_prompt=args.negative_prompt,
num_steps=args.steps,
unconditional_guidance_scale=args.scale,
temperature=1,
Expand Down

0 comments on commit 50b592f

Please sign in to comment.