-
Notifications
You must be signed in to change notification settings - Fork 18
/
sample.py
executable file
·83 lines (68 loc) · 3 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/usr/bin/env python3
"""Unconditional sampling from a diffusion model."""
import argparse
from functools import partial
from pathlib import Path
import jax
import jax.numpy as jnp
from PIL import Image
from tqdm import tqdm, trange
from diffusion import get_model, get_models, load_params, sampling, utils
MODULE_DIR = Path(__file__).resolve().parent
def main():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument('--batch-size', '-bs', type=int, default=1,
help='the number of images per batch')
p.add_argument('--checkpoint', type=str,
help='the checkpoint to use')
p.add_argument('--eta', type=float, default=1.,
help='the amount of noise to add during sampling (0-1)')
p.add_argument('--init', type=str,
help='the init image')
p.add_argument('--model', type=str, choices=get_models(), required=True,
help='the model to use')
p.add_argument('-n', type=int, default=1,
help='the number of images to sample')
p.add_argument('--seed', type=int, default=0,
help='the random seed')
p.add_argument('--starting-timestep', '-st', type=float, default=0.9,
help='the timestep to start at (used with init images)')
p.add_argument('--steps', type=int, default=1000,
help='the number of timesteps')
args = p.parse_args()
model = get_model(args.model)
checkpoint = args.checkpoint
if not checkpoint:
checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pkl'
params = load_params(checkpoint)
if args.init:
_, y, x = model.shape
init = Image.open(args.init).convert('RGB').resize((x, y), Image.LANCZOS)
init = utils.from_pil_image(init)[None]
key = jax.random.PRNGKey(args.seed)
def run(key, n):
tqdm.write('Sampling...')
key, subkey = jax.random.split(key)
noise = jax.random.normal(subkey, [n, *model.shape])
key, subkey = jax.random.split(key)
sample_step = partial(sampling.jit_sample_step, extra_args={})
steps = utils.get_ddpm_schedule(jnp.linspace(1, 0, args.steps + 1)[:-1])
if args.init:
steps = steps[steps < args.starting_timestep]
alpha, sigma = utils.t_to_alpha_sigma(steps[0])
noise = init * alpha + noise * sigma
return sampling.sample_loop(model, params, subkey, noise, steps, args.eta, sample_step)
def run_all(key, n, batch_size):
for i in trange(0, n, batch_size):
key, subkey = jax.random.split(key)
cur_batch_size = min(n - i, batch_size)
outs = run(key, cur_batch_size)
for j, out in enumerate(outs):
utils.to_pil_image(out).save(f'out_{i + j:05}.png')
try:
run_all(key, args.n, args.batch_size)
except KeyboardInterrupt:
pass
if __name__ == '__main__':
main()