-
Notifications
You must be signed in to change notification settings - Fork 12
/
generate_gif.py
118 lines (95 loc) · 5.18 KB
/
generate_gif.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""Generate GIF using pretrained network pickle."""
import os
import click
import dnnlib
import numpy as np
from PIL import Image
import torch
import legacy
#----------------------------------------------------------------------------
@click.command()
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
@click.option('--seed', help='Random seed', default=0, type=int)
@click.option('--num-rows', help='Number of rows', default=2, type=int)
@click.option('--num-cols', help='Number of columns', default=2, type=int)
@click.option('--resolution', help='Resolution of the output images', default=128, type=int)
@click.option('--num-phases', help='Number of phases', default=5, type=int)
@click.option('--transition-frames', help='Number of transition frames per phase', default=20, type=int)
@click.option('--static-frames', help='Number of static frames per phase', default=5, type=int)
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
@click.option('--output', type=str, required=True)
# DFMGAN args
@click.option('--latent-mode', help='randomly sampled latent codes', type=click.Choice(['both', 'content', 'defect', 'nores', 'none']), default='both', show_default=True)
def generate_gif(
network_pkl: str,
seed: int,
num_rows: int,
num_cols: int,
resolution: int,
num_phases: int,
transition_frames: int,
static_frames: int,
truncation_psi: float,
noise_mode: str,
output: str,
latent_mode: str,
):
"""Generate gif using pretrained network pickle.
Examples:
\b
python generate_gif.py --output=obama.gif --seed=0 --num-rows=1 --num-cols=8 \\
--network=https://hanlab.mit.edu/projects/data-efficient-gans/models/DiffAugment-stylegan2-100-shot-obama.pkl
"""
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
transfer = (G.transfer != 'none')
outdir = os.path.dirname(output)
if outdir:
os.makedirs(outdir, exist_ok=True)
np.random.seed(seed)
output_seq = []
batch_size = num_rows * num_cols
latent_size = G.z_dim
latents = [np.random.randn(batch_size, latent_size) for _ in range(num_phases)]
if transfer:
latents_defect = [np.random.randn(batch_size, latent_size) for _ in range(num_phases)]
def to_image_grid(outputs):
outputs = np.reshape(outputs, [num_rows, num_cols, *outputs.shape[1:]])
outputs = np.concatenate(outputs, axis=1)
outputs = np.concatenate(outputs, axis=1)
return Image.fromarray(outputs).resize((resolution * num_cols, resolution * num_rows), Image.ANTIALIAS)
def generate(dlatents):
images = G.synthesis(dlatents, noise_mode=noise_mode)
images = (images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
return to_image_grid(images)
def transfer_generate(dlatents, defectlatents):
images = G.synthesis(dlatents, defectlatents, noise_mode=noise_mode) if latent_mode != 'nores' else G.synthesis(dlatents, defectlatents, noise_mode=noise_mode, fix_residual_to_zero = True)
images = (images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
return to_image_grid(images)
for i in range(num_phases):
dlatents0 = G.mapping(torch.from_numpy(latents[i - 1] if latent_mode not in ['defect', 'none'] else latents[0]).to(device), None)
dlatents1 = G.mapping(torch.from_numpy(latents[i] if latent_mode not in ['defect', 'none'] else latents[0]).to(device), None)
if transfer:
defectlatents0 = G.defect_mapping(torch.from_numpy(latents_defect[i - 1] if latent_mode not in ['content', 'none'] else latents_defect[0]).to(device), None)
defectlatents1 = G.defect_mapping(torch.from_numpy(latents_defect[i] if latent_mode not in ['content', 'none'] else latents_defect[0]).to(device), None)
for j in range(transition_frames):
dlatents = (dlatents0 * (transition_frames - j) + dlatents1 * j) / transition_frames
if transfer:
defectlatents = (defectlatents0 * (transition_frames - j) + defectlatents1 * j) / transition_frames
output_seq.append(transfer_generate(dlatents, defectlatents))
else:
output_seq.append(generate(dlatents))
if transfer:
output_seq.extend([transfer_generate(dlatents1, defectlatents1)] * static_frames)
else:
output_seq.extend([generate(dlatents1)] * static_frames)
if not output.endswith('.gif'):
output += '.gif'
output_seq[0].save(output, save_all=True, append_images=output_seq[1:], optimize=False, duration=50, loop=0)
#----------------------------------------------------------------------------
if __name__ == "__main__":
generate_gif() # pylint: disable=no-value-for-parameter
#----------------------------------------------------------------------------