forked from google-research/simclr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
165 lines (141 loc) · 6.17 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# coding=utf-8
# Copyright 2020 The SimCLR Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific simclr governing permissions and
# limitations under the License.
# ==============================================================================
"""Data pipeline."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl import flags
import data_util as data_util
import tensorflow.compat.v1 as tf
FLAGS = flags.FLAGS
def pad_to_batch(dataset, batch_size):
"""Pad Tensors to specified batch size.
Args:
dataset: An instance of tf.data.Dataset.
batch_size: The number of samples per batch of input requested.
Returns:
An instance of tf.data.Dataset that yields the same Tensors with the same
structure as the original padded to batch_size along the leading
dimension.
Raises:
ValueError: If the dataset does not comprise any tensors; if a tensor
yielded by the dataset has an unknown number of dimensions or is a
scalar; or if it can be statically determined that tensors comprising
a single dataset element will have different leading dimensions.
"""
def _pad_to_batch(*args):
"""Given Tensors yielded by a Dataset, pads all to the batch size."""
flat_args = tf.nest.flatten(args)
for tensor in flat_args:
if tensor.shape.ndims is None:
raise ValueError(
'Unknown number of dimensions for tensor %s.' % tensor.name)
if tensor.shape.ndims == 0:
raise ValueError('Tensor %s is a scalar.' % tensor.name)
# This will throw if flat_args is empty. However, as of this writing,
# tf.data.Dataset.map will throw first with an internal error, so we do
# not check this case explicitly.
first_tensor = flat_args[0]
first_tensor_shape = tf.shape(first_tensor)
first_tensor_batch_size = first_tensor_shape[0]
difference = batch_size - first_tensor_batch_size
for i, tensor in enumerate(flat_args):
control_deps = []
if i != 0:
# Check that leading dimensions of this tensor matches the first,
# either statically or dynamically. (If the first dimensions of both
# tensors are statically known, the we have to check the static
# shapes at graph construction time or else we will never get to the
# dynamic assertion.)
if (first_tensor.shape[:1].is_fully_defined() and
tensor.shape[:1].is_fully_defined()):
if first_tensor.shape[0] != tensor.shape[0]:
raise ValueError(
'Batch size of dataset tensors does not match. %s '
'has shape %s, but %s has shape %s' % (
first_tensor.name, first_tensor.shape,
tensor.name, tensor.shape))
else:
curr_shape = tf.shape(tensor)
control_deps = [tf.Assert(
tf.equal(curr_shape[0], first_tensor_batch_size),
['Batch size of dataset tensors %s and %s do not match. '
'Shapes are' % (tensor.name, first_tensor.name), curr_shape,
first_tensor_shape])]
with tf.control_dependencies(control_deps):
# Pad to batch_size along leading dimension.
flat_args[i] = tf.pad(
tensor, [[0, difference]] + [[0, 0]] * (tensor.shape.ndims - 1))
flat_args[i].set_shape([batch_size] + tensor.shape.as_list()[1:])
return tf.nest.pack_sequence_as(args, flat_args)
return dataset.map(_pad_to_batch)
def build_input_fn(builder, is_training):
"""Build input function.
Args:
builder: TFDS builder for specified dataset.
is_training: Whether to build in training mode.
Returns:
A function that accepts a dict of params and returns a tuple of images and
features, to be used as the input_fn in TPUEstimator.
"""
def _input_fn(params):
"""Inner input function."""
preprocess_fn_pretrain = get_preprocess_fn(is_training, is_pretrain=True)
preprocess_fn_finetune = get_preprocess_fn(is_training, is_pretrain=False)
num_classes = builder.info.features['label'].num_classes
def map_fn(image, label):
"""Produces multiple transformations of the same batch."""
if FLAGS.train_mode == 'pretrain':
xs = []
for _ in range(2): # Two transformations
xs.append(preprocess_fn_pretrain(image))
image = tf.concat(xs, -1)
label = tf.zeros([num_classes])
else:
image = preprocess_fn_finetune(image)
label = tf.one_hot(label, num_classes)
return image, label, 1.0
dataset = builder.as_dataset(
split=FLAGS.train_split if is_training else FLAGS.eval_split,
shuffle_files=is_training, as_supervised=True)
if FLAGS.cache_dataset:
dataset = dataset.cache()
if is_training:
buffer_multiplier = 50 if FLAGS.image_size <= 32 else 10
dataset = dataset.shuffle(params['batch_size'] * buffer_multiplier)
dataset = dataset.repeat(-1)
dataset = dataset.map(map_fn,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(params['batch_size'], drop_remainder=is_training)
dataset = pad_to_batch(dataset, params['batch_size'])
images, labels, mask = tf.data.make_one_shot_iterator(dataset).get_next()
return images, {'labels': labels, 'mask': mask}
return _input_fn
def get_preprocess_fn(is_training, is_pretrain):
"""Get function that accepts an image and returns a preprocessed image."""
# Disable test cropping for small images (e.g. CIFAR)
if FLAGS.image_size <= 32:
test_crop = False
else:
test_crop = True
return functools.partial(
data_util.preprocess_image,
height=FLAGS.image_size,
width=FLAGS.image_size,
is_training=is_training,
color_distort=is_pretrain,
test_crop=test_crop)