-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathtpu_normalization.py
115 lines (95 loc) · 4.82 KB
/
tpu_normalization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Normamlization methods that implements cross replica nomalization for TPU."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import tensorflow as tf
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.keras import layers as keras_layers
from tensorflow.python.ops import math_ops
def cross_replica_average(t, num_groups=1):
"""Calculates the average value of input tensor across TPU replicas."""
num_shards = tpu_function.get_tpu_context().number_of_shards
num_shards_per_group = 1
group_assignment = None
if num_groups > 0:
if num_shards % num_groups != 0:
raise ValueError('num_shards: %d mod num_groups: %d, should be 0' %
(num_shards, num_groups))
num_shards_per_group = num_shards // num_groups
group_assignment = [[
x for x in range(num_shards) if x // num_shards_per_group == y
] for y in range(num_groups)]
return tpu_ops.cross_replica_sum(t, group_assignment) / math_ops.cast(
num_shards_per_group, t.dtype)
class BatchNormalization(keras_layers.BatchNormalization, tf.layers.Layer):
"""Batch Normalization layer that supports cross replica computation on TPU.
This class extends the keras.BatchNormalization implementation by supporting
cross replica means and variances. The base class implementation only computes
moments based on mini-batch per replica (TPU core).
For detailed information of arguments and implementation, refer to:
https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization
Arguments:
fused: if `None` or `True`, use a faster, fused implementation if possible.
If `False`, use the system recommended implementation.
cross_replica_average_fn: A function takes a tensor and outputs the mean
value across all the replicas. Currently, only TPU version supports this
feature. If specified, fused must be `False`.
"""
def __init__(self, fused=None, cross_replica_average_fn=None, **kwargs):
super(BatchNormalization, self).__init__(**kwargs)
self.cross_replica_average_fn = cross_replica_average_fn
if fused and cross_replica_average_fn is not None:
raise ValueError('fused must be `False` when sepcifying'
' cross_replica_average_fn')
def _moments(self, inputs, reduction_axes, keep_dims):
mean, variance = super(BatchNormalization, self)._moments(
inputs, reduction_axes, keep_dims=keep_dims)
if self.cross_replica_average_fn:
mean = self.cross_replica_average_fn(mean)
variance = self.cross_replica_average_fn(variance)
return (mean, variance)
def cross_replica_batch_normalization(inputs,
training=False,
num_distributed_groups=1,
**kwargs):
"""Functional interface for the cross replica batch normalization layer.
For detailed information of arguments and implementation, refer to:
https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization
Arguments:
inputs: Tensor input.
training: Either a Python boolean, or a TensorFlow boolean scalar tensor
(e.g. a placeholder). Whether to return the output in training mode
(normalized with statistics of the current batch) or in inference mode
(normalized with moving statistics). **NOTE**: make sure to set this
parameter correctly, or else your training/inference will not work
properly.
num_distributed_groups: Number of groups to normalize in the distributed
batch normalization. Replicas will evenly split into groups. For example,
1 for global batch norm and -1 or None for per-replica batch norm.
**kwargs: For passing through arguments to BatchNormalization.
Returns:
Output tensor.
Raises:
ValueError: if eager execution is enabled.
"""
layer = BatchNormalization(
cross_replica_average_fn=functools.partial(
cross_replica_average, num_groups=num_distributed_groups),
**kwargs)
return layer.apply(inputs, training=training)