-
Notifications
You must be signed in to change notification settings - Fork 69
/
norm.py
230 lines (200 loc) · 6.92 KB
/
norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
"""Neuron normalization methods."""
import torch
class MeanOnlyBatchNorm(torch.nn.Module):
"""Implements mean only batch norm with optional user defined quantization
using pre-hook-function. The mean of batchnorm translates to negative bias
of the neuron.
Parameters
----------
num_features : int
number of features. It is automatically initialized on first run if the
value is None. Default is None.
momentum : float
momentum of mean calculation. Defaults to 0.1.
pre_hook_fx : function pointer or lambda
pre-hook-function that is applied to the normalization output.
User can provide a quantization method as needed.
Defaults to None.
Attributes
----------
num_features
momentum
pre_hook_fx
running_mean : torch tensor
running mean estimate.
update : bool
enable mean estimte update.
"""
def __init__(self, num_features=None, momentum=0.1, pre_hook_fx=None):
""" """
super(MeanOnlyBatchNorm, self).__init__()
self.num_features = num_features
self.momentum = momentum
if pre_hook_fx is None:
self.pre_hook_fx = lambda x: x
else:
self.pre_hook_fx = pre_hook_fx
self.register_buffer(
'running_mean',
torch.zeros(1 if num_features is None else num_features)
)
self.reset_parameters()
self.update = True
def reset_parameters(self):
"""Reset states."""
self.running_mean.zero_()
@property
def bias(self):
"""Equivalent bias shift."""
return -self.pre_hook_fx(self.running_mean, descale=True)
def forward(self, inp):
"""
"""
size = inp.shape
if self.num_features is None:
self.num_features = inp.shape[1]
if self.training and self.update is True:
mean = torch.mean(
inp.view(size[0], self.num_features, -1),
dim=[0, 2]
)
# n = inp.numel() / inp.shape[1]
with torch.no_grad():
self.running_mean = (1 - self.momentum) * self.running_mean \
+ self.momentum * mean
else:
mean = self.running_mean
if len(size) == 2:
out = (inp - self.pre_hook_fx(mean.view(1, -1)))
elif len(size) == 3:
out = (inp - self.pre_hook_fx(mean.view(1, -1, 1)))
elif len(size) == 4:
out = (inp - self.pre_hook_fx(mean.view(1, -1, 1, 1)))
elif len(size) == 5:
out = (inp - self.pre_hook_fx(mean.view(1, -1, 1, 1, 1)))
else:
print(f'Found unexpected number of dims {len(size)} in input.')
return out
class WgtScaleBatchNorm(torch.nn.Module):
"""Implements batch norm with variance scale in powers of 2. This allows
eventual normalizaton to be implemented with bit-shift in a hardware
friendly manner. Optional user defined quantization can be enabled using a
pre-hook-function. The mean of batchnorm translates to negative bias of the
neuron.
Parameters
----------
num_features : int
number of features. It is automatically initialized on first run if the
value is None. Default is None.
momentum : float
momentum of mean calculation. Defaults to 0.1.
weight_exp_bits : int
number of allowable bits for weight exponentation. Defaults to 3.
eps : float
infitesimal value. Defaults to 1e-5.
pre_hook_fx : function pointer or lambda
pre-hook-function that is applied to the normalization output.
User can provide a quantization method as needed.
Defaults to None.
Attributes
----------
num_features
momentum
weight_exp_bits
eps
pre_hook_fx
running_mean : torch tensor
running mean estimate.
running_var : torch tensor
running variance estimate.
update : bool
enable mean estimte update.
"""
def __init__(
self,
num_features=None, momentum=0.1,
weight_exp_bits=3, eps=1e-5,
pre_hook_fx=None
):
""" """
super(WgtScaleBatchNorm, self).__init__()
self.num_features = num_features
self.momentum = momentum
self.weight_exp_bits = weight_exp_bits
self.eps = eps
if pre_hook_fx is None:
self.pre_hook_fx = lambda x: x
else:
self.pre_hook_fx = pre_hook_fx
self.register_buffer(
'running_mean',
torch.zeros(1 if num_features is None else num_features)
)
self.register_buffer(
'running_var',
torch.zeros(1)
)
self.reset_parameters()
self.update = True
def reset_parameters(self):
"""Reset states."""
self.running_mean.zero_()
self.running_var.zero_()
def std(self, var):
"""
"""
std = torch.sqrt(var + self.eps)
return torch.pow(2., torch.ceil(torch.log2(std)).clamp(
-self.weight_exp_bits, self.weight_exp_bits
))
@property
def bias(self):
"""Equivalent bias shift."""
return -self.pre_hook_fx(self.running_mean, descale=True)
@property
def weight_exp(self):
"""Equivalent weight exponent value."""
return torch.ceil(torch.log2(torch.sqrt(self.running_var + self.eps)))
def forward(self, inp):
"""
"""
size = inp.shape
if self.num_features is None:
self.num_features = inp.shape[1]
if self.training and self.update is True:
mean = torch.mean(
inp.view(size[0], self.num_features, -1),
dim=[0, 2]
)
var = torch.var(inp, unbiased=False)
n = inp.numel() / inp.shape[1]
with torch.no_grad():
self.running_mean = (1 - self.momentum) * self.running_mean \
+ self.momentum * mean
self.running_var = (1 - self.momentum) * self.running_var \
+ self.momentum * var * n / (n + 1)
else:
mean = self.running_mean
var = self.running_var
std = self.std(var)
if len(size) == 2:
out = (
inp - self.pre_hook_fx(mean.view(1, -1))
) / std.view(1, -1)
elif len(size) == 3:
out = (
inp - self.pre_hook_fx(mean.view(1, -1, 1))
) / std.view(1, -1, 1)
elif len(size) == 4:
out = (
inp - self.pre_hook_fx(mean.view(1, -1, 1, 1))
) / std.view(1, -1, 1, 1)
elif len(size) == 5:
out = (
inp - self.pre_hook_fx(mean.view(1, -1, 1, 1, 1))
) / std.view(1, -1, 1, 1, 1)
else:
print(f'Found unexpected number of dims {len(size)}.')
return out