Skip to content

Commit

Permalink
Update stable_diffusion.py
Browse files Browse the repository at this point in the history
  • Loading branch information
costiash authored Oct 30, 2022
1 parent b5b0fa5 commit 41973c2
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions stable_diffusion_tf/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def generate(
input_image = input_image.resize((self.img_width, self.img_height))

elif type(input_image) is np.ndarray:
input_image = np.resize(input_image, (self.img_width, self.img_height, input_image.shape[2]))
input_image = np.resize(input_image, (self.img_height, self.img_width, input_image.shape[2]))

input_image_array = np.array(input_image, dtype=np.float32)[None,...,:3]
input_image_tensor = tf.cast((input_image_array / 255.0) * 2 - 1, self.dtype)
Expand Down Expand Up @@ -200,7 +200,7 @@ def get_x_prev_and_pred_x0(self, x, e_t, index, a_t, a_prev, temperature, seed):

def load_weights_from_pytorch_ckpt(self , pytorch_ckpt_path):
import torch
pt_weights = torch.load(pytorch_ckpt_path)
pt_weights = torch.load(pytorch_ckpt_path, map_location="cpu")
for module_name in ['text_encoder', 'diffusion_model', 'decoder', 'encoder' ]:
module_weights = []
for i , (key , perm ) in enumerate(PYTORCH_CKPT_MAPPING[module_name]):
Expand Down

0 comments on commit 41973c2

Please sign in to comment.