-
Notifications
You must be signed in to change notification settings - Fork 0
/
modules.py
22 lines (18 loc) · 927 Bytes
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from neural_tangents import stax
def ResNetBlock(out_chan, W_std, b_std, strides=(1,1), channel_mismatch=False):
conv = stax.serial(
stax.Relu(), stax.Conv(out_chan, (3,3), strides, padding='SAME', W_std=W_std, b_std=b_std),
stax.Relu(), stax.Conv(out_chan, (3,3), strides=(1, 1), padding='SAME', W_std=W_std, b_std=b_std),
)
shortcut = stax.Identity() if not channel_mismatch else stax.Conv(out_chan, (3,3), strides, padding='SAME', W_std=W_std, b_std=b_std)
return stax.serial(
stax.FanOut(2),
stax.parallel(conv, shortcut),
stax.FanInSum()
)
def ResNetGroup(n, out_chan, W_std, b_std, strides=(1,1)):
blocks = []
blocks += [ResNetBlock(out_chan, W_std=W_std, b_std=b_std, strides=strides)]
for _ in range(n - 1):
blocks += [ResNetBlock(out_chan, W_std=W_std, b_std=b_std, strides=(1,1))]
return stax.serial(*blocks)