-
Notifications
You must be signed in to change notification settings - Fork 19
/
parametric.py
77 lines (70 loc) · 3 KB
/
parametric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import tensorflow as tf
from tensorflow.contrib import slim
from mayo.util import memoize_method
from mayo.net.tf.gate.base import GateParameterValueError
from mayo.net.tf.gate.sparse import SparseRegularizedGatedConvolutionBase
class ParametricGatedConvolution(SparseRegularizedGatedConvolutionBase):
""" Parametric batch normalization with gating. """
def _update_defaults(self, defaults):
super()._update_defaults(defaults)
# FIXME hacky normalizer customization
defaults['norm'] = 'batch'
defaults['parametric_beta'] = False
defaults['add_beta_first'] = False
def normalize(self, tensor):
if self.normalizer_fn is not slim.batch_norm:
raise GateParameterValueError(
'Policy "{}" is used, we expect slim.batch_norm to '
'be used but it is absent in {}.'
.format(self.policy, self.node))
if not self.normalizer_params.get('scale', False):
raise GateParameterValueError(
'Policy "parametric_gamma" expects `scale` to be used '
'in slim.batch_norm.')
if self.norm == 'batch':
normalizer_params = dict(self.normalizer_params, **{
'scale': False,
'center': False,
'activation_fn': None,
'scope': '{}/BatchNorm'.format(self.scope),
'is_training': self.is_training,
})
return self.constructor.instantiate_batch_normalization(
None, tensor, normalizer_params)
if self.norm == 'channel':
norm_mean, norm_var = tf.nn.moments(
tensor, axes=[1, 2], keep_dims=True)
return (tensor - norm_mean) / tf.sqrt(norm_var)
raise GateParameterValueError('Unrecognized normalization policy.')
@memoize_method
def beta(self):
tensor = self._predictor('gate/beta')
self._register('beta', tensor)
return tensor
def _add_beta(self, tensor):
if not self.normalizer_params.get('center', True):
return
if self.parametric_beta:
beta = self.beta()
else:
# constant beta
beta_scope = '{}/gate/shift'.format(self.scope)
beta = tf.get_variable(
beta_scope, shape=tensor.shape[-1], dtype=tf.float32,
initializer=tf.constant_initializer(0.1),
trainable=self.trainable)
beta = self.actives() * beta if self.enable else beta
return tensor + beta
def activate(self, tensor):
# gating happens before activation
# output = relu(
# actives(gamma(x)) * gamma(x) * norm(conv(x)) +
# actives(gamma(x)) * beta
# )
gamma = self.gate()
if self.add_beta_first:
tensor = self._add_beta(tensor)
tensor *= self.actives() * gamma if self.enable else gamma
if not self.add_beta_first:
tensor = self._add_beta(tensor)
return super().activate(tensor)