-
Notifications
You must be signed in to change notification settings - Fork 55
/
generate_image_pairs.py
111 lines (86 loc) · 3.6 KB
/
generate_image_pairs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import os
import cv2
import numpy as np
import torch
from tqdm import tqdm
from model import Generator
from utils import ten2cv, cv2ten
import random
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def generate(args, g_ema, device, mean_latent, sample_style, add_weight_index):
if args.sample_zs is not None:
sample_zs = torch.load(args.sample_zs)
else:
sample_zs = None
with torch.no_grad():
g_ema.eval()
for i in tqdm(range(args.pics)):
if sample_zs is not None:
sample_z = sample_zs[i]
else:
sample_z = torch.randn(1, args.latent, device=device)
sample1, _ = g_ema([sample_z],
truncation=args.truncation, truncation_latent=mean_latent, return_latents=False, randomize_noise=False)
sample2, _ = g_ema([sample_z], z_embed=sample_style, add_weight_index=add_weight_index,
truncation=args.truncation, truncation_latent=mean_latent, return_latents=False, randomize_noise=False)
sample1 = ten2cv(sample1)
sample2 = ten2cv(sample2)
out = np.concatenate([sample1, sample2], axis=1)
cv2.imwrite(f'{args.outdir}/{str(i).zfill(6)}.jpg', out)
if __name__ == '__main__':
device = 'cuda'
parser = argparse.ArgumentParser()
parser.add_argument('--size', type=int, default=1024)
parser.add_argument('--pics', type=int, default=20, help='N_PICS')
parser.add_argument('--truncation', type=float, default=0.75)
parser.add_argument('--truncation_mean', type=int, default=4096)
parser.add_argument('--ckpt', type=str, default='', help='path to BlendGAN checkpoint')
parser.add_argument('--style_img', type=str, default=None, help='path to style image')
parser.add_argument('--sample_zs', type=str, default=None)
parser.add_argument('--add_weight_index', type=int, default=6)
parser.add_argument('--channel_multiplier', type=int, default=2)
parser.add_argument('--outdir', type=str, default="")
args = parser.parse_args()
outdir = args.outdir
if not os.path.exists(outdir):
os.makedirs(outdir, exist_ok=True)
args.latent = 512
args.n_mlp = 8
checkpoint = torch.load(args.ckpt)
model_dict = checkpoint['g_ema']
if "latent_avg" in checkpoint.keys():
latent_avg = checkpoint["latent_avg"]
else:
latent_avg = None
if "truncation" in checkpoint.keys():
args.truncation = checkpoint["truncation"]
print('ckpt: ', args.ckpt)
print('truncation: ', args.truncation)
g_ema = Generator(
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier, load_pretrained_vgg=False
).to(device)
g_ema.load_state_dict(model_dict)
if args.truncation < 1:
if latent_avg is not None:
mean_latent = latent_avg
print('### use mean_latent in ckpt["latent_avg"]')
else:
with torch.no_grad():
mean_latent = g_ema.mean_latent(args.truncation_mean)
print('### generate mean_latent with \'g_ema.mean_latent\'')
else:
mean_latent = None
print('### args.truncation = 1, mean_latent is None')
if args.style_img is not None:
img = cv2.imread(args.style_img, 1)
img = cv2ten(img, device)
sample_style = g_ema.get_z_embed(img)
else:
sample_style = torch.randn(1, args.latent, device=device)
generate(args, g_ema, device, mean_latent, sample_style, args.add_weight_index)
print('Done!')