forked from lucidrains/DALLE-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
generate.py
130 lines (84 loc) · 3.93 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import argparse
from pathlib import Path
from tqdm import tqdm
# torch
import torch
from einops import repeat
# vision imports
from PIL import Image
from torchvision.utils import make_grid, save_image
# dalle related classes and utils
from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer
# argument parsing
parser = argparse.ArgumentParser()
parser.add_argument('--dalle_path', type = str, required = True,
help='path to your trained DALL-E')
parser.add_argument('--vqgan_model_path', type=str, default = None,
help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)')
parser.add_argument('--vqgan_config_path', type=str, default = None,
help='path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)')
parser.add_argument('--text', type = str, required = True,
help='your text prompt')
parser.add_argument('--num_images', type = int, default = 128, required = False,
help='number of images')
parser.add_argument('--batch_size', type = int, default = 4, required = False,
help='batch size')
parser.add_argument('--top_k', type = float, default = 0.9, required = False,
help='top k filter threshold')
parser.add_argument('--outputs_dir', type = str, default = './outputs', required = False,
help='output directory')
parser.add_argument('--bpe_path', type = str,
help='path to your huggingface BPE json file')
parser.add_argument('--hug', dest='hug', action = 'store_true')
parser.add_argument('--chinese', dest='chinese', action = 'store_true')
parser.add_argument('--taming', dest='taming', action='store_true')
parser.add_argument('--gentxt', dest='gentxt', action='store_true')
args = parser.parse_args()
# helper fns
def exists(val):
return val is not None
# tokenizer
if exists(args.bpe_path):
klass = HugTokenizer if args.hug else YttmTokenizer
tokenizer = klass(args.bpe_path)
elif args.chinese:
tokenizer = ChineseTokenizer()
# load DALL-E
dalle_path = Path(args.dalle_path)
assert dalle_path.exists(), 'trained DALL-E must exist'
load_obj = torch.load(str(dalle_path))
dalle_params, vae_params, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights')
dalle_params.pop('vae', None) # cleanup later
if args.taming:
vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path)
elif vae_params is not None:
vae = DiscreteVAE(**vae_params)
else:
vae = OpenAIDiscreteVAE()
dalle = DALLE(vae = vae, **dalle_params).cuda()
dalle.load_state_dict(weights)
# generate images
image_size = vae.image_size
texts = args.text.split('|')
for j, text in tqdm(enumerate(texts)):
if args.gentxt:
text_tokens, gen_texts = dalle.generate_texts(tokenizer, text=text, filter_thres = args.top_k)
text = gen_texts[0]
else:
text_tokens = tokenizer.tokenize([text], dalle.text_seq_len).cuda()
text_tokens = repeat(text_tokens, '() n -> b n', b = args.num_images)
outputs = []
for text_chunk in tqdm(text_tokens.split(args.batch_size), desc = f'generating images for - {text}'):
output = dalle.generate_images(text_chunk, filter_thres = args.top_k)
outputs.append(output)
outputs = torch.cat(outputs)
# save all images
file_name = text
outputs_dir = Path(args.outputs_dir) / file_name.replace(' ', '_')[:(100)]
outputs_dir.mkdir(parents = True, exist_ok = True)
for i, image in tqdm(enumerate(outputs), desc = 'saving images'):
save_image(image, outputs_dir / f'{i}.jpg', normalize=True)
with open(outputs_dir / 'caption.txt', 'w') as f:
f.write(file_name)
print(f'created {args.num_images} images at "{str(outputs_dir)}"')