-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_vanilla_pipeline.py
112 lines (93 loc) · 4.18 KB
/
run_vanilla_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import pprint
import os
from typing import List
import torch
from PIL import Image
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
from utils.ptp_utils import *
from utils.magnet import *
from pytorch_lightning import seed_everything
import time
import argparse
@torch.no_grad()
def main(opt):
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
pipe = StableDiffusionPipeline.from_pretrained(opt.sd_path).to(device)
tokenizer = pipe.tokenizer
text_encoder = pipe.text_encoder
parser = stanza.Pipeline(lang='en', processors='tokenize,pos,constituency', download_method=None)
candidates, candidate_embs = prepare_candidates(offline_file=opt.magnet_path)
candidates = candidates.to(device)
candidate_embs = candidate_embs.to(device)
output_path = "outputs"
os.makedirs(output_path, exist_ok=True)
default_prompts = [
"a lone, green fire hydrant sits in red grass",
"some blue bananas with little yellow stickers on them",
"a bowl of broccoli and red rice with a white sauce"
]
test_prompts = default_prompts if opt.prompts is None else opt.prompts
num_images_per_prompt = opt.N
METHOD_LIST = ["sd", "magnet"] if opt.run_sd else ["magnet"]
for bid, prompt in enumerate(test_prompts):
for METHOD in METHOD_LIST:
if METHOD == "magnet":
try:
with torch.no_grad():
magnet_embeds = get_magnet_direction(
parser,
tokenizer,
text_encoder,
prompt,
candidates,
candidate_embs,
alpha_lambda=opt.L, K=opt.K, neighbor="feature"
)
except:
print(f"Fail to apply Magnet at prompt: {prompt}")
break
else:
magnet_embeds = None
cur_seed = 14273 + bid * 55
seed_everything(int(cur_seed))
if magnet_embeds is None:
outputs = pipe(
prompt=prompt,
guidance_scale=opt.cfg_scale,
guidance_rescale=0.,
num_inference_steps=opt.ddim_steps,
num_images_per_prompt=num_images_per_prompt
).images
else:
outputs = pipe(
guidance_scale=opt.cfg_scale,
guidance_rescale=0.,
prompt_embeds=magnet_embeds,
num_inference_steps=opt.ddim_steps,
num_images_per_prompt=num_images_per_prompt
).images
[outputs[i].save(os.path.join(output_path, f'{prompt}_seed{cur_seed}_{i}_{METHOD}.png')) for i in range(num_images_per_prompt)]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--sd_path', type=str, required=True,
help='Path to pre-trained Stable Diffusion.')
parser.add_argument('--magnet_path', type=str,
help='Path to the local file of the candidate embedding to save time.'
)
parser.add_argument('--prompts', nargs='+', default=None,
help='Prompt list to generate a batch of images.')
# Magnet settings
parser.add_argument('--K', type=int, default=5,
help='Hyperparameter of Magnet for the number of neighbor objects.')
parser.add_argument('--L', type=float, default=0.6,
help='Hyperparameter of Magnet for adaptive strength.')
# SD settings
parser.add_argument('--cfg_scale', type=float, default=7.5)
parser.add_argument('--ddim_steps', type=int, default=50)
parser.add_argument('--N', type=int, default=1,
help='Number of image for each prompt.')
parser.add_argument('--run_sd', action='store_true')
args = parser.parse_args()
main(args)