-
Notifications
You must be signed in to change notification settings - Fork 0
/
unet.py
583 lines (515 loc) · 21.9 KB
/
unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
"""
Defines the models including the U-Net architecture that will be used for gradual denoising
of the latent space.
"""
import jax
import flax
from nn import normalization, timestep_embedding
from flax import linen as nn
import jax.numpy as jnp
from abc import abstractmethod
from typing import Collection
class TimeStepBlock(nn.Module):
"""
Any module where __call__ takes timestep embeddings as a second argument.
"""
@abstractmethod
def __call__(self, x, emb):
"""
Apply the module to 'x' given 'emb' as the timestep embedding.
"""
class TimeStepEmbedSequential(nn.Sequential, TimeStepBlock):
"""
A sequential module that passes timestep embeddings to children
that support it (instances of TimeStepBlock) as an extra input.
"""
def __call__(self, x, emb):
for layer in self.layers:
if isinstance(layer, TimeStepBlock):
x = layer(x, emb)
else:
x = layer(x)
return x
class Upsample(nn.Module):
"""
An upsampling module with an optional convolution.
channels: channels in the inputs and outputs
use_conv: a bool determining if convolution is applied
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
(unit-tested)
"""
channels: int
use_conv: bool
dims: int = 2
out_channels: int = None
@nn.compact
def __call__(self, x):
if self.dims == 3:
(B, D, H, W, C) = x.shape
x = jax.image.resize(x, shape=(B, D, H * 2, W * 2, C),
method='nearest')
elif self.dims == 2:
(B, H, W, C) = x.shape
x = jax.image.resize(x, shape=(B, H * 2, W * 2, C),
method='nearest')
elif self.dims == 1:
(B, H, C) = x.shape
x = jax.image.resize(x, shape=(B, H * 2, C), method='nearest')
else:
raise ValueError(f"Unsupported dimensions: {self.dims}")
if self.use_conv:
channels = self.channels
if self.out_channels is not None:
channels = self.out_channels # set the output channels if they are not equal to self.channels
kernel_length = 3
kernel_size = tuple([kernel_length for i in range(self.dims)])
x = nn.Conv(features=channels,
kernel_size=kernel_size,
strides=1,
padding='SAME')(x)
return x
class Downsample(nn.Module):
"""
A downsampling module with an optional convolution.
channels: channels in the inputs and outputs
use_conv: a bool determining if convolution is applied
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
(if use_conv, the method performs a convolutional downsampling using strides.
else, it performs an average pooling downsampling.)
(unit-tested)
"""
channels: int
use_conv: bool
dims: int = 2
out_channels: int = None
@nn.compact
def __call__(self, x):
# Determine the number of output channels.
if self.out_channels is not None:
out_channels = self.out_channels
else:
out_channels = self.channels
# If use_conv, use convolution with stride 2.
if self.use_conv:
kernel_size = tuple([3 for i in range(self.dims)]) # 3x3 convolution for 2D, 3x3x3 convolution for 3D
# There is an edge-case with 3D convolutions.
if self.dims == 3:
x = nn.Conv(features=out_channels,
kernel_size=kernel_size,
strides=(1, 2, 2), # Keeps the spatial dimensions the same
padding='SAME')(x)
elif self.dims == 2 or self.dims == 1:
x = nn.Conv(features=out_channels,
kernel_size=kernel_size,
strides=2,
padding='SAME')(x)
else:
raise ValueError(f"Unsupported dimensions: {self.dims}")
else:
if self.dims == 1:
x = nn.avg_pool(x, window_shape=(2,), strides=(2,))
elif self.dims == 2:
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
elif self.dims == 3:
x = nn.avg_pool(x,
window_shape=(1, 2, 2),
strides=(1, 2, 2),
padding='SAME')
return x
class Identity(nn.Module):
"""
A utility module that simply returns the input.
"""
@nn.compact
def __call__(self, x):
return x
class ResBlock(TimeStepBlock):
"""
A residual block that can optionally change the number of channels.
channels: the number of input channels
emb_channels: the number of timestep embedding channels
use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
out_channels: if specified, the number of out channels
dims: determines if the signal is 1D, 2D, or 3D.
use_checkpoint: if True, use gradient checkpointing on this module
up: if True, use this block for upsampling
down: if True, use this block for downsampling
(unit-tested)
"""
channels: int
emb_channels: int
dropout: float = 0.0
out_channels: int = None
use_conv: bool = True
dims: int = 2
use_scale_shift_norm: bool = False
up: bool = False
down: bool = False
def setup(self):
out_channels = self.out_channels or self.channels # TODO: this will not be allowed in Flax!
# in_rest and in_conv form the in_layers, but they need to be separate
self.in_rest = nn.Sequential([
normalization(self.channels),
jax.nn.silu,
])
self.in_conv = nn.Conv(features=out_channels,
kernel_size=tuple([3 for i in range(self.dims)]),
padding='SAME')
self.updown = self.up or self.down # whether to upsample or downsample
# (1) initialize the upsampling or downsampling layers
if self.up:
self.h_upd = Upsample(self.channels, use_conv=False, dims=self.dims)
self.x_upd = Upsample(self.channels, use_conv=False, dims=self.dims)
elif self.down:
self.h_upd = Downsample(self.channels, use_conv=False, dims=self.dims)
self.x_upd = Downsample(self.channels, use_conv=False, dims=self.dims)
else:
self.h_upd = Identity()
self.x_upd = Identity()
# (2) initialize the embedding layers
self.emb_layers = nn.Sequential([
jax.nn.silu,
nn.Dense(features=(
2 * out_channels if self.use_scale_shift_norm else out_channels)
)
])
# (3) initialize the output layers
self.out_layers = nn.Sequential([
normalization(out_channels),
jax.nn.silu,
# nn.Dropout(self.dropout),
nn.Conv(features=out_channels,
kernel_size=tuple([3 for i in range(self.dims)]),
padding='SAME')
])
# Skip connection logic
if out_channels == self.channels:
self.skip_connection = Identity()
elif self.use_conv:
# self.skip_connection = nn.Conv()
self.skip_connection = nn.Conv(features=out_channels,
kernel_size=tuple([3 for i in range(self.dims)]),
padding='SAME')
else:
# If not using convolution, use a 1x1 conv to get the output channels right.
self.skip_connection = nn.Conv(features=out_channels,
kernel_size=tuple([1 for i in range(self.dims)]),
padding='SAME')
def __call__(self, x, emb):
# (1) if up or downsample, use the upsampling or downsampling layers (apply the convolution in in_layers after
# the op). Otherwise, apply in_rest and in_conv directly.
if self.updown:
h = self.in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = self.in_conv(h)
else:
h = self.in_rest(x)
h = self.in_conv(h)
# (2) embedding layers logic
emb_out = self.emb_layers(emb)
# (3) scale shift norm ? (else, add the timestep embedding and pass it through the output layers)
if self.use_scale_shift_norm:
raise NotImplementedError("scale-shift normalization not implemented")
else:
h = h + emb_out
h = self.out_layers(h)
# (4) residual with the skip connection, return sum.
return self.skip_connection(x) + h
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
in_channels: channels in the input Tensor.
model_channels: base channel count for the model.
out_channels: channels in the output Tensor.
num_res_blocks: number of residual blocks per downsample.
attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
dropout: the dropout probability.
channel_mult: channel multiplier for each level of the UNet.
conv_resample: if True, use learned convolutions for upsampling and downsampling.
dims: determines if the signal is 1D, 2D, or 3D.
num_heads: the number of attention heads in each attention layer.
num_head_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
use_scale_shift_norm: use a FiLM-like conditioning mechanism.
resblock_updown: use residual blocks for up/downsampling.
use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
(unit-tested)
Deleted pieces:
- checkpointing
- num_classes (not planning on doing class-conditional generation)
"""
in_channels: int
model_channels: int
out_channels: int
num_res_blocks: int
attention_resolutions: Collection[int]
dropout: float = 0.0
channel_mult: Collection[int] = (1, 2, 4, 8)
conv_resample: bool = True
dims: int = 2
num_heads: int = 1
num_head_channels: int = -1
use_scale_shift_norm: bool = False
resblock_updown: bool = False
use_new_attention_order: bool = False
def setup(self):
time_embed_dim = self.model_channels * 4
self.time_embed = nn.Sequential([
nn.Dense(time_embed_dim),
jax.nn.silu,
nn.Dense(time_embed_dim),
]) # projects model_channel to size 4 * model_channel
ch = input_ch = int(self.channel_mult[0] * self.model_channels)
# 1st element of input_blocks: a convolution with output channels equal to ch
input_blocks = [
TimeStepEmbedSequential([nn.Conv(features=ch,
kernel_size=tuple([3 for i in range(self.dims)]),
padding='SAME')])
]
self._feature_size = ch
# Add the channels to a list input_block_channels
input_block_channels = [ch]
ds = 1
# Enumerate channel_mult to (level, mult)
# for _ in range(self.num_res_blocks):
# Add a residual block to input blocks
# Check attention_resolutions, if this res is in it, add an attention block as well
# Add ch = int(mult * model_channels) to input_block_channels
# endfor
# if level != len(channel_mult) - 1: (meaning if we are not at the end of the channel_mult list)
# add a downsampling residual block, (or a simple Downsample if resblock_updown is False)
# update ds, ch, and self.feature_size
for level, mult in enumerate(self.channel_mult):
for _ in range(self.num_res_blocks):
layers = [
ResBlock(channels=ch,
emb_channels=time_embed_dim,
dropout=self.dropout,
out_channels=int(mult * self.model_channels), # double channels with each halving of
# resolution
)
]
ch = int(mult * self.model_channels) # set channels to the output of ResBlock for the attention
# (or next iteration)
if ds in self.attention_resolutions:
layers.append(
AttentionBlock(channels=ch,
num_heads=self.num_heads,
num_head_channels=self.num_head_channels)
)
input_blocks.append(TimeStepEmbedSequential(layers)) # make everything into a nn.Sequential
self._feature_size += ch
input_block_channels.append(ch)
# ends inner for
if level != len(self.channel_mult) - 1:
out_ch = ch
if self.resblock_updown:
input_blocks.append(
TimeStepEmbedSequential([
ResBlock(channels=ch,
emb_channels=time_embed_dim,
dropout=self.dropout,
out_channels=out_ch,
down=True)
])
)
else:
input_blocks.append(
TimeStepEmbedSequential([
Downsample(channels=ch,
use_conv=self.conv_resample,
dims=self.dims,
out_channels=out_ch)
])
)
ch = out_ch # required for the next run of the for loop
input_block_channels.append(ch)
ds *= 2
self._feature_size += ch
self.input_blocks = input_blocks
# Middle Block consists of:
# - ResBlock
# - AttentionBlock
# - ResBlock
self.middle_block = TimeStepEmbedSequential([
ResBlock(channels=ch,
emb_channels=time_embed_dim,
dropout=self.dropout),
AttentionBlock(channels=ch,
num_heads=self.num_heads,
num_head_channels=self.num_head_channels),
ResBlock(channels=ch,
emb_channels=time_embed_dim,
dropout=self.dropout)
])
self._feature_size += ch
# Output Block
# do the same logic as the input layers, only in reverse
output_blocks = []
for level, mult in reversed(list(enumerate(self.channel_mult))):
for i in range(self.num_res_blocks + 1):
ich = input_block_channels.pop()
layers = [
ResBlock(channels=ich + ch,
emb_channels=time_embed_dim,
dropout=self.dropout,
out_channels=int(mult * self.model_channels),
dims=self.dims)
]
ch = int(mult * self.model_channels)
if ds in self.attention_resolutions:
layers.append(AttentionBlock(channels=ch,
num_heads=self.num_heads,
num_head_channels=self.num_head_channels))
if level and i == self.num_res_blocks:
out_ch = ch
if self.resblock_updown:
layers.append(
ResBlock(channels=ch,
emb_channels=time_embed_dim,
dropout=self.dropout,
out_channels=out_ch,
up=True)
)
else:
layers.append(Upsample(channels=ch,
use_conv=self.conv_resample,
dims=self.dims,
out_channels=out_ch))
ds //= 2
output_blocks.append(TimeStepEmbedSequential(layers))
self._feature_size += ch
self.output_blocks = output_blocks
# Final Output
self.out = nn.Sequential([
normalization(ch),
jax.nn.silu,
nn.Conv(features=self.out_channels,
kernel_size=tuple([3 for i in range(self.dims)]),
padding='SAME'),
])
# self.input_blocks
def __call__(self, x, timesteps):
"""
Args:
x: [N x ... (spatial dims)] Tensor of inputs
timesteps: a 1D batch of timesteps
Returns: [N x ... (spatial dims)] Tensor of outputs
"""
# Required in the symmetric skip connections.
hs = []
# Timestep embeddings
emb = self.time_embed(timestep_embedding(timesteps,
dim=self.model_channels)) # (N, time_embed_dim)
for _ in range(self.dims):
emb = jnp.expand_dims(emb, axis=1)
h = x
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
h = jnp.concatenate([h, hs.pop()], axis=-1) # concatenate along channel dimension
h = module(h, emb)
return self.out(h)
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
channels: the number of input channels
num_heads: the number of attention heads
num_head_channels:
use_new_attention_order:
(unit-tested)
"""
channels: int
num_heads: int = 1
num_head_channels: int = -1
use_new_attention_order: bool = False
@nn.compact
def __call__(self, x):
B, *spatial, C = x.shape # get dimensions
# flatten x
x = jnp.reshape(x, (B, -1, C))
# normalize
h = normalization(C)(x)
# convolve to get qkv
qkv = nn.Conv(features=self.num_heads * self.channels * 3, # Getting the qkv tensor for multi-head attention
kernel_size=(1,))(h)
# attention
h = QKVAttention(self.num_heads)(qkv) # (B, T, d_model)
h = h.transpose((0, 2, 1))
# projection layer
h = (nn.Conv(features=self.channels, kernel_size=(1,)))(h) # TODO: there is supposed to be a ZeroModule here
# - I don't know how to zero out the parameters of a module yet!
# residual connection and reshape to original shape
return (x + h).reshape(B, *spatial, C)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention.
"""
num_heads: int = 1
@nn.compact
def __call__(self, qkv):
"""
Args:
qkv: the combined query key value tensor of shape (B, d_model, num_heads * T * 3)
Returns:
a value tensor of shape (B, T, d_model)
"""
# Shape wrangling
B, d_model, product = qkv.shape
T = product // (self.num_heads * 3)
qkv = qkv.reshape((B, d_model, self.num_heads, T, 3)) # (B, d_model, num_heads, T, 3)
qkv = qkv.transpose((0, 4, 2, 3, 1)) # (B, 3, num_heads, T, d_model)
q, k, v = jnp.split(qkv, 3, axis=1)
q = jnp.squeeze(q, axis=1) # remove unnecessary dimension
k = jnp.squeeze(k, axis=1)
v = jnp.squeeze(v, axis=1)
# Attention
return jax.vmap(self._forward)(q, k, v)
@staticmethod
def _forward(q, k, v) -> jnp.ndarray:
"""
Args:
q: the query tensor (H, T, d_model)
k: the key tensor (H, T, d_model)
v: the value tensor (H, T, d_model)
where H is the number of heads, T is the number of elements that attend to each other
and C is d_model
Returns: a value tensor of shape (T, d_model) after attention
This function does not take into account the batch dimension. It will be vmap
transformed to do so.
(unit-tested)
"""
headed_v = jax.vmap(QKVAttention.scaled_dot_product_attention)(q, k, v) # (H, T, d_model/num_heads)
# Merge the heads
v_prime = jnp.concatenate(headed_v, axis=0) # (T * num_heads, d_model)
return v_prime
@staticmethod
def scaled_dot_product_attention(q, k, v) -> jnp.ndarray:
"""
Args:
q: the query tensor (T, d_model)
k: the key tensor (T, d_model)
v: the value tensor (T, d_model / num_heads)
Returns: a tensor of shape (T, d_model / num_heads) after scaled dot product attention
This function will be vmap transformed in order to apply it as a multi-head
attention mechanism.
(unit-tested)
"""
# Get dimensions.
elements, d_model = q.shape
scale_factor = jnp.sqrt(d_model) # scaling factor for scaled attention
weight_logits = jnp.matmul(q, k.T) / scale_factor # Compute Q*K.T / sqrt(d_model)
attention_weights = jax.nn.softmax(weight_logits, axis=-1) # (T, T)
attention_v = jnp.matmul(attention_weights, v) # (T, d_model / num_heads)
return attention_v