-
Notifications
You must be signed in to change notification settings - Fork 25
/
objects_room.py
90 lines (77 loc) · 3.49 KB
/
objects_room.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Objects Room dataset reader."""
import functools
import tensorflow.compat.v1 as tf
COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string('GZIP')
IMAGE_SIZE = [64, 64]
# The maximum number of foreground and background entities in each variant
# of the provided datasets. The values correspond to the number of
# segmentation masks returned per scene.
MAX_NUM_ENTITIES = {
'train': 7,
'six_objects': 10,
'empty_room': 4,
'identical_color': 10
}
BYTE_FEATURES = ['mask', 'image']
def feature_descriptions(max_num_entities):
"""Create a dictionary describing the dataset features.
Args:
max_num_entities: int. The maximum number of foreground and background
entities in each image. This corresponds to the number of segmentation
masks returned per scene.
Returns:
A dictionary which maps feature names to `tf.Example`-compatible shape and
data type descriptors.
"""
return {
'image': tf.FixedLenFeature(IMAGE_SIZE+[3], tf.string),
'mask': tf.FixedLenFeature([max_num_entities]+IMAGE_SIZE+[1], tf.string),
}
def _decode(example_proto, features):
# Parse the input `tf.Example` proto using a feature description dictionary.
single_example = tf.parse_single_example(example_proto, features)
for k in BYTE_FEATURES:
single_example[k] = tf.squeeze(tf.decode_raw(single_example[k], tf.uint8),
axis=-1)
return single_example
def dataset(tfrecords_path, dataset_variant, read_buffer_size=None,
map_parallel_calls=None):
"""Read, decompress, and parse the TFRecords file.
Args:
tfrecords_path: str. Path to the dataset file.
dataset_variant: str. One of ['train', 'six_objects', 'empty_room',
'identical_color']. This is used to identify the maximum number of
entities in each scene. If an incorrect identifier is passed in, the
TFRecords file will not be read correctly.
read_buffer_size: int. Number of bytes in the read buffer. See documentation
for `tf.data.TFRecordDataset.__init__`.
map_parallel_calls: int. Number of elements decoded asynchronously in
parallel. See documentation for `tf.data.Dataset.map`.
Returns:
An unbatched `tf.data.TFRecordDataset`.
"""
if dataset_variant not in MAX_NUM_ENTITIES:
raise ValueError('Invalid `dataset_variant` provided. The supported values'
' are: {}'.format(list(MAX_NUM_ENTITIES.keys())))
max_num_entities = MAX_NUM_ENTITIES[dataset_variant]
raw_dataset = tf.data.TFRecordDataset(
tfrecords_path, compression_type=COMPRESSION_TYPE,
buffer_size=read_buffer_size)
features = feature_descriptions(max_num_entities)
partial_decode_fn = functools.partial(_decode, features=features)
return raw_dataset.map(partial_decode_fn,
num_parallel_calls=map_parallel_calls)