-
Notifications
You must be signed in to change notification settings - Fork 25
/
sampling.py
519 lines (426 loc) · 19.1 KB
/
sampling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
# pytype: skip-file
"""Various sampling methods."""
import functools
import jax
import jax.numpy as jnp
import jax.random as random
import abc
import flax
from models.utils import from_flattened_numpy, to_flattened_numpy, get_score_fn
from scipy import integrate
import sde_lib
from utils import batch_mul, batch_add
from models import utils as mutils
_CORRECTORS = {}
_PREDICTORS = {}
def register_predictor(cls=None, *, name=None):
"""A decorator for registering predictor classes."""
def _register(cls):
if name is None:
local_name = cls.__name__
else:
local_name = name
if local_name in _PREDICTORS:
raise ValueError(f'Already registered model with name: {local_name}')
_PREDICTORS[local_name] = cls
return cls
if cls is None:
return _register
else:
return _register(cls)
def register_corrector(cls=None, *, name=None):
"""A decorator for registering corrector classes."""
def _register(cls):
if name is None:
local_name = cls.__name__
else:
local_name = name
if local_name in _CORRECTORS:
raise ValueError(f'Already registered model with name: {local_name}')
_CORRECTORS[local_name] = cls
return cls
if cls is None:
return _register
else:
return _register(cls)
def get_predictor(name):
return _PREDICTORS[name]
def get_corrector(name):
return _CORRECTORS[name]
def get_sampling_fn(config, sde, model, shape, inverse_scaler, eps):
"""Create a sampling function.
Args:
config: A `ml_collections.ConfigDict` object that contains all configuration information.
sde: A `sde_lib.SDE` object that represents the forward SDE.
model: A `flax.linen.Module` object that represents the architecture of a time-dependent score-based model.
shape: A sequence of integers representing the expected shape of a single sample.
inverse_scaler: The inverse data normalizer function.
eps: A `float` number. The reverse-time SDE is only integrated to `eps` for numerical stability.
Returns:
A function that takes random states and a replicated training state and outputs samples with the
trailing dimensions matching `shape`.
"""
sampler_name = config.sampling.method
# Probability flow ODE sampling with black-box ODE solvers
if sampler_name.lower() == 'ode':
sampling_fn = get_ode_sampler(sde=sde,
model=model,
shape=shape,
inverse_scaler=inverse_scaler,
denoise=config.sampling.noise_removal,
eps=eps)
# Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases.
elif sampler_name.lower() == 'pc':
predictor = get_predictor(config.sampling.predictor.lower())
corrector = get_corrector(config.sampling.corrector.lower())
sampling_fn = get_pc_sampler(sde=sde,
model=model,
shape=shape,
predictor=predictor,
corrector=corrector,
inverse_scaler=inverse_scaler,
snr=config.sampling.snr,
n_steps=config.sampling.n_steps_each,
probability_flow=config.sampling.probability_flow,
continuous=config.training.continuous,
denoise=config.sampling.noise_removal,
eps=eps)
else:
raise ValueError(f"Sampler name {sampler_name} unknown.")
return sampling_fn
class Predictor(abc.ABC):
"""The abstract class for a predictor algorithm."""
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__()
self.sde = sde
# Compute the reverse SDE/ODE
self.rsde = sde.reverse(score_fn, probability_flow)
self.score_fn = score_fn
@abc.abstractmethod
def update_fn(self, rng, x, t):
"""One update of the predictor.
Args:
rng: A JAX random state.
x: A JAX array representing the current state
t: A JAX array representing the current time step.
Returns:
x: A JAX array of the next state.
x_mean: A JAX array. The next state without random noise. Useful for denoising.
"""
pass
class Corrector(abc.ABC):
"""The abstract class for a corrector algorithm."""
def __init__(self, sde, score_fn, snr, n_steps):
super().__init__()
self.sde = sde
self.score_fn = score_fn
self.snr = snr
self.n_steps = n_steps
@abc.abstractmethod
def update_fn(self, rng, x, t):
"""One update of the corrector.
Args:
rng: A JAX random state.
x: A JAX array representing the current state
t: A JAX array representing the current time step.
Returns:
x: A JAX array of the next state.
x_mean: A JAX array. The next state without random noise. Useful for denoising.
"""
pass
@register_predictor(name='euler_maruyama')
class EulerMaruyamaPredictor(Predictor):
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow)
def update_fn(self, rng, x, t):
dt = -1. / self.rsde.N
z = random.normal(rng, x.shape)
drift, diffusion = self.rsde.sde(x, t)
x_mean = x + drift * dt
x = x_mean + batch_mul(diffusion, jnp.sqrt(-dt) * z)
return x, x_mean
@register_predictor(name='reverse_diffusion')
class ReverseDiffusionPredictor(Predictor):
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow)
def update_fn(self, rng, x, t):
f, G = self.rsde.discretize(x, t)
z = random.normal(rng, x.shape)
x_mean = x - f
x = x_mean + batch_mul(G, z)
return x, x_mean
@register_predictor(name='ancestral_sampling')
class AncestralSamplingPredictor(Predictor):
"""The ancestral sampling predictor. Currently only supports VE/VP SDEs."""
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow)
if not isinstance(sde, sde_lib.VPSDE) and not isinstance(sde, sde_lib.VESDE):
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
assert not probability_flow, "Probability flow not supported by ancestral sampling"
def vesde_update_fn(self, rng, x, t):
sde = self.sde
timestep = (t * (sde.N - 1) / sde.T).astype(jnp.int32)
sigma = sde.discrete_sigmas[timestep]
adjacent_sigma = jnp.where(timestep == 0, jnp.zeros(t.shape), sde.discrete_sigmas[timestep - 1])
score = self.score_fn(x, t)
x_mean = x + batch_mul(score, sigma ** 2 - adjacent_sigma ** 2)
std = jnp.sqrt((adjacent_sigma ** 2 * (sigma ** 2 - adjacent_sigma ** 2)) / (sigma ** 2))
noise = random.normal(rng, x.shape)
x = x_mean + batch_mul(std, noise)
return x, x_mean
def vpsde_update_fn(self, rng, x, t):
sde = self.sde
timestep = (t * (sde.N - 1) / sde.T).astype(jnp.int32)
beta = sde.discrete_betas[timestep]
score = self.score_fn(x, t)
x_mean = batch_mul((x + batch_mul(beta, score)), 1. / jnp.sqrt(1. - beta))
noise = random.normal(rng, x.shape)
x = x_mean + batch_mul(jnp.sqrt(beta), noise)
return x, x_mean
def update_fn(self, rng, x, t):
if isinstance(self.sde, sde_lib.VESDE):
return self.vesde_update_fn(rng, x, t)
elif isinstance(self.sde, sde_lib.VPSDE):
return self.vpsde_update_fn(rng, x, t)
@register_predictor(name='none')
class NonePredictor(Predictor):
"""An empty predictor that does nothing."""
def __init__(self, sde, score_fn, probability_flow=False):
pass
def update_fn(self, rng, x, t):
return x, x
@register_corrector(name='langevin')
class LangevinCorrector(Corrector):
def __init__(self, sde, score_fn, snr, n_steps):
super().__init__(sde, score_fn, snr, n_steps)
if not isinstance(sde, sde_lib.VPSDE) \
and not isinstance(sde, sde_lib.VESDE) \
and not isinstance(sde, sde_lib.subVPSDE):
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
def update_fn(self, rng, x, t):
sde = self.sde
score_fn = self.score_fn
n_steps = self.n_steps
target_snr = self.snr
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
timestep = (t * (sde.N - 1) / sde.T).astype(jnp.int32)
alpha = sde.alphas[timestep]
else:
alpha = jnp.ones_like(t)
def loop_body(step, val):
rng, x, x_mean = val
grad = score_fn(x, t)
rng, step_rng = jax.random.split(rng)
noise = jax.random.normal(step_rng, x.shape)
grad_norm = jnp.linalg.norm(
grad.reshape((grad.shape[0], -1)), axis=-1).mean()
grad_norm = jax.lax.pmean(grad_norm, axis_name='batch')
noise_norm = jnp.linalg.norm(
noise.reshape((noise.shape[0], -1)), axis=-1).mean()
noise_norm = jax.lax.pmean(noise_norm, axis_name='batch')
step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
x_mean = x + batch_mul(step_size, grad)
x = x_mean + batch_mul(noise, jnp.sqrt(step_size * 2))
return rng, x, x_mean
_, x, x_mean = jax.lax.fori_loop(0, n_steps, loop_body, (rng, x, x))
return x, x_mean
@register_corrector(name='ald')
class AnnealedLangevinDynamics(Corrector):
"""The original annealed Langevin dynamics predictor in NCSN/NCSNv2.
We include this corrector only for completeness. It was not directly used in our paper.
"""
def __init__(self, sde, score_fn, snr, n_steps):
super().__init__(sde, score_fn, snr, n_steps)
if not isinstance(sde, sde_lib.VPSDE) \
and not isinstance(sde, sde_lib.VESDE) \
and not isinstance(sde, sde_lib.subVPSDE):
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
def update_fn(self, rng, x, t):
sde = self.sde
score_fn = self.score_fn
n_steps = self.n_steps
target_snr = self.snr
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
timestep = (t * (sde.N - 1) / sde.T).astype(jnp.int32)
alpha = sde.alphas[timestep]
else:
alpha = jnp.ones_like(t)
std = self.sde.marginal_prob(x, t)[1]
def loop_body(step, val):
rng, x, x_mean = val
grad = score_fn(x, t)
rng, step_rng = jax.random.split(rng)
noise = jax.random.normal(step_rng, x.shape)
step_size = (target_snr * std) ** 2 * 2 * alpha
x_mean = x + batch_mul(step_size, grad)
x = x_mean + batch_mul(noise, jnp.sqrt(step_size * 2))
return rng, x, x_mean
_, x, x_mean = jax.lax.fori_loop(0, n_steps, loop_body, (rng, x, x))
return x, x_mean
@register_corrector(name='none')
class NoneCorrector(Corrector):
"""An empty corrector that does nothing."""
def __init__(self, sde, score_fn, snr, n_steps):
pass
def update_fn(self, rng, x, t):
return x, x
def shared_predictor_update_fn(rng, state, x, t, sde, model, predictor, probability_flow, continuous):
"""A wrapper that configures and returns the update function of predictors."""
score_fn = mutils.get_score_fn(sde, model, state.params_ema, state.model_state, train=False, continuous=continuous)
if predictor is None:
# Corrector-only sampler
predictor_obj = NonePredictor(sde, score_fn, probability_flow)
else:
predictor_obj = predictor(sde, score_fn, probability_flow)
return predictor_obj.update_fn(rng, x, t)
def shared_corrector_update_fn(rng, state, x, t, sde, model, corrector, continuous, snr, n_steps):
"""A wrapper tha configures and returns the update function of correctors."""
score_fn = mutils.get_score_fn(sde, model, state.params_ema, state.model_state, train=False, continuous=continuous)
if corrector is None:
# Predictor-only sampler
corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps)
else:
corrector_obj = corrector(sde, score_fn, snr, n_steps)
return corrector_obj.update_fn(rng, x, t)
def get_pc_sampler(sde, model, shape, predictor, corrector, inverse_scaler, snr,
n_steps=1, probability_flow=False, continuous=False,
denoise=True, eps=1e-3):
"""Create a Predictor-Corrector (PC) sampler.
Args:
sde: An `sde_lib.SDE` object representing the forward SDE.
model: A `flax.linen.Module` object that represents the architecture of a time-dependent score-based model.
shape: A sequence of integers. The expected shape of a single sample.
predictor: A subclass of `sampling.Predictor` representing the predictor algorithm.
corrector: A subclass of `sampling.Corrector` representing the corrector algorithm.
inverse_scaler: The inverse data normalizer.
snr: A `float` number. The signal-to-noise ratio for configuring correctors.
n_steps: An integer. The number of corrector steps per predictor update.
probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor.
continuous: `True` indicates that the score model was continuously trained.
denoise: If `True`, add one-step denoising to the final samples.
eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
Returns:
A sampling function that takes random states, and a replcated training state and returns samples as well as
the number of function evaluations during sampling.
"""
# Create predictor & corrector update functions
predictor_update_fn = functools.partial(shared_predictor_update_fn,
sde=sde,
model=model,
predictor=predictor,
probability_flow=probability_flow,
continuous=continuous)
corrector_update_fn = functools.partial(shared_corrector_update_fn,
sde=sde,
model=model,
corrector=corrector,
continuous=continuous,
snr=snr,
n_steps=n_steps)
def pc_sampler(rng, state):
""" The PC sampler funciton.
Args:
rng: A JAX random state
state: A `flax.struct.dataclass` object that represents the training state of a score-based model.
Returns:
Samples, number of function evaluations
"""
# Initial sample
rng, step_rng = random.split(rng)
x = sde.prior_sampling(step_rng, shape)
timesteps = jnp.linspace(sde.T, eps, sde.N)
def loop_body(i, val):
rng, x, x_mean = val
t = timesteps[i]
vec_t = jnp.ones(shape[0]) * t
rng, step_rng = random.split(rng)
x, x_mean = corrector_update_fn(step_rng, state, x, vec_t)
rng, step_rng = random.split(rng)
x, x_mean = predictor_update_fn(step_rng, state, x, vec_t)
return rng, x, x_mean
_, x, x_mean = jax.lax.fori_loop(0, sde.N, loop_body, (rng, x, x))
# Denoising is equivalent to running one predictor step without adding noise.
return inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1)
return jax.pmap(pc_sampler, axis_name='batch')
def get_ode_sampler(sde, model, shape, inverse_scaler,
denoise=False, rtol=1e-5, atol=1e-5, method='RK45', eps=1e-3):
"""Probability flow ODE sampler with the black-box ODE solver.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
model: A `flax.linen.Module` object that represents the architecture of the score-based model.
shape: A sequence of integers. The expected shape of a single sample.
inverse_scaler: The inverse data normalizer.
denoise: If `True`, add one-step denoising to final samples.
rtol: A `float` number. The relative tolerance level of the ODE solver.
atol: A `float` number. The absolute tolerance level of the ODE solver.
method: A `str`. The algorithm used for the black-box ODE solver.
See the documentation of `scipy.integrate.solve_ivp`.
eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
Returns:
A sampling function that takes random states, and a replicated training state and returns samples
as well as the number of function evaluations during sampling.
"""
@jax.pmap
def denoise_update_fn(rng, state, x):
score_fn = get_score_fn(sde, model, state.params_ema, state.model_state, train=False, continuous=True)
# Reverse diffusion predictor for denoising
predictor_obj = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False)
vec_eps = jnp.ones((x.shape[0],)) * eps
_, x = predictor_obj.update_fn(rng, x, vec_eps)
return x
@jax.pmap
def drift_fn(state, x, t):
"""Get the drift function of the reverse-time SDE."""
score_fn = get_score_fn(sde, model, state.params_ema, state.model_state, train=False, continuous=True)
rsde = sde.reverse(score_fn, probability_flow=True)
return rsde.sde(x, t)[0]
def ode_sampler(prng, pstate, z=None):
"""The probability flow ODE sampler with black-box ODE solver.
Args:
prng: An array of random state. The leading dimension equals the number of devices.
pstate: Replicated training state for running on multiple devices.
z: If present, generate samples from latent code `z`.
Returns:
Samples, and the number of function evaluations.
"""
# Initial sample
rng = flax.jax_utils.unreplicate(prng)
rng, step_rng = random.split(rng)
if z is None:
# If not represent, sample the latent code from the prior distibution of the SDE.
x = sde.prior_sampling(step_rng, (jax.local_device_count(),) + shape)
else:
x = z
def ode_func(t, x):
x = from_flattened_numpy(x, (jax.local_device_count(),) + shape)
vec_t = jnp.ones((x.shape[0], x.shape[1])) * t
drift = drift_fn(pstate, x, vec_t)
return to_flattened_numpy(drift)
# Black-box ODE solver for the probability flow ODE
solution = integrate.solve_ivp(ode_func, (sde.T, eps), to_flattened_numpy(x),
rtol=rtol, atol=atol, method=method)
nfe = solution.nfev
x = jnp.asarray(solution.y[:, -1]).reshape((jax.local_device_count(),) + shape)
# Denoising is equivalent to running one predictor step without adding noise
if denoise:
rng, *step_rng = random.split(rng, jax.local_device_count() + 1)
step_rng = jnp.asarray(step_rng)
x = denoise_update_fn(step_rng, pstate, x)
x = inverse_scaler(x)
return x, nfe
return ode_sampler